import tensorflow as tf from colors import vector_to_index @tf.function def aesthetic_loss(y_pics): flipx = tf.reverse(y_pics, axis=(-2,)) flipy = tf.reverse(y_pics, axis=(-3,)) center = tf.reverse(y_pics, axis=(-2, -3,)) trans = tf.transpose(y_pics, perm=(0, 2, 1, 3)) anti_trans = tf.reverse(tf.transpose(tf.reverse( y_pics, axis=(-2,)), perm=(0, 2, 1, 3)), axis=(-2,)) flipx_score = compute_score(y_pics, flipx) flipy_score = compute_score(y_pics, flipy) trans_score = compute_score(y_pics, trans) anti_trans_score = compute_score(y_pics, anti_trans) center_score = compute_score(y_pics, center) # Extract patches for neighborhood check patches = tf.image.extract_patches( y_pics, (1, 3, 3, 1), (1, 1, 1, 1), (1, 1, 1, 1), padding="SAME") # Dot product with central pixel # neigh = tf.reduce_sum( # tf.maximum(0.0, -tf.einsum("...k,...k->...", tf.repeat(y_pics, 9, -1), patches))) neigh = tf.reduce_sum( tf.abs(tf.einsum("...k,...k->...", tf.repeat(y_pics, 9, -1), patches))) divers = tf.reduce_sum( tf.square(tf.einsum("ijkl,mjkl->imjk", y_pics, y_pics))) # print(flipx_score) # print( tf.reduce_max(tf.stack([flipx_score, flipy_score, trans_score, anti_trans_score, center_score]), axis=(0,))) # idxs = vector_to_index(y_pics) # ncol = tf.map_fn(lambda p: tf.size( # tf.unique(tf.reshape(p, [-1]),)[0], out_type=tf.int64), idxs) # print(ncol) return tf.reduce_sum([ -1.0/(64) * tf.reduce_sum(tf.reduce_max(tf.stack([flipx_score, flipy_score, trans_score, anti_trans_score, center_score]), axis=(0,))), -1.0/(64*9) * 0.1 * neigh, # -1.0/7 * 10 * tf.cast(ncol, tf.float32), # 1.0/64 * 0.01 * tf.reduce_sum(tf.square(y_pics)), 1.0/64 * 10 * divers ], ) @tf.function def compute_score(pic1, pic2): return tf.reduce_sum( tf.abs( tf.einsum( "...k,...k->...", pic1, pic2)), axis=(-2, -1)) # return tf.reduce_sum( # -tf.math.maximum(0.0, # -tf.einsum( # "...k,...k->...", # pic1, # pic2)), axis=(-2, -1))