{ "cells": [ { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "The autoreload extension is already loaded. To reload it, use:\n", " %reload_ext autoreload\n" ] } ], "source": [ "%load_ext autoreload\n", "%autoreload 2" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "import tensorflow as tf\n", "from tensorflow.keras.layers import Dense, Conv2D, Reshape\n", "from polymap import Polymap\n", "from aesthetic_loss import aesthetic_loss" ] }, { "cell_type": "code", "execution_count": 46, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "['🟥', '🟦', '🟧', '🟨', '🟩', '🟪', '🟫']\n" ] } ], "source": [ "from math import pi\n", "squares = [chr(i) for i in range(0x1F7E5, 0x1F7EC)]\n", "tfsquares = tf.constant(squares)\n", "print(squares)\n", "colors = tf.constant([10, 219, 38, 48, 80, 282, 20] * tf.constant(pi/180.0), dtype=tf.float32)\n", "color_vectors = tf.transpose(\n", " tf.stack([\n", " tf.math.cos(colors),\n", " tf.math.sin(colors)]\n", " )\n", ")\n", "\n", "def vector_to_index(tensor):\n", " return tf.argmax((tf.einsum(\"ijk,lk->ijl\", tensor, color_vectors)), axis=-1)\n", "\n", "def index_to_string(tensor):\n", " # square_select = tf.math.argmax(tensor, 2)\n", "\n", " tfstring = tf.strings.join(\n", " tf.map_fn(\n", " lambda v:\n", " tf.strings.join(\n", " tf.gather(tfsquares, tf.cast(v, tf.int64))), tensor, fn_output_signature=tf.string), \"\\n\")\n", "\n", " return tfstring.numpy().decode()" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "batch_size = 16" ] }, { "cell_type": "code", "execution_count": 51, "metadata": {}, "outputs": [], "source": [ "# class PatchGen(tf.keras.Model):\n", "\n", "# def __init__(self, batch_size):\n", "# super(PatchGen, self).__init__()\n", "# self.pm = Polymap(8, 3)(input)\n", "# self.d1 = Dense(64)\n", "# self.rs = Reshape([8,8,1])\n", "# self.dc1 = Conv2D(3,(1,1))\n", "\n", "# def call(self, inputs):\n", "# x = self.block_1(inputs)\n", "# x = self.block_2(x)\n", "# x = self.global_pool(x)\n", "# return self.classifier(x)\n", "\n", "input = tf.random.normal(shape=[batch_size,2])\n", "pm = Polymap(8, 3)\n", "d1 = Dense(64, activation=\"relu\")\n", "d2 = Dense(64, activation=\"relu\")\n", "rs = Reshape([8,8,1])\n", "dc1 = Conv2D(2,(1,1))\n", "\n", "def gen():\n", " input = tf.random.normal(shape=[batch_size,2])\n", " x = pm(input)\n", " x = d1(x)\n", " x = d2(x)\n", " x = rs(x)\n", " return dc1(x)\n", "# gen()" ] }, { "cell_type": "code", "execution_count": 52, "metadata": {}, "outputs": [], "source": [ "optimizer = tf.keras.optimizers.Adam()\n", "train_loss = tf.keras.metrics.Mean(name='train_loss')" ] }, { "cell_type": "code", "execution_count": 53, "metadata": {}, "outputs": [], "source": [ "@tf.function\n", "def train_step():\n", " with tf.GradientTape() as tape:\n", " pics = gen()\n", " loss = aesthetic_loss(pics)\n", " tv = sum([pm.trainable_weights, d1.trainable_weights,\n", " dc1.trainable_weights], [])\n", " print(tv)\n", " gradients = tape.gradient(loss, tv)\n", " optimizer.apply_gradients(zip(gradients, tv))\n", "\n", " train_loss(loss)\n", " # train_accuracy(labels, predictions)" ] }, { "cell_type": "code", "execution_count": 57, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch 1, Loss: 0.0004605448921211064, \n", "🟩🟩🟩🟩🟩🟩🟩🟩\n", "🟩🟩🟩🟩🟩🟩🟩🟩\n", "🟩🟩🟩🟩🟩🟩🟩🟩\n", "🟩🟩🟩🟩🟩🟩🟩🟩\n", "🟩🟩🟩🟩🟩🟩🟩🟩\n", "🟩🟩🟩🟩🟩🟩🟩🟩\n", "🟩🟩🟩🟩🟩🟩🟩🟩\n", "🟩🟩🟩🟩🟩🟩🟩🟩\n", "Epoch 2, Loss: 9.45829197007697e-06, \n", "🟥🟥🟥🟥🟥🟥🟥🟥\n", "🟥🟥🟥🟥🟥🟥🟥🟥\n", "🟥🟥🟥🟥🟥🟥🟥🟥\n", "🟥🟥🟥🟥🟥🟥🟥🟥\n", "🟥🟥🟥🟥🟥🟥🟥🟥\n", "🟥🟥🟥🟥🟥🟥🟥🟥\n", "🟥🟥🟥🟥🟥🟥🟥🟥\n", "🟥🟥🟥🟥🟥🟥🟥🟥\n", "Epoch 3, Loss: 8.17671389086172e-05, \n", "🟦🟦🟦🟦🟦🟦🟦🟦\n", "🟦🟦🟦🟦🟦🟦🟦🟦\n", "🟦🟦🟦🟦🟦🟦🟦🟦\n", "🟦🟦🟦🟦🟦🟦🟦🟦\n", "🟦🟦🟦🟦🟦🟦🟦🟦\n", "🟦🟦🟦🟦🟦🟦🟦🟦\n", "🟦🟦🟦🟦🟦🟦🟦🟦\n", "🟦🟦🟦🟦🟦🟦🟦🟦\n", "Epoch 4, Loss: 0.0001505890249973163, \n", "🟦🟦🟦🟦🟦🟦🟦🟦\n", "🟦🟦🟦🟦🟦🟦🟦🟦\n", "🟦🟦🟦🟦🟦🟦🟦🟦\n", "🟦🟦🟦🟦🟦🟦🟦🟦\n", "🟦🟦🟦🟦🟦🟦🟦🟦\n", "🟦🟦🟦🟦🟦🟦🟦🟦\n", "🟦🟦🟦🟦🟦🟦🟦🟦\n", "🟦🟦🟦🟦🟦🟦🟦🟦\n", "Epoch 5, Loss: 0.00018983485642820597, \n", "🟩🟩🟩🟩🟩🟩🟩🟩\n", "🟩🟩🟩🟩🟩🟩🟩🟩\n", "🟩🟩🟩🟩🟩🟩🟩🟩\n", "🟩🟩🟩🟩🟩🟩🟩🟩\n", "🟩🟩🟩🟩🟩🟩🟩🟩\n", "🟩🟩🟩🟩🟩🟩🟩🟩\n", "🟩🟩🟩🟩🟩🟩🟩🟩\n", "🟩🟩🟩🟩🟩🟩🟩🟩\n", "Epoch 6, Loss: 0.00012902110756840557, \n", "🟦🟦🟦🟦🟦🟦🟦🟦\n", "🟦🟦🟦🟦🟦🟦🟦🟦\n", "🟦🟦🟦🟦🟦🟦🟦🟦\n", "🟦🟦🟦🟦🟦🟦🟦🟦\n", "🟦🟦🟦🟦🟦🟦🟦🟦\n", "🟦🟦🟦🟦🟦🟦🟦🟦\n", "🟦🟦🟦🟦🟦🟦🟦🟦\n", "🟦🟦🟦🟦🟦🟦🟦🟦\n", "Epoch 7, Loss: 1.342521863989532e-05, \n", "🟨🟨🟨🟨🟨🟨🟨🟨\n", "🟨🟨🟨🟨🟨🟨🟨🟨\n", "🟨🟨🟨🟨🟨🟨🟨🟨\n", "🟨🟨🟨🟨🟨🟨🟨🟨\n", "🟨🟨🟨🟨🟨🟨🟨🟨\n", "🟨🟨🟨🟨🟨🟨🟨🟨\n", "🟨🟨🟨🟨🟨🟨🟨🟨\n", "🟨🟨🟨🟨🟨🟨🟨🟨\n", "Epoch 8, Loss: 0.00034237385261803865, \n", "🟩🟩🟩🟩🟩🟩🟩🟩\n", "🟩🟩🟩🟩🟩🟩🟩🟩\n", "🟩🟩🟩🟩🟩🟩🟩🟩\n", "🟩🟩🟩🟩🟩🟩🟩🟩\n", "🟩🟩🟩🟩🟩🟩🟩🟩\n", "🟩🟩🟩🟩🟩🟩🟩🟩\n", "🟩🟩🟩🟩🟩🟩🟩🟩\n", "🟩🟩🟩🟩🟩🟩🟩🟩\n", "Epoch 9, Loss: 8.251221879618242e-05, \n", "🟦🟦🟦🟦🟦🟦🟦🟦\n", "🟦🟦🟦🟦🟦🟦🟦🟦\n", "🟦🟦🟦🟦🟦🟦🟦🟦\n", "🟦🟦🟦🟦🟦🟦🟦🟦\n", "🟦🟦🟦🟦🟦🟦🟦🟦\n", "🟦🟦🟦🟦🟦🟦🟦🟦\n", "🟦🟦🟦🟦🟦🟦🟦🟦\n", "🟦🟦🟦🟦🟦🟦🟦🟦\n", "Epoch 10, Loss: 1.0330303013006414e-07, \n", "🟧🟧🟧🟧🟧🟧🟧🟧\n", "🟧🟧🟧🟧🟧🟧🟧🟧\n", "🟧🟧🟧🟧🟧🟧🟧🟧\n", "🟧🟧🟧🟧🟧🟧🟧🟧\n", "🟧🟧🟧🟧🟧🟧🟧🟧\n", "🟧🟧🟧🟧🟧🟧🟧🟧\n", "🟧🟧🟧🟧🟧🟧🟧🟧\n", "🟧🟧🟧🟧🟧🟧🟧🟧\n", "Epoch 11, Loss: 3.0386236176127568e-05, \n", "🟥🟥🟥🟥🟥🟥🟥🟥\n", "🟥🟥🟥🟥🟥🟥🟥🟥\n", "🟥🟥🟥🟥🟥🟥🟥🟥\n", "🟥🟥🟥🟥🟥🟥🟥🟥\n", "🟥🟥🟥🟥🟥🟥🟥🟥\n", "🟥🟥🟥🟥🟥🟥🟥🟥\n", "🟥🟥🟥🟥🟥🟥🟥🟥\n", "🟥🟥🟥🟥🟥🟥🟥🟥\n", "Epoch 12, Loss: 4.6750748879276216e-05, \n", "🟦🟦🟦🟦🟦🟦🟦🟦\n", "🟦🟦🟦🟦🟦🟦🟦🟦\n", "🟦🟦🟦🟦🟦🟦🟦🟦\n", "🟦🟦🟦🟦🟦🟦🟦🟦\n", "🟦🟦🟦🟦🟦🟦🟦🟦\n", "🟦🟦🟦🟦🟦🟦🟦🟦\n", "🟦🟦🟦🟦🟦🟦🟦🟦\n", "🟦🟦🟦🟦🟦🟦🟦🟦\n", "Epoch 13, Loss: 7.047886185773677e-08, \n", "🟦🟦🟦🟦🟦🟦🟦🟦\n", "🟦🟦🟦🟦🟦🟦🟦🟦\n", "🟦🟦🟦🟦🟦🟦🟦🟦\n", "🟦🟦🟦🟦🟦🟦🟦🟦\n", "🟦🟦🟦🟦🟦🟦🟦🟦\n", "🟦🟦🟦🟦🟦🟦🟦🟦\n", "🟦🟦🟦🟦🟦🟦🟦🟦\n", "🟦🟦🟦🟦🟦🟦🟦🟦\n", "Epoch 14, Loss: 9.711433317672463e-14, \n", "🟥🟥🟥🟥🟥🟥🟥🟥\n", "🟥🟥🟥🟥🟥🟥🟥🟥\n", "🟥🟥🟥🟥🟥🟥🟥🟥\n", "🟥🟥🟥🟥🟥🟥🟥🟥\n", "🟥🟥🟥🟥🟥🟥🟥🟥\n", "🟥🟥🟥🟥🟥🟥🟥🟥\n", "🟥🟥🟥🟥🟥🟥🟥🟥\n", "🟥🟥🟥🟥🟥🟥🟥🟥\n", "Epoch 15, Loss: 7.075177338602006e-23, \n", "🟪🟪🟪🟪🟪🟪🟪🟪\n", "🟪🟪🟪🟪🟪🟪🟪🟪\n", "🟪🟪🟪🟪🟪🟪🟪🟪\n", "🟪🟪🟪🟪🟪🟪🟪🟪\n", "🟪🟪🟪🟪🟪🟪🟪🟪\n", "🟪🟪🟪🟪🟪🟪🟪🟪\n", "🟪🟪🟪🟪🟪🟪🟪🟪\n", "🟪🟪🟪🟪🟪🟪🟪🟪\n", "Epoch 16, Loss: 0.11943158507347107, \n", "🟨🟨🟨🟨🟨🟨🟨🟨\n", "🟨🟨🟨🟨🟨🟨🟨🟨\n", "🟨🟨🟨🟨🟨🟨🟨🟨\n", "🟨🟨🟨🟨🟨🟨🟨🟨\n", "🟨🟨🟨🟨🟨🟨🟨🟨\n", "🟨🟨🟨🟨🟨🟨🟨🟨\n", "🟨🟨🟨🟨🟨🟨🟨🟨\n", "🟨🟨🟨🟨🟨🟨🟨🟨\n", "Epoch 17, Loss: 0.002167836995795369, \n", "🟨🟨🟨🟨🟨🟨🟨🟨\n", "🟨🟨🟨🟨🟨🟨🟨🟨\n", "🟨🟨🟨🟨🟨🟨🟨🟨\n", "🟨🟨🟨🟨🟨🟨🟨🟨\n", "🟨🟨🟨🟨🟨🟨🟨🟨\n", "🟨🟨🟨🟨🟨🟨🟨🟨\n", "🟨🟨🟨🟨🟨🟨🟨🟨\n", "🟨🟨🟨🟨🟨🟨🟨🟨\n", "Epoch 18, Loss: 6.590542034246027e-05, \n", "🟧🟧🟧🟧🟧🟧🟧🟧\n", "🟧🟧🟧🟧🟧🟧🟧🟧\n", "🟧🟧🟧🟧🟧🟧🟧🟧\n", "🟧🟧🟧🟧🟧🟧🟧🟧\n", "🟧🟧🟧🟧🟧🟧🟧🟧\n", "🟧🟧🟧🟧🟧🟧🟧🟧\n", "🟧🟧🟧🟧🟧🟧🟧🟧\n", "🟧🟧🟧🟧🟧🟧🟧🟧\n", "Epoch 19, Loss: 0.0017910723108798265, \n", "🟩🟩🟩🟩🟩🟩🟩🟩\n", "🟩🟩🟩🟩🟩🟩🟩🟩\n", "🟩🟩🟩🟩🟩🟩🟩🟩\n", "🟩🟩🟩🟩🟩🟩🟩🟩\n", "🟩🟩🟩🟩🟩🟩🟩🟩\n", "🟩🟩🟩🟩🟩🟩🟩🟩\n", "🟩🟩🟩🟩🟩🟩🟩🟩\n", "🟩🟩🟩🟩🟩🟩🟩🟩\n", "Epoch 20, Loss: 0.004703103564679623, \n", "🟦🟦🟦🟦🟦🟦🟦🟦\n", "🟦🟦🟦🟦🟦🟦🟦🟦\n", "🟦🟦🟦🟦🟦🟦🟦🟦\n", "🟦🟦🟦🟦🟦🟦🟦🟦\n", "🟦🟦🟦🟦🟦🟦🟦🟦\n", "🟦🟦🟦🟦🟦🟦🟦🟦\n", "🟦🟦🟦🟦🟦🟦🟦🟦\n", "🟦🟦🟦🟦🟦🟦🟦🟦\n" ] } ], "source": [ "EPOCHS = 50\n", "\n", "for epoch in range(EPOCHS):\n", " # Reset the metrics at the start of the next epoch\n", " train_loss.reset_states()\n", "# train_accuracy.reset_states()\n", "# test_loss.reset_states()\n", "# test_accuracy.reset_states()\n", "\n", "# for images, labels in train_ds:\n", " for i in range(50):\n", " train_step()\n", "\n", "# for test_images, test_labels in test_ds:\n", "# test_step(test_images, test_labels)\n", "\n", " print(\n", " f'Epoch {epoch + 1}, '\n", " f'Loss: {train_loss.result()}, '\n", " # f'Accuracy: {train_accuracy.result() * 100}, '\n", " # f'Test Loss: {test_loss.result()}, '\n", " # f'Test Accuracy: {test_accuracy.result() * 100}'\n", " )\n", "\n", " ps = gen()\n", " # print(ps[0])\n", " print(index_to_string(vector_to_index(ps[0])))" ] } ], "metadata": { "kernelspec": { "display_name": ".venv_patch_gen", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.9" }, "orig_nbformat": 4 }, "nbformat": 4, "nbformat_minor": 2 }