You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

128 lines
3.5 KiB
Python

1 year ago
# autopep8: off
import tensorflow as tf
1 year ago
1 year ago
physical_devices = tf.config.experimental.list_physical_devices('GPU')
print(physical_devices)
tf.config.experimental.set_memory_growth(physical_devices[0], True)
1 year ago
1 year ago
import colors
from aesthetic_loss import aesthetic_loss
from polymap import Polymap
from tensorflow.keras.layers import Dense, Conv2D, Reshape, ReLU, Conv2DTranspose, LeakyReLU, BatchNormalization, GaussianNoise
# autopep8: on
batch_size = 256
input_small = tf.random.normal(shape=[4, 2])
pm = Polymap(8, 3)
rl = ReLU()
d1 = Dense(256, activation="relu",
kernel_regularizer=tf.keras.regularizers.l1_l2())
d2 = Dense(512, activation="relu",
kernel_regularizer=tf.keras.regularizers.l1_l2())
d3 = Dense(1024, activation="relu",
kernel_regularizer=tf.keras.regularizers.l1_l2())
rs = Reshape([2, 2, 1024//4])
dc1 = Conv2DTranspose(256, (5, 5),
padding="valid", use_bias=False)
dc2 = Conv2DTranspose(128, (3, 3),
padding="valid", use_bias=False)
dc3 = Conv2DTranspose(64, (3, 3),
padding="same", use_bias=False)
dc4 = Conv2D(2, (1, 1), activation="tanh", use_bias=False)
bn1 = BatchNormalization()
bn2 = BatchNormalization()
bn3 = BatchNormalization()
def gen_model():
model = tf.keras.Sequential()
model.add(pm)
# model.add(rl)
model.add(d1)
model.add(d2)
model.add(d3)
model.add(rs)
model.add(dc1)
model.add(bn1)
model.add(LeakyReLU())
model.add(dc2)
model.add(bn2)
model.add(LeakyReLU())
model.add(dc3)
model.add(bn3)
model.add(LeakyReLU())
model.add(dc4)
return model
optimizer = tf.keras.optimizers.Adam(learning_rate=0.001, epsilon=0.001)
train_loss = tf.keras.metrics.Mean(name='train_loss')
generator = gen_model()
rng = tf.random.Generator.from_seed(1)
@tf.function
def train_step(input):
# input = tf.random.normal(shape=[batch_size, 2])
# input = rng.normal(shape=[batch_size, 2])
# print(input)
with tf.GradientTape() as tape:
# pics = gen(input)
pics = generator(input, training=True)
loss = aesthetic_loss(
pics) + tf.reduce_sum([d1.losses, d2.losses, d3.losses])
+ 0.01 * tf.reduce_mean(tf.abs(pm.trainable_weights))
# tv = sum([pm.trainable_weights,
# d1.trainable_weights,
# d2.trainable_weights,
# d3.trainable_weights,
# dc1.trainable_weights,
# dc2.trainable_weights,
# dc3.trainable_weights,], [])
# print(tv)
gradients = tape.gradient(loss, generator.trainable_weights)
optimizer.apply_gradients(zip(gradients, generator.trainable_weights))
train_loss(loss + tf.reduce_sum([d1.losses, d2.losses, d3.losses]))
EPOCHS = 200
for epoch in range(EPOCHS):
train_loss.reset_states()
input = rng.normal(shape=[batch_size, 2])
# print(input)
for i in range(100):
train_step(input)
print(
f'Epoch {epoch + 1}, '
f'Loss: {train_loss.result()}, '
)
input_small = rng.normal(shape=[4, 2])
ps = generator(input_small, training=False)
# print(ps[0])
# print(pm.kernels)
s0 = colors.vector_to_string(ps[0])
s1 = colors.vector_to_string(ps[1])
s2 = colors.vector_to_string(ps[2])
s3 = colors.vector_to_string(ps[3])
print("\n".join(" ".join(r)
for r in zip(s0.split("\n"), s1.split("\n"), s2.split("\n"), s3.split("\n"))))