You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

66 lines
2.4 KiB
Python

1 year ago
import tensorflow as tf
1 year ago
from colors import vector_to_index
@tf.function
def aesthetic_loss(y_pics):
flipx = tf.reverse(y_pics, axis=(-2,))
flipy = tf.reverse(y_pics, axis=(-3,))
center = tf.reverse(y_pics, axis=(-2, -3,))
trans = tf.transpose(y_pics, perm=(0, 2, 1, 3))
anti_trans = tf.reverse(tf.transpose(tf.reverse(
y_pics, axis=(-2,)), perm=(0, 2, 1, 3)), axis=(-2,))
flipx_score = compute_score(y_pics, flipx)
flipy_score = compute_score(y_pics, flipy)
trans_score = compute_score(y_pics, trans)
anti_trans_score = compute_score(y_pics, anti_trans)
center_score = compute_score(y_pics, center)
# Extract patches for neighborhood check
patches = tf.image.extract_patches(
y_pics, (1, 3, 3, 1), (1, 1, 1, 1), (1, 1, 1, 1), padding="SAME")
# Dot product with central pixel
# neigh = tf.reduce_sum(
# tf.maximum(0.0, -tf.einsum("...k,...k->...", tf.repeat(y_pics, 9, -1), patches)))
neigh = tf.reduce_sum(
tf.abs(tf.einsum("...k,...k->...", tf.repeat(y_pics, 9, -1), patches)))
divers = tf.reduce_sum(
tf.square(tf.einsum("ijkl,mjkl->imjk", y_pics, y_pics)))
# print(flipx_score)
# print( tf.reduce_max(tf.stack([flipx_score, flipy_score, trans_score, anti_trans_score, center_score]), axis=(0,)))
# idxs = vector_to_index(y_pics)
# ncol = tf.map_fn(lambda p: tf.size(
# tf.unique(tf.reshape(p, [-1]),)[0], out_type=tf.int64), idxs)
# print(ncol)
return tf.reduce_sum([
-1.0/(64) * tf.reduce_sum(tf.reduce_max(tf.stack([flipx_score, flipy_score,
trans_score, anti_trans_score, center_score]), axis=(0,))),
-1.0/(64*9) * 0.1 * neigh,
# -1.0/7 * 10 * tf.cast(ncol, tf.float32),
# 1.0/64 * 0.01 * tf.reduce_sum(tf.square(y_pics)),
1.0/64 * 10 * divers
], )
@tf.function
def compute_score(pic1, pic2):
return tf.reduce_sum(
tf.abs(
tf.einsum(
"...k,...k->...",
pic1,
pic2)), axis=(-2, -1))
# return tf.reduce_sum(
# -tf.math.maximum(0.0,
# -tf.einsum(
# "...k,...k->...",
# pic1,
# pic2)), axis=(-2, -1))