diff --git a/aesthetic_loss.py b/aesthetic_loss.py index b504577..959fbb5 100644 --- a/aesthetic_loss.py +++ b/aesthetic_loss.py @@ -1,5 +1,65 @@ import tensorflow as tf -def my_loss_fn(y_true, y_pred): - - return tf.reduce_mean(squared_difference, axis=-1) # Note the `axis=-1` +from colors import vector_to_index + + +@tf.function +def aesthetic_loss(y_pics): + + flipx = tf.reverse(y_pics, axis=(-2,)) + flipy = tf.reverse(y_pics, axis=(-3,)) + center = tf.reverse(y_pics, axis=(-2, -3,)) + trans = tf.transpose(y_pics, perm=(0, 2, 1, 3)) + anti_trans = tf.reverse(tf.transpose(tf.reverse( + y_pics, axis=(-2,)), perm=(0, 2, 1, 3)), axis=(-2,)) + + flipx_score = compute_score(y_pics, flipx) + flipy_score = compute_score(y_pics, flipy) + trans_score = compute_score(y_pics, trans) + anti_trans_score = compute_score(y_pics, anti_trans) + center_score = compute_score(y_pics, center) + + # Extract patches for neighborhood check + patches = tf.image.extract_patches( + y_pics, (1, 3, 3, 1), (1, 1, 1, 1), (1, 1, 1, 1), padding="SAME") + # Dot product with central pixel + # neigh = tf.reduce_sum( + # tf.maximum(0.0, -tf.einsum("...k,...k->...", tf.repeat(y_pics, 9, -1), patches))) + neigh = tf.reduce_sum( + tf.abs(tf.einsum("...k,...k->...", tf.repeat(y_pics, 9, -1), patches))) + + divers = tf.reduce_sum( + tf.square(tf.einsum("ijkl,mjkl->imjk", y_pics, y_pics))) + + # print(flipx_score) + # print( tf.reduce_max(tf.stack([flipx_score, flipy_score, trans_score, anti_trans_score, center_score]), axis=(0,))) + + # idxs = vector_to_index(y_pics) + # ncol = tf.map_fn(lambda p: tf.size( + # tf.unique(tf.reshape(p, [-1]),)[0], out_type=tf.int64), idxs) + # print(ncol) + + return tf.reduce_sum([ + -1.0/(64) * tf.reduce_sum(tf.reduce_max(tf.stack([flipx_score, flipy_score, + trans_score, anti_trans_score, center_score]), axis=(0,))), + -1.0/(64*9) * 0.1 * neigh, + # -1.0/7 * 10 * tf.cast(ncol, tf.float32), + # 1.0/64 * 0.01 * tf.reduce_sum(tf.square(y_pics)), + 1.0/64 * 10 * divers + ], ) + + +@tf.function +def compute_score(pic1, pic2): + return tf.reduce_sum( + tf.abs( + tf.einsum( + "...k,...k->...", + pic1, + pic2)), axis=(-2, -1)) + # return tf.reduce_sum( + # -tf.math.maximum(0.0, + # -tf.einsum( + # "...k,...k->...", + # pic1, + # pic2)), axis=(-2, -1)) diff --git a/colors.py b/colors.py new file mode 100644 index 0000000..9fd2c5a --- /dev/null +++ b/colors.py @@ -0,0 +1,31 @@ +from math import pi +import tensorflow as tf + +squares = [chr(i) for i in range(0x1F7E5, 0x1F7EC)] +tfsquares = tf.constant(squares) +colors = tf.constant([10, 219, 38, 48, 80, 282, 20] * + tf.constant(pi/180.0), dtype=tf.float32) +color_vectors = tf.transpose( + tf.stack([ + tf.math.cos(colors), + tf.math.sin(colors)] + ) +) + + +def vector_to_index(tensor): + return tf.argmax((tf.einsum("...ijk,lk->...ijl", tensor, color_vectors)), axis=-1) + + +def index_to_string(tensor): + tfstring = tf.strings.join( + tf.map_fn( + lambda v: + tf.strings.join( + tf.gather(tfsquares, tf.cast(v, tf.int64))), tensor, fn_output_signature=tf.string), "\n") + + return tfstring.numpy().decode() + + +def vector_to_string(tensor): + return index_to_string(vector_to_index(tensor)) diff --git a/patch_gen.ipynb b/patch_gen.ipynb new file mode 100644 index 0000000..d6483a3 --- /dev/null +++ b/patch_gen.ipynb @@ -0,0 +1,398 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "The autoreload extension is already loaded. To reload it, use:\n", + " %reload_ext autoreload\n" + ] + } + ], + "source": [ + "%load_ext autoreload\n", + "%autoreload 2" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [], + "source": [ + "import tensorflow as tf\n", + "from tensorflow.keras.layers import Dense, Conv2D, Reshape\n", + "from polymap import Polymap\n", + "from aesthetic_loss import aesthetic_loss" + ] + }, + { + "cell_type": "code", + "execution_count": 46, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "['🟥', '🟦', '🟧', '🟨', '🟩', '🟪', '🟫']\n" + ] + } + ], + "source": [ + "from math import pi\n", + "squares = [chr(i) for i in range(0x1F7E5, 0x1F7EC)]\n", + "tfsquares = tf.constant(squares)\n", + "print(squares)\n", + "colors = tf.constant([10, 219, 38, 48, 80, 282, 20] * tf.constant(pi/180.0), dtype=tf.float32)\n", + "color_vectors = tf.transpose(\n", + " tf.stack([\n", + " tf.math.cos(colors),\n", + " tf.math.sin(colors)]\n", + " )\n", + ")\n", + "\n", + "def vector_to_index(tensor):\n", + " return tf.argmax((tf.einsum(\"ijk,lk->ijl\", tensor, color_vectors)), axis=-1)\n", + "\n", + "def index_to_string(tensor):\n", + " # square_select = tf.math.argmax(tensor, 2)\n", + "\n", + " tfstring = tf.strings.join(\n", + " tf.map_fn(\n", + " lambda v:\n", + " tf.strings.join(\n", + " tf.gather(tfsquares, tf.cast(v, tf.int64))), tensor, fn_output_signature=tf.string), \"\\n\")\n", + "\n", + " return tfstring.numpy().decode()" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [], + "source": [ + "batch_size = 16" + ] + }, + { + "cell_type": "code", + "execution_count": 51, + "metadata": {}, + "outputs": [], + "source": [ + "# class PatchGen(tf.keras.Model):\n", + "\n", + "# def __init__(self, batch_size):\n", + "# super(PatchGen, self).__init__()\n", + "# self.pm = Polymap(8, 3)(input)\n", + "# self.d1 = Dense(64)\n", + "# self.rs = Reshape([8,8,1])\n", + "# self.dc1 = Conv2D(3,(1,1))\n", + "\n", + "# def call(self, inputs):\n", + "# x = self.block_1(inputs)\n", + "# x = self.block_2(x)\n", + "# x = self.global_pool(x)\n", + "# return self.classifier(x)\n", + "\n", + "input = tf.random.normal(shape=[batch_size,2])\n", + "pm = Polymap(8, 3)\n", + "d1 = Dense(64, activation=\"relu\")\n", + "d2 = Dense(64, activation=\"relu\")\n", + "rs = Reshape([8,8,1])\n", + "dc1 = Conv2D(2,(1,1))\n", + "\n", + "def gen():\n", + " input = tf.random.normal(shape=[batch_size,2])\n", + " x = pm(input)\n", + " x = d1(x)\n", + " x = d2(x)\n", + " x = rs(x)\n", + " return dc1(x)\n", + "# gen()" + ] + }, + { + "cell_type": "code", + "execution_count": 52, + "metadata": {}, + "outputs": [], + "source": [ + "optimizer = tf.keras.optimizers.Adam()\n", + "train_loss = tf.keras.metrics.Mean(name='train_loss')" + ] + }, + { + "cell_type": "code", + "execution_count": 53, + "metadata": {}, + "outputs": [], + "source": [ + "@tf.function\n", + "def train_step():\n", + " with tf.GradientTape() as tape:\n", + " pics = gen()\n", + " loss = aesthetic_loss(pics)\n", + " tv = sum([pm.trainable_weights, d1.trainable_weights,\n", + " dc1.trainable_weights], [])\n", + " print(tv)\n", + " gradients = tape.gradient(loss, tv)\n", + " optimizer.apply_gradients(zip(gradients, tv))\n", + "\n", + " train_loss(loss)\n", + " # train_accuracy(labels, predictions)" + ] + }, + { + "cell_type": "code", + "execution_count": 57, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 1, Loss: 0.0004605448921211064, \n", + "🟩🟩🟩🟩🟩🟩🟩🟩\n", + "🟩🟩🟩🟩🟩🟩🟩🟩\n", + "🟩🟩🟩🟩🟩🟩🟩🟩\n", + "🟩🟩🟩🟩🟩🟩🟩🟩\n", + "🟩🟩🟩🟩🟩🟩🟩🟩\n", + "🟩🟩🟩🟩🟩🟩🟩🟩\n", + "🟩🟩🟩🟩🟩🟩🟩🟩\n", + "🟩🟩🟩🟩🟩🟩🟩🟩\n", + "Epoch 2, Loss: 9.45829197007697e-06, \n", + "🟥🟥🟥🟥🟥🟥🟥🟥\n", + "🟥🟥🟥🟥🟥🟥🟥🟥\n", + "🟥🟥🟥🟥🟥🟥🟥🟥\n", + "🟥🟥🟥🟥🟥🟥🟥🟥\n", + "🟥🟥🟥🟥🟥🟥🟥🟥\n", + "🟥🟥🟥🟥🟥🟥🟥🟥\n", + "🟥🟥🟥🟥🟥🟥🟥🟥\n", + "🟥🟥🟥🟥🟥🟥🟥🟥\n", + "Epoch 3, Loss: 8.17671389086172e-05, \n", + "🟦🟦🟦🟦🟦🟦🟦🟦\n", + "🟦🟦🟦🟦🟦🟦🟦🟦\n", + "🟦🟦🟦🟦🟦🟦🟦🟦\n", + "🟦🟦🟦🟦🟦🟦🟦🟦\n", + "🟦🟦🟦🟦🟦🟦🟦🟦\n", + "🟦🟦🟦🟦🟦🟦🟦🟦\n", + "🟦🟦🟦🟦🟦🟦🟦🟦\n", + "🟦🟦🟦🟦🟦🟦🟦🟦\n", + "Epoch 4, Loss: 0.0001505890249973163, \n", + "🟦🟦🟦🟦🟦🟦🟦🟦\n", + "🟦🟦🟦🟦🟦🟦🟦🟦\n", + "🟦🟦🟦🟦🟦🟦🟦🟦\n", + "🟦🟦🟦🟦🟦🟦🟦🟦\n", + "🟦🟦🟦🟦🟦🟦🟦🟦\n", + "🟦🟦🟦🟦🟦🟦🟦🟦\n", + "🟦🟦🟦🟦🟦🟦🟦🟦\n", + "🟦🟦🟦🟦🟦🟦🟦🟦\n", + "Epoch 5, Loss: 0.00018983485642820597, \n", + "🟩🟩🟩🟩🟩🟩🟩🟩\n", + "🟩🟩🟩🟩🟩🟩🟩🟩\n", + "🟩🟩🟩🟩🟩🟩🟩🟩\n", + "🟩🟩🟩🟩🟩🟩🟩🟩\n", + "🟩🟩🟩🟩🟩🟩🟩🟩\n", + "🟩🟩🟩🟩🟩🟩🟩🟩\n", + "🟩🟩🟩🟩🟩🟩🟩🟩\n", + "🟩🟩🟩🟩🟩🟩🟩🟩\n", + "Epoch 6, Loss: 0.00012902110756840557, \n", + "🟦🟦🟦🟦🟦🟦🟦🟦\n", + "🟦🟦🟦🟦🟦🟦🟦🟦\n", + "🟦🟦🟦🟦🟦🟦🟦🟦\n", + "🟦🟦🟦🟦🟦🟦🟦🟦\n", + "🟦🟦🟦🟦🟦🟦🟦🟦\n", + "🟦🟦🟦🟦🟦🟦🟦🟦\n", + "🟦🟦🟦🟦🟦🟦🟦🟦\n", + "🟦🟦🟦🟦🟦🟦🟦🟦\n", + "Epoch 7, Loss: 1.342521863989532e-05, \n", + "🟨🟨🟨🟨🟨🟨🟨🟨\n", + "🟨🟨🟨🟨🟨🟨🟨🟨\n", + "🟨🟨🟨🟨🟨🟨🟨🟨\n", + "🟨🟨🟨🟨🟨🟨🟨🟨\n", + "🟨🟨🟨🟨🟨🟨🟨🟨\n", + "🟨🟨🟨🟨🟨🟨🟨🟨\n", + "🟨🟨🟨🟨🟨🟨🟨🟨\n", + "🟨🟨🟨🟨🟨🟨🟨🟨\n", + "Epoch 8, Loss: 0.00034237385261803865, \n", + "🟩🟩🟩🟩🟩🟩🟩🟩\n", + "🟩🟩🟩🟩🟩🟩🟩🟩\n", + "🟩🟩🟩🟩🟩🟩🟩🟩\n", + "🟩🟩🟩🟩🟩🟩🟩🟩\n", + "🟩🟩🟩🟩🟩🟩🟩🟩\n", + "🟩🟩🟩🟩🟩🟩🟩🟩\n", + "🟩🟩🟩🟩🟩🟩🟩🟩\n", + "🟩🟩🟩🟩🟩🟩🟩🟩\n", + "Epoch 9, Loss: 8.251221879618242e-05, \n", + "🟦🟦🟦🟦🟦🟦🟦🟦\n", + "🟦🟦🟦🟦🟦🟦🟦🟦\n", + "🟦🟦🟦🟦🟦🟦🟦🟦\n", + "🟦🟦🟦🟦🟦🟦🟦🟦\n", + "🟦🟦🟦🟦🟦🟦🟦🟦\n", + "🟦🟦🟦🟦🟦🟦🟦🟦\n", + "🟦🟦🟦🟦🟦🟦🟦🟦\n", + "🟦🟦🟦🟦🟦🟦🟦🟦\n", + "Epoch 10, Loss: 1.0330303013006414e-07, \n", + "🟧🟧🟧🟧🟧🟧🟧🟧\n", + "🟧🟧🟧🟧🟧🟧🟧🟧\n", + "🟧🟧🟧🟧🟧🟧🟧🟧\n", + "🟧🟧🟧🟧🟧🟧🟧🟧\n", + "🟧🟧🟧🟧🟧🟧🟧🟧\n", + "🟧🟧🟧🟧🟧🟧🟧🟧\n", + "🟧🟧🟧🟧🟧🟧🟧🟧\n", + "🟧🟧🟧🟧🟧🟧🟧🟧\n", + "Epoch 11, Loss: 3.0386236176127568e-05, \n", + "🟥🟥🟥🟥🟥🟥🟥🟥\n", + "🟥🟥🟥🟥🟥🟥🟥🟥\n", + "🟥🟥🟥🟥🟥🟥🟥🟥\n", + "🟥🟥🟥🟥🟥🟥🟥🟥\n", + "🟥🟥🟥🟥🟥🟥🟥🟥\n", + "🟥🟥🟥🟥🟥🟥🟥🟥\n", + "🟥🟥🟥🟥🟥🟥🟥🟥\n", + "🟥🟥🟥🟥🟥🟥🟥🟥\n", + "Epoch 12, Loss: 4.6750748879276216e-05, \n", + "🟦🟦🟦🟦🟦🟦🟦🟦\n", + "🟦🟦🟦🟦🟦🟦🟦🟦\n", + "🟦🟦🟦🟦🟦🟦🟦🟦\n", + "🟦🟦🟦🟦🟦🟦🟦🟦\n", + "🟦🟦🟦🟦🟦🟦🟦🟦\n", + "🟦🟦🟦🟦🟦🟦🟦🟦\n", + "🟦🟦🟦🟦🟦🟦🟦🟦\n", + "🟦🟦🟦🟦🟦🟦🟦🟦\n", + "Epoch 13, Loss: 7.047886185773677e-08, \n", + "🟦🟦🟦🟦🟦🟦🟦🟦\n", + "🟦🟦🟦🟦🟦🟦🟦🟦\n", + "🟦🟦🟦🟦🟦🟦🟦🟦\n", + "🟦🟦🟦🟦🟦🟦🟦🟦\n", + "🟦🟦🟦🟦🟦🟦🟦🟦\n", + "🟦🟦🟦🟦🟦🟦🟦🟦\n", + "🟦🟦🟦🟦🟦🟦🟦🟦\n", + "🟦🟦🟦🟦🟦🟦🟦🟦\n", + "Epoch 14, Loss: 9.711433317672463e-14, \n", + "🟥🟥🟥🟥🟥🟥🟥🟥\n", + "🟥🟥🟥🟥🟥🟥🟥🟥\n", + "🟥🟥🟥🟥🟥🟥🟥🟥\n", + "🟥🟥🟥🟥🟥🟥🟥🟥\n", + "🟥🟥🟥🟥🟥🟥🟥🟥\n", + "🟥🟥🟥🟥🟥🟥🟥🟥\n", + "🟥🟥🟥🟥🟥🟥🟥🟥\n", + "🟥🟥🟥🟥🟥🟥🟥🟥\n", + "Epoch 15, Loss: 7.075177338602006e-23, \n", + "🟪🟪🟪🟪🟪🟪🟪🟪\n", + "🟪🟪🟪🟪🟪🟪🟪🟪\n", + "🟪🟪🟪🟪🟪🟪🟪🟪\n", + "🟪🟪🟪🟪🟪🟪🟪🟪\n", + "🟪🟪🟪🟪🟪🟪🟪🟪\n", + "🟪🟪🟪🟪🟪🟪🟪🟪\n", + "🟪🟪🟪🟪🟪🟪🟪🟪\n", + "🟪🟪🟪🟪🟪🟪🟪🟪\n", + "Epoch 16, Loss: 0.11943158507347107, \n", + "🟨🟨🟨🟨🟨🟨🟨🟨\n", + "🟨🟨🟨🟨🟨🟨🟨🟨\n", + "🟨🟨🟨🟨🟨🟨🟨🟨\n", + "🟨🟨🟨🟨🟨🟨🟨🟨\n", + "🟨🟨🟨🟨🟨🟨🟨🟨\n", + "🟨🟨🟨🟨🟨🟨🟨🟨\n", + "🟨🟨🟨🟨🟨🟨🟨🟨\n", + "🟨🟨🟨🟨🟨🟨🟨🟨\n", + "Epoch 17, Loss: 0.002167836995795369, \n", + "🟨🟨🟨🟨🟨🟨🟨🟨\n", + "🟨🟨🟨🟨🟨🟨🟨🟨\n", + "🟨🟨🟨🟨🟨🟨🟨🟨\n", + "🟨🟨🟨🟨🟨🟨🟨🟨\n", + "🟨🟨🟨🟨🟨🟨🟨🟨\n", + "🟨🟨🟨🟨🟨🟨🟨🟨\n", + "🟨🟨🟨🟨🟨🟨🟨🟨\n", + "🟨🟨🟨🟨🟨🟨🟨🟨\n", + "Epoch 18, Loss: 6.590542034246027e-05, \n", + "🟧🟧🟧🟧🟧🟧🟧🟧\n", + "🟧🟧🟧🟧🟧🟧🟧🟧\n", + "🟧🟧🟧🟧🟧🟧🟧🟧\n", + "🟧🟧🟧🟧🟧🟧🟧🟧\n", + "🟧🟧🟧🟧🟧🟧🟧🟧\n", + "🟧🟧🟧🟧🟧🟧🟧🟧\n", + "🟧🟧🟧🟧🟧🟧🟧🟧\n", + "🟧🟧🟧🟧🟧🟧🟧🟧\n", + "Epoch 19, Loss: 0.0017910723108798265, \n", + "🟩🟩🟩🟩🟩🟩🟩🟩\n", + "🟩🟩🟩🟩🟩🟩🟩🟩\n", + "🟩🟩🟩🟩🟩🟩🟩🟩\n", + "🟩🟩🟩🟩🟩🟩🟩🟩\n", + "🟩🟩🟩🟩🟩🟩🟩🟩\n", + "🟩🟩🟩🟩🟩🟩🟩🟩\n", + "🟩🟩🟩🟩🟩🟩🟩🟩\n", + "🟩🟩🟩🟩🟩🟩🟩🟩\n", + "Epoch 20, Loss: 0.004703103564679623, \n", + "🟦🟦🟦🟦🟦🟦🟦🟦\n", + "🟦🟦🟦🟦🟦🟦🟦🟦\n", + "🟦🟦🟦🟦🟦🟦🟦🟦\n", + "🟦🟦🟦🟦🟦🟦🟦🟦\n", + "🟦🟦🟦🟦🟦🟦🟦🟦\n", + "🟦🟦🟦🟦🟦🟦🟦🟦\n", + "🟦🟦🟦🟦🟦🟦🟦🟦\n", + "🟦🟦🟦🟦🟦🟦🟦🟦\n" + ] + } + ], + "source": [ + "EPOCHS = 50\n", + "\n", + "for epoch in range(EPOCHS):\n", + " # Reset the metrics at the start of the next epoch\n", + " train_loss.reset_states()\n", + "# train_accuracy.reset_states()\n", + "# test_loss.reset_states()\n", + "# test_accuracy.reset_states()\n", + "\n", + "# for images, labels in train_ds:\n", + " for i in range(50):\n", + " train_step()\n", + "\n", + "# for test_images, test_labels in test_ds:\n", + "# test_step(test_images, test_labels)\n", + "\n", + " print(\n", + " f'Epoch {epoch + 1}, '\n", + " f'Loss: {train_loss.result()}, '\n", + " # f'Accuracy: {train_accuracy.result() * 100}, '\n", + " # f'Test Loss: {test_loss.result()}, '\n", + " # f'Test Accuracy: {test_accuracy.result() * 100}'\n", + " )\n", + "\n", + " ps = gen()\n", + " # print(ps[0])\n", + " print(index_to_string(vector_to_index(ps[0])))" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv_patch_gen", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.9" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/patch_gen.py b/patch_gen.py index 04da175..2dbcff0 100644 --- a/patch_gen.py +++ b/patch_gen.py @@ -1,4 +1,127 @@ -from tensorflow import keras +# autopep8: off +import tensorflow as tf -input = keras.Input((2,)) +physical_devices = tf.config.experimental.list_physical_devices('GPU') +print(physical_devices) +tf.config.experimental.set_memory_growth(physical_devices[0], True) +import colors +from aesthetic_loss import aesthetic_loss +from polymap import Polymap +from tensorflow.keras.layers import Dense, Conv2D, Reshape, ReLU, Conv2DTranspose, LeakyReLU, BatchNormalization, GaussianNoise +# autopep8: on + +batch_size = 256 + + +input_small = tf.random.normal(shape=[4, 2]) +pm = Polymap(8, 3) +rl = ReLU() +d1 = Dense(256, activation="relu", + kernel_regularizer=tf.keras.regularizers.l1_l2()) +d2 = Dense(512, activation="relu", + kernel_regularizer=tf.keras.regularizers.l1_l2()) +d3 = Dense(1024, activation="relu", + kernel_regularizer=tf.keras.regularizers.l1_l2()) +rs = Reshape([2, 2, 1024//4]) +dc1 = Conv2DTranspose(256, (5, 5), + padding="valid", use_bias=False) +dc2 = Conv2DTranspose(128, (3, 3), + padding="valid", use_bias=False) +dc3 = Conv2DTranspose(64, (3, 3), + padding="same", use_bias=False) +dc4 = Conv2D(2, (1, 1), activation="tanh", use_bias=False) +bn1 = BatchNormalization() +bn2 = BatchNormalization() +bn3 = BatchNormalization() + +def gen_model(): + model = tf.keras.Sequential() + + model.add(pm) + # model.add(rl) + + model.add(d1) + model.add(d2) + model.add(d3) + + model.add(rs) + + model.add(dc1) + model.add(bn1) + model.add(LeakyReLU()) + + model.add(dc2) + model.add(bn2) + model.add(LeakyReLU()) + + model.add(dc3) + model.add(bn3) + model.add(LeakyReLU()) + + model.add(dc4) + return model + + +optimizer = tf.keras.optimizers.Adam(learning_rate=0.001, epsilon=0.001) +train_loss = tf.keras.metrics.Mean(name='train_loss') + +generator = gen_model() + +rng = tf.random.Generator.from_seed(1) + +@tf.function +def train_step(input): + # input = tf.random.normal(shape=[batch_size, 2]) + # input = rng.normal(shape=[batch_size, 2]) + + # print(input) + + with tf.GradientTape() as tape: + # pics = gen(input) + pics = generator(input, training=True) + loss = aesthetic_loss( + pics) + tf.reduce_sum([d1.losses, d2.losses, d3.losses]) + + 0.01 * tf.reduce_mean(tf.abs(pm.trainable_weights)) + # tv = sum([pm.trainable_weights, + # d1.trainable_weights, + # d2.trainable_weights, + # d3.trainable_weights, + # dc1.trainable_weights, + # dc2.trainable_weights, + # dc3.trainable_weights,], []) + + # print(tv) + gradients = tape.gradient(loss, generator.trainable_weights) + optimizer.apply_gradients(zip(gradients, generator.trainable_weights)) + + train_loss(loss + tf.reduce_sum([d1.losses, d2.losses, d3.losses])) + + +EPOCHS = 200 + +for epoch in range(EPOCHS): + train_loss.reset_states() + + input = rng.normal(shape=[batch_size, 2]) + # print(input) + + for i in range(100): + train_step(input) + + print( + f'Epoch {epoch + 1}, ' + f'Loss: {train_loss.result()}, ' + ) + + input_small = rng.normal(shape=[4, 2]) + ps = generator(input_small, training=False) + # print(ps[0]) + # print(pm.kernels) + s0 = colors.vector_to_string(ps[0]) + s1 = colors.vector_to_string(ps[1]) + s2 = colors.vector_to_string(ps[2]) + s3 = colors.vector_to_string(ps[3]) + + print("\n".join(" ".join(r) + for r in zip(s0.split("\n"), s1.split("\n"), s2.split("\n"), s3.split("\n")))) diff --git a/playground.ipynb b/playground.ipynb new file mode 100644 index 0000000..3ebb6b6 --- /dev/null +++ b/playground.ipynb @@ -0,0 +1,794 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "%load_ext autoreload\n", + "%autoreload 2" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "import tensorflow as tf\n", + "import random\n", + "from math import pi" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "['🟥', '🟦', '🟧', '🟨', '🟩', '🟪', '🟫']\n" + ] + }, + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "squares = [chr(i) for i in range(0x1F7E5, 0x1F7EC)]\n", + "tfsquares = tf.constant(squares)\n", + "print(squares)\n", + "colors = tf.constant([10, 219, 38, 48, 80, 282, 20] * tf.constant(pi/180.0), dtype=tf.float32)\n", + "color_vectors = tf.transpose(\n", + " tf.stack([\n", + " tf.math.cos(colors),\n", + " tf.math.sin(colors)]\n", + " )\n", + ")\n", + "color_vectors" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [], + "source": [ + "def one_hot_to_string(tensor):\n", + " square_select = tf.math.argmax(tensor, 2)\n", + "\n", + " tfstring = tf.strings.join(\n", + " tf.map_fn(\n", + " lambda v:\n", + " tf.strings.join(\n", + " tf.gather(tfsquares, tf.cast(v, tf.int64))), square_select, fn_output_signature=tf.string), \"\\n\")\n", + "\n", + " return tfstring.numpy().decode()\n", + "\n", + "def index_to_string(tensor):\n", + " return tf.argmax((tf.einsum(\"ijk,lk->ijl\", tensor, color_vectors)), axis=2)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 57, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 57, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "vi = tf.map_fn(lambda v: tf.map_fn(lambda s: color_vectors[s],v, fn_output_signature=tf.float32),tf.constant(e, shape=(8, 8)),fn_output_signature=tf.float32)\n", + "\n", + "index_to_string(vi)" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# def colors_to_one_hot(tensor):\n", + "vi = tf.map_fn(lambda v: tf.map_fn(lambda s: color_vectors[s],v, fn_output_signature=tf.float32),tf.constant(e, shape=(8, 8)),fn_output_signature=tf.float32)\n", + "ns = tf.image.extract_patches(tf.constant(vi, shape=(1,8,8,2)), (1,3,3,1), (1,1,1,1), (1,1,1,1), padding=\"SAME\")\n", + "virs = tf.reshape(vi, (8*8,2))\n", + "# tf.reduce_sum(tf.multiply(tf.repeat(vi, 9, -1), ns))\n", + "tf.abs(tf.einsum(\"...k,...k->...\", tf.repeat(vi, 9, -1), ns))" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "tf.Tensor([0. 0. 0. 1. 0. 0. 0.], shape=(7,), dtype=float32)\n", + "tf.Tensor([[3]], shape=(1, 1), dtype=int64)\n", + "🟨🟥🟨🟨🟩🟫🟪🟫\n", + "🟩🟪🟧🟧🟦🟧🟦🟫\n", + "🟩🟦🟫🟨🟦🟨🟩🟫\n", + "🟪🟫🟩🟨🟧🟪🟥🟩\n", + "🟥🟫🟨🟫🟧🟦🟦🟧\n", + "🟩🟫🟫🟧🟨🟨🟪🟫\n", + "🟩🟪🟦🟥🟧🟫🟧🟨\n", + "🟨🟪🟨🟧🟪🟦🟨🟧\n", + "--------------\n", + "🟧🟨🟦🟪🟧🟨🟪🟨\n", + "🟨🟧🟫🟧🟥🟦🟪🟩\n", + "🟫🟪🟨🟨🟧🟫🟫🟩\n", + "🟧🟦🟦🟧🟫🟨🟫🟥\n", + "🟩🟥🟪🟧🟨🟩🟫🟪\n", + "🟫🟩🟨🟦🟨🟫🟦🟩\n", + "🟫🟦🟧🟦🟧🟧🟪🟩\n", + "🟫🟪🟫🟩🟨🟨🟥🟨\n", + "tf.Tensor(\n", + "[[0. 0. 0. 0. 0. 0. 0. 0.]\n", + " [0. 0. 1. 0. 0. 1. 0. 0.]\n", + " [0. 0. 0. 0. 0. 0. 0. 0.]\n", + " [0. 0. 0. 0. 0. 0. 0. 0.]\n", + " [0. 0. 0. 0. 0. 0. 0. 0.]\n", + " [0. 0. 0. 0. 0. 0. 0. 0.]\n", + " [0. 0. 0. 0. 0. 0. 0. 0.]\n", + " [0. 0. 0. 0. 0. 0. 0. 0.]], shape=(8, 8), dtype=float32)\n" + ] + } + ], + "source": [ + "oh = tf.one_hot(3, 7)\n", + "print(oh)\n", + "print(tf.where(oh))\n", + "\n", + "# a = tf.constant([[list(tf.one_hot(random.randrange(7), 7)) for i in range(8) ] for i in range(8)], shape=(8,8,len(squares)))\n", + "\n", + "# print([[list(tf.one_hot(random.randrange(7), 7)) for i in range(8) ] for i in range(8)])\n", + "e = [[random.randrange(7) for i in range(8)] for i in range(8)]\n", + "# print(e)\n", + "\n", + "a = tf.one_hot(tf.constant(e, shape=(8, 8)), len(squares))\n", + "\n", + "print(one_hot_to_string(a))\n", + "# print(\"--------------\")\n", + "# print(one_hot_to_string(tf.reverse(a, axis=(-2,))))\n", + "# print(\"--------------\")\n", + "# print(one_hot_to_string(tf.reverse(a, axis=(-3,))))\n", + "# print(\"--------------\")\n", + "# print(one_hot_to_string(tf.transpose(a, perm=(1,0,2))))\n", + "print(\"--------------\")\n", + "# print(one_hot_to_string(tf.reverse(tf.transpose(tf.reverse(a, axis=(-2,)), perm=(1,0,2)), axis=(-2,))))\n", + "print(one_hot_to_string(tf.reverse(a, axis=(-2,-3,))))\n", + "# print(tensor_to_string(tf.transpose(a, [1, 0, 2])))\n", + "# print(\"--------------\")\n", + "# print(tensor_to_string(tf.linalg.matmul( a , tf.constant(\n", + "# [\n", + "# [0, 0, 0, 0, 0, 0, 0, 1],\n", + "# [0, 0, 0, 0, 0, 0, 1, 0],\n", + "# [0, 0, 0, 0, 0, 1, 0, 0],\n", + "# [0, 0, 0, 0, 1, 0, 0, 0],\n", + "# [0, 0, 0, 1, 0, 0, 0, 0],\n", + "# [0, 0, 1, 0, 0, 0, 0, 0],\n", + "# [0, 1, 0, 0, 0, 0, 0, 0],\n", + "# [1, 0, 0, 0, 0, 0, 0, 0]\n", + "# ], dtype=tf.float32, shape=(8,8,1)\n", + "# ), transpose_a=True)))\n", + "# print(\"--------------\")\n", + "# print(tensor_to_string(tf.reverse(a, (0,))))\n", + "# print(\"--------------\")\n", + "# print(tensor_to_string(tf.reverse(a, (1,))))\n", + "# print(tensor_to_string(\n", + "# tf.linalg.matmul(\n", + "# tf.transpose(a, [0, 2, 1]),\n", + "# tf.reverse(a, (1,)))\n", + "# ))\n", + "print(\n", + " tf.einsum(\n", + " \"ijk,ijk->ij\",\n", + " a,\n", + " tf.reverse(a, (1,)))\n", + ")\n", + "\n", + "\n", + "# print(tf.constant(range(6), shape=(2,3)))\n", + "# tf.map_fn(tf.math.reduce_sum, tf.constant(range(6), shape=(2,3)))\n" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 15, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import aesthetic_loss\n", + "\n", + "es = [[[random.randrange(7) for i in range(8)]\n", + " for i in range(8)] for j in range(16)]\n", + "\n", + "vis = tf.stack(\n", + " [tf.map_fn(lambda v: tf.map_fn(lambda s: color_vectors[s], v, fn_output_signature=tf.float32), tf.constant(\n", + " e, shape=(8, 8)), fn_output_signature=tf.float32) for e in es]\n", + ")\n", + "\n", + "aesthetic_loss.aesthetic_loss(vis)\n", + "# aesthetic_loss.compute_score(vis, vis)\n", + "# tf.linalg.l2_normalize(vis, axis=-1)" + ] + }, + { + "cell_type": "code", + "execution_count": 91, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 91, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "x = tf.Variable(2)\n", + "y = tf.Variable(3)\n", + "\n", + "xp = tf.pow(x, [0,1,2,3])\n", + "yp = tf.pow(y, [0,1,2,3])\n", + "\n", + "A = tf.Variable([[i+j for i in range(4)] for j in range(4)], shape=(4,4))\n", + "tf.einsum(\"i,ij,j\", xp, A, yp)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 70, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 70, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "r = tf.random.normal((16,2))\n", + "rp = tf.map_fn(lambda v: tf.pow(tf.repeat(tf.reshape(v, shape=(2,1)), repeats=4, axis=1), tf.constant([[0,1,2,3], [0,1,2,3]], dtype=tf.float32)), r)\n", + "A = tf.Variable(tf.reshape([float(i) for i in range(8*4*4)], shape=(8,4,4)), shape=(8,4,4), dtype=tf.float32)\n", + "tf.einsum(\"ik,ij,lkj->il\", rp[:,0], rp[:,1], A)\n", + "# rp[:,0]\n" + ] + }, + { + "cell_type": "code", + "execution_count": 82, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "tf.Tensor(\n", + "[[ 4.15302128e-01 -1.17927468e+00 2.49658436e-01 6.19551420e-01\n", + " -1.39033377e-01 2.79209971e-01 3.97601783e-01 -2.17277408e-01]\n", + " [ 1.74276757e+00 5.69359064e-02 5.38259745e-04 1.49045396e+00\n", + " -7.87080288e-01 -3.98119450e-01 1.18716180e-01 1.86502957e+00]\n", + " [ 1.78081155e+00 -1.29199183e+00 -5.72377980e-01 1.73489356e+00\n", + " -1.81523621e+00 -7.56731749e-01 5.50791979e-01 2.23066378e+00]\n", + " [-2.86107287e-02 2.38285348e-01 -2.90394664e-01 2.20198780e-02\n", + " 3.43291536e-02 2.77290910e-01 -3.25144343e-02 3.50418538e-01]\n", + " [-2.32135087e-01 1.95894271e-01 5.75335659e-02 -1.76543556e-03\n", + " -4.01449911e-02 2.00845733e-01 -9.64274257e-02 2.31855541e-01]\n", + " [ 4.60646629e-01 7.75275290e-01 -9.96064782e-01 -5.15016496e-01\n", + " 2.03099363e-02 7.87942708e-02 -3.07008505e-01 7.16598153e-01]\n", + " [ 3.02803397e-01 -8.83126408e-02 3.72119397e-01 1.14753373e-01\n", + " -8.30153376e-03 6.69163764e-02 -3.60930443e-01 4.32033747e-01]\n", + " [-1.55036688e-01 2.96032988e-02 5.69820963e-02 -1.04342289e-01\n", + " 2.18345195e-01 1.60513148e-01 -2.30654161e-02 2.92460948e-01]\n", + " [-9.37448144e-02 4.92064208e-02 1.25066668e-01 -1.08998209e-01\n", + " 1.47912696e-01 7.13704750e-02 -9.69988406e-02 3.16246897e-01]\n", + " [ 3.76237392e-01 -8.38992000e-02 3.31230164e-01 2.86825836e-01\n", + " 3.67110878e-01 -3.61552715e-01 6.06793404e-01 -1.22976625e+00]\n", + " [-1.00743435e-01 -1.46258920e-01 -1.82411104e-01 -3.14753801e-01\n", + " 7.92229354e-01 2.98170269e-01 -1.82591528e-01 2.78764069e-01]\n", + " [ 3.58289123e-01 1.97965109e+00 -1.18626821e+00 -1.47391105e+00\n", + " 6.11245155e-01 -4.61193591e-01 -7.64733970e-01 1.45492101e+00]\n", + " [-2.00078160e-01 4.32360053e-01 1.63023725e-01 -3.31694707e-02\n", + " -4.53773856e-01 4.14216101e-01 1.76714227e-01 2.12158859e-01]\n", + " [ 4.91659880e-01 1.56133473e+00 -8.19299459e-01 9.75016415e-01\n", + " -2.52336192e+00 1.69403172e+00 2.05944443e+00 1.03490460e+00]\n", + " [-2.33240795e+00 6.64168453e+00 -4.04779243e+00 -5.99489832e+00\n", + " 1.60179734e+00 2.19676828e+00 4.11457443e+00 2.81916380e-01]\n", + " [ 1.73429586e-02 4.88152355e-02 3.41429591e-01 -2.39478126e-02\n", + " -2.88104061e-02 1.31737068e-01 -2.55863100e-01 2.94781595e-01]], shape=(16, 8), dtype=float32)\n" + ] + } + ], + "source": [ + "import polymap\n", + "r = tf.random.normal((16,2))\n", + "\n", + "pm = polymap.Polymap(8,4)\n", + "# pm.build([16,2])\n", + "print(pm(r))\n" + ] + }, + { + "cell_type": "code", + "execution_count": 111, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 111, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "r = tf.random.normal(shape = (16, 8, 8, 2))\n", + "tf.linalg.l2_normalize(r,axis=-1)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".patch_gen_venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.9" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/playgrund.ipynb b/playgrund.ipynb deleted file mode 100644 index fd72057..0000000 --- a/playgrund.ipynb +++ /dev/null @@ -1,155 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 6, - "metadata": {}, - "outputs": [], - "source": [ - "import tensorflow as tf\n", - "import random" - ] - }, - { - "cell_type": "code", - "execution_count": 29, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "['🟥', '🟦', '🟧', '🟨', '🟩', '🟪', '🟫']\n" - ] - } - ], - "source": [ - "squares = [chr(i) for i in range(0x1F7E5, 0x1F7EC)]\n", - "tfsquares = tf.constant(squares)\n", - "print(squares)" - ] - }, - { - "cell_type": "code", - "execution_count": 51, - "metadata": {}, - "outputs": [], - "source": [ - "def tensor_to_string(tensor):\n", - " square_select = tf.math.argmax(tensor, 2)\n", - "\n", - " tfstring = tf.strings.join(\n", - " tf.map_fn(\n", - " lambda v:\n", - " tf.strings.join(\n", - " tf.gather(tfsquares, tf.cast(v, tf.int64))), square_select, fn_output_signature=tf.string), \"\\n\")\n", - "\n", - " return tfstring.numpy().decode()\n" - ] - }, - { - "cell_type": "code", - "execution_count": 95, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "tf.Tensor([0. 0. 0. 1. 0. 0. 0.], shape=(7,), dtype=float32)\n", - "tf.Tensor([[3]], shape=(1, 1), dtype=int64)\n", - "🟦🟪🟦🟫🟦🟧🟨🟨\n", - "🟦🟫🟧🟫🟪🟥🟧🟪\n", - "🟫🟩🟨🟧🟨🟫🟪🟨\n", - "🟧🟦🟨🟪🟦🟨🟥🟨\n", - "🟫🟧🟧🟨🟧🟧🟩🟪\n", - "🟦🟦🟨🟦🟥🟫🟧🟩\n", - "🟫🟩🟧🟫🟨🟫🟥🟩\n", - "🟧🟫🟫🟩🟧🟧🟫🟩\n", - "--------------\n", - "tf.Tensor(\n", - "[[0. 0. 0. 0. 0. 0. 0. 0.]\n", - " [0. 0. 0. 0. 0. 0. 0. 0.]\n", - " [0. 0. 0. 0. 0. 0. 0. 0.]\n", - " [0. 0. 1. 0. 0. 1. 0. 0.]\n", - " [0. 0. 1. 0. 0. 1. 0. 0.]\n", - " [0. 0. 0. 0. 0. 0. 0. 0.]\n", - " [0. 0. 0. 0. 0. 0. 0. 0.]\n", - " [0. 1. 0. 0. 0. 0. 1. 0.]], shape=(8, 8), dtype=float32)\n" - ] - } - ], - "source": [ - "oh = tf.one_hot(3, 7)\n", - "print(oh)\n", - "print(tf.where(oh))\n", - "\n", - "# a = tf.constant([[list(tf.one_hot(random.randrange(7), 7)) for i in range(8) ] for i in range(8)], shape=(8,8,len(squares)))\n", - "\n", - "# print([[list(tf.one_hot(random.randrange(7), 7)) for i in range(8) ] for i in range(8)])\n", - "e = [[random.randrange(7) for i in range(8)] for i in range(8)]\n", - "# print(e)\n", - "\n", - "a = tf.one_hot(tf.constant(e, shape=(8, 8)), len(squares))\n", - "\n", - "print(tensor_to_string(a))\n", - "print(\"--------------\")\n", - "# print(tensor_to_string(tf.transpose(a, [1, 0, 2])))\n", - "# print(\"--------------\")\n", - "# print(tensor_to_string(tf.linalg.matmul( a , tf.constant(\n", - "# [\n", - "# [0, 0, 0, 0, 0, 0, 0, 1],\n", - "# [0, 0, 0, 0, 0, 0, 1, 0],\n", - "# [0, 0, 0, 0, 0, 1, 0, 0],\n", - "# [0, 0, 0, 0, 1, 0, 0, 0],\n", - "# [0, 0, 0, 1, 0, 0, 0, 0],\n", - "# [0, 0, 1, 0, 0, 0, 0, 0],\n", - "# [0, 1, 0, 0, 0, 0, 0, 0],\n", - "# [1, 0, 0, 0, 0, 0, 0, 0]\n", - "# ], dtype=tf.float32, shape=(8,8,1)\n", - "# ), transpose_a=True)))\n", - "# print(\"--------------\")\n", - "# print(tensor_to_string(tf.reverse(a, (0,))))\n", - "# print(\"--------------\")\n", - "# print(tensor_to_string(tf.reverse(a, (1,))))\n", - "# print(tensor_to_string(\n", - "# tf.linalg.matmul(\n", - "# tf.transpose(a, [0, 2, 1]),\n", - "# tf.reverse(a, (1,)))\n", - "# ))\n", - "print(\n", - " tf.einsum(\n", - " \"ijk,ijk->ij\",\n", - " a,\n", - " tf.reverse(a, (1,)))\n", - ")\n", - "\n", - "\n", - "# print(tf.constant(range(6), shape=(2,3)))\n", - "# tf.map_fn(tf.math.reduce_sum, tf.constant(range(6), shape=(2,3)))\n" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": ".patch_gen_venv", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.0" - }, - "orig_nbformat": 4 - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/polymap.py b/polymap.py new file mode 100644 index 0000000..aa7ecad --- /dev/null +++ b/polymap.py @@ -0,0 +1,28 @@ +import tensorflow as tf +# from tensorflow import Layer + + +class Polymap(tf.keras.layers.Layer): + def __init__(self, num_poly, poly_deg): + super(Polymap, self).__init__() + self.num_poly = num_poly + self.poly_deg = poly_deg + self.exponents = tf.constant( + [range(self.poly_deg + 1), range(self.poly_deg + 1)], dtype=tf.float32) + self.kernels = self.add_weight("kernels", + shape=[ + self.num_poly, self.poly_deg + 1, self.poly_deg + 1], + trainable=True) + + def build(self, input_shape): + if input_shape[-1] != 2: + raise "Input shape must be Ix2" + + + def call(self, input): + powers = tf.map_fn( + lambda v: tf.pow( + tf.repeat(tf.reshape(v, shape=(2, 1)), + repeats=self.poly_deg + 1, axis=1), + self.exponents), input) + return tf.einsum("ik,ij,lkj->il", powers[:, 0], powers[:, 1], self.kernels)