|
|
|
import tensorflow as tf
|
|
|
|
|
|
|
|
from colors import vector_to_index
|
|
|
|
|
|
|
|
|
|
|
|
@tf.function
|
|
|
|
def aesthetic_loss(y_pics):
|
|
|
|
|
|
|
|
flipx = tf.reverse(y_pics, axis=(-2,))
|
|
|
|
flipy = tf.reverse(y_pics, axis=(-3,))
|
|
|
|
center = tf.reverse(y_pics, axis=(-2, -3,))
|
|
|
|
trans = tf.transpose(y_pics, perm=(0, 2, 1, 3))
|
|
|
|
anti_trans = tf.reverse(tf.transpose(tf.reverse(
|
|
|
|
y_pics, axis=(-2,)), perm=(0, 2, 1, 3)), axis=(-2,))
|
|
|
|
|
|
|
|
flipx_score = compute_score(y_pics, flipx)
|
|
|
|
flipy_score = compute_score(y_pics, flipy)
|
|
|
|
trans_score = compute_score(y_pics, trans)
|
|
|
|
anti_trans_score = compute_score(y_pics, anti_trans)
|
|
|
|
center_score = compute_score(y_pics, center)
|
|
|
|
|
|
|
|
# Extract patches for neighborhood check
|
|
|
|
patches = tf.image.extract_patches(
|
|
|
|
y_pics, (1, 3, 3, 1), (1, 1, 1, 1), (1, 1, 1, 1), padding="SAME")
|
|
|
|
# Dot product with central pixel
|
|
|
|
# neigh = tf.reduce_sum(
|
|
|
|
# tf.maximum(0.0, -tf.einsum("...k,...k->...", tf.repeat(y_pics, 9, -1), patches)))
|
|
|
|
neigh = tf.reduce_sum(
|
|
|
|
tf.abs(tf.einsum("...k,...k->...", tf.repeat(y_pics, 9, -1), patches)))
|
|
|
|
|
|
|
|
divers = tf.reduce_sum(
|
|
|
|
tf.square(tf.einsum("ijkl,mjkl->imjk", y_pics, y_pics)))
|
|
|
|
|
|
|
|
# print(flipx_score)
|
|
|
|
# print( tf.reduce_max(tf.stack([flipx_score, flipy_score, trans_score, anti_trans_score, center_score]), axis=(0,)))
|
|
|
|
|
|
|
|
# idxs = vector_to_index(y_pics)
|
|
|
|
# ncol = tf.map_fn(lambda p: tf.size(
|
|
|
|
# tf.unique(tf.reshape(p, [-1]),)[0], out_type=tf.int64), idxs)
|
|
|
|
# print(ncol)
|
|
|
|
|
|
|
|
return tf.reduce_sum([
|
|
|
|
-1.0/(64) * tf.reduce_sum(tf.reduce_max(tf.stack([flipx_score, flipy_score,
|
|
|
|
trans_score, anti_trans_score, center_score]), axis=(0,))),
|
|
|
|
-1.0/(64*9) * 0.1 * neigh,
|
|
|
|
# -1.0/7 * 10 * tf.cast(ncol, tf.float32),
|
|
|
|
# 1.0/64 * 0.01 * tf.reduce_sum(tf.square(y_pics)),
|
|
|
|
1.0/64 * 10 * divers
|
|
|
|
], )
|
|
|
|
|
|
|
|
|
|
|
|
@tf.function
|
|
|
|
def compute_score(pic1, pic2):
|
|
|
|
return tf.reduce_sum(
|
|
|
|
tf.abs(
|
|
|
|
tf.einsum(
|
|
|
|
"...k,...k->...",
|
|
|
|
pic1,
|
|
|
|
pic2)), axis=(-2, -1))
|
|
|
|
# return tf.reduce_sum(
|
|
|
|
# -tf.math.maximum(0.0,
|
|
|
|
# -tf.einsum(
|
|
|
|
# "...k,...k->...",
|
|
|
|
# pic1,
|
|
|
|
# pic2)), axis=(-2, -1))
|