You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

399 lines
13 KiB
Plaintext

1 year ago
{
"cells": [
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"The autoreload extension is already loaded. To reload it, use:\n",
" %reload_ext autoreload\n"
]
}
],
"source": [
"%load_ext autoreload\n",
"%autoreload 2"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"import tensorflow as tf\n",
"from tensorflow.keras.layers import Dense, Conv2D, Reshape\n",
"from polymap import Polymap\n",
"from aesthetic_loss import aesthetic_loss"
]
},
{
"cell_type": "code",
"execution_count": 46,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"['🟥', '🟦', '🟧', '🟨', '🟩', '🟪', '🟫']\n"
]
}
],
"source": [
"from math import pi\n",
"squares = [chr(i) for i in range(0x1F7E5, 0x1F7EC)]\n",
"tfsquares = tf.constant(squares)\n",
"print(squares)\n",
"colors = tf.constant([10, 219, 38, 48, 80, 282, 20] * tf.constant(pi/180.0), dtype=tf.float32)\n",
"color_vectors = tf.transpose(\n",
" tf.stack([\n",
" tf.math.cos(colors),\n",
" tf.math.sin(colors)]\n",
" )\n",
")\n",
"\n",
"def vector_to_index(tensor):\n",
" return tf.argmax((tf.einsum(\"ijk,lk->ijl\", tensor, color_vectors)), axis=-1)\n",
"\n",
"def index_to_string(tensor):\n",
" # square_select = tf.math.argmax(tensor, 2)\n",
"\n",
" tfstring = tf.strings.join(\n",
" tf.map_fn(\n",
" lambda v:\n",
" tf.strings.join(\n",
" tf.gather(tfsquares, tf.cast(v, tf.int64))), tensor, fn_output_signature=tf.string), \"\\n\")\n",
"\n",
" return tfstring.numpy().decode()"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"batch_size = 16"
]
},
{
"cell_type": "code",
"execution_count": 51,
"metadata": {},
"outputs": [],
"source": [
"# class PatchGen(tf.keras.Model):\n",
"\n",
"# def __init__(self, batch_size):\n",
"# super(PatchGen, self).__init__()\n",
"# self.pm = Polymap(8, 3)(input)\n",
"# self.d1 = Dense(64)\n",
"# self.rs = Reshape([8,8,1])\n",
"# self.dc1 = Conv2D(3,(1,1))\n",
"\n",
"# def call(self, inputs):\n",
"# x = self.block_1(inputs)\n",
"# x = self.block_2(x)\n",
"# x = self.global_pool(x)\n",
"# return self.classifier(x)\n",
"\n",
"input = tf.random.normal(shape=[batch_size,2])\n",
"pm = Polymap(8, 3)\n",
"d1 = Dense(64, activation=\"relu\")\n",
"d2 = Dense(64, activation=\"relu\")\n",
"rs = Reshape([8,8,1])\n",
"dc1 = Conv2D(2,(1,1))\n",
"\n",
"def gen():\n",
" input = tf.random.normal(shape=[batch_size,2])\n",
" x = pm(input)\n",
" x = d1(x)\n",
" x = d2(x)\n",
" x = rs(x)\n",
" return dc1(x)\n",
"# gen()"
]
},
{
"cell_type": "code",
"execution_count": 52,
"metadata": {},
"outputs": [],
"source": [
"optimizer = tf.keras.optimizers.Adam()\n",
"train_loss = tf.keras.metrics.Mean(name='train_loss')"
]
},
{
"cell_type": "code",
"execution_count": 53,
"metadata": {},
"outputs": [],
"source": [
"@tf.function\n",
"def train_step():\n",
" with tf.GradientTape() as tape:\n",
" pics = gen()\n",
" loss = aesthetic_loss(pics)\n",
" tv = sum([pm.trainable_weights, d1.trainable_weights,\n",
" dc1.trainable_weights], [])\n",
" print(tv)\n",
" gradients = tape.gradient(loss, tv)\n",
" optimizer.apply_gradients(zip(gradients, tv))\n",
"\n",
" train_loss(loss)\n",
" # train_accuracy(labels, predictions)"
]
},
{
"cell_type": "code",
"execution_count": 57,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1, Loss: 0.0004605448921211064, \n",
"🟩🟩🟩🟩🟩🟩🟩🟩\n",
"🟩🟩🟩🟩🟩🟩🟩🟩\n",
"🟩🟩🟩🟩🟩🟩🟩🟩\n",
"🟩🟩🟩🟩🟩🟩🟩🟩\n",
"🟩🟩🟩🟩🟩🟩🟩🟩\n",
"🟩🟩🟩🟩🟩🟩🟩🟩\n",
"🟩🟩🟩🟩🟩🟩🟩🟩\n",
"🟩🟩🟩🟩🟩🟩🟩🟩\n",
"Epoch 2, Loss: 9.45829197007697e-06, \n",
"🟥🟥🟥🟥🟥🟥🟥🟥\n",
"🟥🟥🟥🟥🟥🟥🟥🟥\n",
"🟥🟥🟥🟥🟥🟥🟥🟥\n",
"🟥🟥🟥🟥🟥🟥🟥🟥\n",
"🟥🟥🟥🟥🟥🟥🟥🟥\n",
"🟥🟥🟥🟥🟥🟥🟥🟥\n",
"🟥🟥🟥🟥🟥🟥🟥🟥\n",
"🟥🟥🟥🟥🟥🟥🟥🟥\n",
"Epoch 3, Loss: 8.17671389086172e-05, \n",
"🟦🟦🟦🟦🟦🟦🟦🟦\n",
"🟦🟦🟦🟦🟦🟦🟦🟦\n",
"🟦🟦🟦🟦🟦🟦🟦🟦\n",
"🟦🟦🟦🟦🟦🟦🟦🟦\n",
"🟦🟦🟦🟦🟦🟦🟦🟦\n",
"🟦🟦🟦🟦🟦🟦🟦🟦\n",
"🟦🟦🟦🟦🟦🟦🟦🟦\n",
"🟦🟦🟦🟦🟦🟦🟦🟦\n",
"Epoch 4, Loss: 0.0001505890249973163, \n",
"🟦🟦🟦🟦🟦🟦🟦🟦\n",
"🟦🟦🟦🟦🟦🟦🟦🟦\n",
"🟦🟦🟦🟦🟦🟦🟦🟦\n",
"🟦🟦🟦🟦🟦🟦🟦🟦\n",
"🟦🟦🟦🟦🟦🟦🟦🟦\n",
"🟦🟦🟦🟦🟦🟦🟦🟦\n",
"🟦🟦🟦🟦🟦🟦🟦🟦\n",
"🟦🟦🟦🟦🟦🟦🟦🟦\n",
"Epoch 5, Loss: 0.00018983485642820597, \n",
"🟩🟩🟩🟩🟩🟩🟩🟩\n",
"🟩🟩🟩🟩🟩🟩🟩🟩\n",
"🟩🟩🟩🟩🟩🟩🟩🟩\n",
"🟩🟩🟩🟩🟩🟩🟩🟩\n",
"🟩🟩🟩🟩🟩🟩🟩🟩\n",
"🟩🟩🟩🟩🟩🟩🟩🟩\n",
"🟩🟩🟩🟩🟩🟩🟩🟩\n",
"🟩🟩🟩🟩🟩🟩🟩🟩\n",
"Epoch 6, Loss: 0.00012902110756840557, \n",
"🟦🟦🟦🟦🟦🟦🟦🟦\n",
"🟦🟦🟦🟦🟦🟦🟦🟦\n",
"🟦🟦🟦🟦🟦🟦🟦🟦\n",
"🟦🟦🟦🟦🟦🟦🟦🟦\n",
"🟦🟦🟦🟦🟦🟦🟦🟦\n",
"🟦🟦🟦🟦🟦🟦🟦🟦\n",
"🟦🟦🟦🟦🟦🟦🟦🟦\n",
"🟦🟦🟦🟦🟦🟦🟦🟦\n",
"Epoch 7, Loss: 1.342521863989532e-05, \n",
"🟨🟨🟨🟨🟨🟨🟨🟨\n",
"🟨🟨🟨🟨🟨🟨🟨🟨\n",
"🟨🟨🟨🟨🟨🟨🟨🟨\n",
"🟨🟨🟨🟨🟨🟨🟨🟨\n",
"🟨🟨🟨🟨🟨🟨🟨🟨\n",
"🟨🟨🟨🟨🟨🟨🟨🟨\n",
"🟨🟨🟨🟨🟨🟨🟨🟨\n",
"🟨🟨🟨🟨🟨🟨🟨🟨\n",
"Epoch 8, Loss: 0.00034237385261803865, \n",
"🟩🟩🟩🟩🟩🟩🟩🟩\n",
"🟩🟩🟩🟩🟩🟩🟩🟩\n",
"🟩🟩🟩🟩🟩🟩🟩🟩\n",
"🟩🟩🟩🟩🟩🟩🟩🟩\n",
"🟩🟩🟩🟩🟩🟩🟩🟩\n",
"🟩🟩🟩🟩🟩🟩🟩🟩\n",
"🟩🟩🟩🟩🟩🟩🟩🟩\n",
"🟩🟩🟩🟩🟩🟩🟩🟩\n",
"Epoch 9, Loss: 8.251221879618242e-05, \n",
"🟦🟦🟦🟦🟦🟦🟦🟦\n",
"🟦🟦🟦🟦🟦🟦🟦🟦\n",
"🟦🟦🟦🟦🟦🟦🟦🟦\n",
"🟦🟦🟦🟦🟦🟦🟦🟦\n",
"🟦🟦🟦🟦🟦🟦🟦🟦\n",
"🟦🟦🟦🟦🟦🟦🟦🟦\n",
"🟦🟦🟦🟦🟦🟦🟦🟦\n",
"🟦🟦🟦🟦🟦🟦🟦🟦\n",
"Epoch 10, Loss: 1.0330303013006414e-07, \n",
"🟧🟧🟧🟧🟧🟧🟧🟧\n",
"🟧🟧🟧🟧🟧🟧🟧🟧\n",
"🟧🟧🟧🟧🟧🟧🟧🟧\n",
"🟧🟧🟧🟧🟧🟧🟧🟧\n",
"🟧🟧🟧🟧🟧🟧🟧🟧\n",
"🟧🟧🟧🟧🟧🟧🟧🟧\n",
"🟧🟧🟧🟧🟧🟧🟧🟧\n",
"🟧🟧🟧🟧🟧🟧🟧🟧\n",
"Epoch 11, Loss: 3.0386236176127568e-05, \n",
"🟥🟥🟥🟥🟥🟥🟥🟥\n",
"🟥🟥🟥🟥🟥🟥🟥🟥\n",
"🟥🟥🟥🟥🟥🟥🟥🟥\n",
"🟥🟥🟥🟥🟥🟥🟥🟥\n",
"🟥🟥🟥🟥🟥🟥🟥🟥\n",
"🟥🟥🟥🟥🟥🟥🟥🟥\n",
"🟥🟥🟥🟥🟥🟥🟥🟥\n",
"🟥🟥🟥🟥🟥🟥🟥🟥\n",
"Epoch 12, Loss: 4.6750748879276216e-05, \n",
"🟦🟦🟦🟦🟦🟦🟦🟦\n",
"🟦🟦🟦🟦🟦🟦🟦🟦\n",
"🟦🟦🟦🟦🟦🟦🟦🟦\n",
"🟦🟦🟦🟦🟦🟦🟦🟦\n",
"🟦🟦🟦🟦🟦🟦🟦🟦\n",
"🟦🟦🟦🟦🟦🟦🟦🟦\n",
"🟦🟦🟦🟦🟦🟦🟦🟦\n",
"🟦🟦🟦🟦🟦🟦🟦🟦\n",
"Epoch 13, Loss: 7.047886185773677e-08, \n",
"🟦🟦🟦🟦🟦🟦🟦🟦\n",
"🟦🟦🟦🟦🟦🟦🟦🟦\n",
"🟦🟦🟦🟦🟦🟦🟦🟦\n",
"🟦🟦🟦🟦🟦🟦🟦🟦\n",
"🟦🟦🟦🟦🟦🟦🟦🟦\n",
"🟦🟦🟦🟦🟦🟦🟦🟦\n",
"🟦🟦🟦🟦🟦🟦🟦🟦\n",
"🟦🟦🟦🟦🟦🟦🟦🟦\n",
"Epoch 14, Loss: 9.711433317672463e-14, \n",
"🟥🟥🟥🟥🟥🟥🟥🟥\n",
"🟥🟥🟥🟥🟥🟥🟥🟥\n",
"🟥🟥🟥🟥🟥🟥🟥🟥\n",
"🟥🟥🟥🟥🟥🟥🟥🟥\n",
"🟥🟥🟥🟥🟥🟥🟥🟥\n",
"🟥🟥🟥🟥🟥🟥🟥🟥\n",
"🟥🟥🟥🟥🟥🟥🟥🟥\n",
"🟥🟥🟥🟥🟥🟥🟥🟥\n",
"Epoch 15, Loss: 7.075177338602006e-23, \n",
"🟪🟪🟪🟪🟪🟪🟪🟪\n",
"🟪🟪🟪🟪🟪🟪🟪🟪\n",
"🟪🟪🟪🟪🟪🟪🟪🟪\n",
"🟪🟪🟪🟪🟪🟪🟪🟪\n",
"🟪🟪🟪🟪🟪🟪🟪🟪\n",
"🟪🟪🟪🟪🟪🟪🟪🟪\n",
"🟪🟪🟪🟪🟪🟪🟪🟪\n",
"🟪🟪🟪🟪🟪🟪🟪🟪\n",
"Epoch 16, Loss: 0.11943158507347107, \n",
"🟨🟨🟨🟨🟨🟨🟨🟨\n",
"🟨🟨🟨🟨🟨🟨🟨🟨\n",
"🟨🟨🟨🟨🟨🟨🟨🟨\n",
"🟨🟨🟨🟨🟨🟨🟨🟨\n",
"🟨🟨🟨🟨🟨🟨🟨🟨\n",
"🟨🟨🟨🟨🟨🟨🟨🟨\n",
"🟨🟨🟨🟨🟨🟨🟨🟨\n",
"🟨🟨🟨🟨🟨🟨🟨🟨\n",
"Epoch 17, Loss: 0.002167836995795369, \n",
"🟨🟨🟨🟨🟨🟨🟨🟨\n",
"🟨🟨🟨🟨🟨🟨🟨🟨\n",
"🟨🟨🟨🟨🟨🟨🟨🟨\n",
"🟨🟨🟨🟨🟨🟨🟨🟨\n",
"🟨🟨🟨🟨🟨🟨🟨🟨\n",
"🟨🟨🟨🟨🟨🟨🟨🟨\n",
"🟨🟨🟨🟨🟨🟨🟨🟨\n",
"🟨🟨🟨🟨🟨🟨🟨🟨\n",
"Epoch 18, Loss: 6.590542034246027e-05, \n",
"🟧🟧🟧🟧🟧🟧🟧🟧\n",
"🟧🟧🟧🟧🟧🟧🟧🟧\n",
"🟧🟧🟧🟧🟧🟧🟧🟧\n",
"🟧🟧🟧🟧🟧🟧🟧🟧\n",
"🟧🟧🟧🟧🟧🟧🟧🟧\n",
"🟧🟧🟧🟧🟧🟧🟧🟧\n",
"🟧🟧🟧🟧🟧🟧🟧🟧\n",
"🟧🟧🟧🟧🟧🟧🟧🟧\n",
"Epoch 19, Loss: 0.0017910723108798265, \n",
"🟩🟩🟩🟩🟩🟩🟩🟩\n",
"🟩🟩🟩🟩🟩🟩🟩🟩\n",
"🟩🟩🟩🟩🟩🟩🟩🟩\n",
"🟩🟩🟩🟩🟩🟩🟩🟩\n",
"🟩🟩🟩🟩🟩🟩🟩🟩\n",
"🟩🟩🟩🟩🟩🟩🟩🟩\n",
"🟩🟩🟩🟩🟩🟩🟩🟩\n",
"🟩🟩🟩🟩🟩🟩🟩🟩\n",
"Epoch 20, Loss: 0.004703103564679623, \n",
"🟦🟦🟦🟦🟦🟦🟦🟦\n",
"🟦🟦🟦🟦🟦🟦🟦🟦\n",
"🟦🟦🟦🟦🟦🟦🟦🟦\n",
"🟦🟦🟦🟦🟦🟦🟦🟦\n",
"🟦🟦🟦🟦🟦🟦🟦🟦\n",
"🟦🟦🟦🟦🟦🟦🟦🟦\n",
"🟦🟦🟦🟦🟦🟦🟦🟦\n",
"🟦🟦🟦🟦🟦🟦🟦🟦\n"
]
}
],
"source": [
"EPOCHS = 50\n",
"\n",
"for epoch in range(EPOCHS):\n",
" # Reset the metrics at the start of the next epoch\n",
" train_loss.reset_states()\n",
"# train_accuracy.reset_states()\n",
"# test_loss.reset_states()\n",
"# test_accuracy.reset_states()\n",
"\n",
"# for images, labels in train_ds:\n",
" for i in range(50):\n",
" train_step()\n",
"\n",
"# for test_images, test_labels in test_ds:\n",
"# test_step(test_images, test_labels)\n",
"\n",
" print(\n",
" f'Epoch {epoch + 1}, '\n",
" f'Loss: {train_loss.result()}, '\n",
" # f'Accuracy: {train_accuracy.result() * 100}, '\n",
" # f'Test Loss: {test_loss.result()}, '\n",
" # f'Test Accuracy: {test_accuracy.result() * 100}'\n",
" )\n",
"\n",
" ps = gen()\n",
" # print(ps[0])\n",
" print(index_to_string(vector_to_index(ps[0])))"
]
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv_patch_gen",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.9"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}