something
parent
4091c0e79e
commit
1d52c5a78e
@ -1,5 +1,65 @@
|
|||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
def my_loss_fn(y_true, y_pred):
|
from colors import vector_to_index
|
||||||
|
|
||||||
return tf.reduce_mean(squared_difference, axis=-1) # Note the `axis=-1`
|
|
||||||
|
@tf.function
|
||||||
|
def aesthetic_loss(y_pics):
|
||||||
|
|
||||||
|
flipx = tf.reverse(y_pics, axis=(-2,))
|
||||||
|
flipy = tf.reverse(y_pics, axis=(-3,))
|
||||||
|
center = tf.reverse(y_pics, axis=(-2, -3,))
|
||||||
|
trans = tf.transpose(y_pics, perm=(0, 2, 1, 3))
|
||||||
|
anti_trans = tf.reverse(tf.transpose(tf.reverse(
|
||||||
|
y_pics, axis=(-2,)), perm=(0, 2, 1, 3)), axis=(-2,))
|
||||||
|
|
||||||
|
flipx_score = compute_score(y_pics, flipx)
|
||||||
|
flipy_score = compute_score(y_pics, flipy)
|
||||||
|
trans_score = compute_score(y_pics, trans)
|
||||||
|
anti_trans_score = compute_score(y_pics, anti_trans)
|
||||||
|
center_score = compute_score(y_pics, center)
|
||||||
|
|
||||||
|
# Extract patches for neighborhood check
|
||||||
|
patches = tf.image.extract_patches(
|
||||||
|
y_pics, (1, 3, 3, 1), (1, 1, 1, 1), (1, 1, 1, 1), padding="SAME")
|
||||||
|
# Dot product with central pixel
|
||||||
|
# neigh = tf.reduce_sum(
|
||||||
|
# tf.maximum(0.0, -tf.einsum("...k,...k->...", tf.repeat(y_pics, 9, -1), patches)))
|
||||||
|
neigh = tf.reduce_sum(
|
||||||
|
tf.abs(tf.einsum("...k,...k->...", tf.repeat(y_pics, 9, -1), patches)))
|
||||||
|
|
||||||
|
divers = tf.reduce_sum(
|
||||||
|
tf.square(tf.einsum("ijkl,mjkl->imjk", y_pics, y_pics)))
|
||||||
|
|
||||||
|
# print(flipx_score)
|
||||||
|
# print( tf.reduce_max(tf.stack([flipx_score, flipy_score, trans_score, anti_trans_score, center_score]), axis=(0,)))
|
||||||
|
|
||||||
|
# idxs = vector_to_index(y_pics)
|
||||||
|
# ncol = tf.map_fn(lambda p: tf.size(
|
||||||
|
# tf.unique(tf.reshape(p, [-1]),)[0], out_type=tf.int64), idxs)
|
||||||
|
# print(ncol)
|
||||||
|
|
||||||
|
return tf.reduce_sum([
|
||||||
|
-1.0/(64) * tf.reduce_sum(tf.reduce_max(tf.stack([flipx_score, flipy_score,
|
||||||
|
trans_score, anti_trans_score, center_score]), axis=(0,))),
|
||||||
|
-1.0/(64*9) * 0.1 * neigh,
|
||||||
|
# -1.0/7 * 10 * tf.cast(ncol, tf.float32),
|
||||||
|
# 1.0/64 * 0.01 * tf.reduce_sum(tf.square(y_pics)),
|
||||||
|
1.0/64 * 10 * divers
|
||||||
|
], )
|
||||||
|
|
||||||
|
|
||||||
|
@tf.function
|
||||||
|
def compute_score(pic1, pic2):
|
||||||
|
return tf.reduce_sum(
|
||||||
|
tf.abs(
|
||||||
|
tf.einsum(
|
||||||
|
"...k,...k->...",
|
||||||
|
pic1,
|
||||||
|
pic2)), axis=(-2, -1))
|
||||||
|
# return tf.reduce_sum(
|
||||||
|
# -tf.math.maximum(0.0,
|
||||||
|
# -tf.einsum(
|
||||||
|
# "...k,...k->...",
|
||||||
|
# pic1,
|
||||||
|
# pic2)), axis=(-2, -1))
|
||||||
|
@ -0,0 +1,31 @@
|
|||||||
|
from math import pi
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
squares = [chr(i) for i in range(0x1F7E5, 0x1F7EC)]
|
||||||
|
tfsquares = tf.constant(squares)
|
||||||
|
colors = tf.constant([10, 219, 38, 48, 80, 282, 20] *
|
||||||
|
tf.constant(pi/180.0), dtype=tf.float32)
|
||||||
|
color_vectors = tf.transpose(
|
||||||
|
tf.stack([
|
||||||
|
tf.math.cos(colors),
|
||||||
|
tf.math.sin(colors)]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def vector_to_index(tensor):
|
||||||
|
return tf.argmax((tf.einsum("...ijk,lk->...ijl", tensor, color_vectors)), axis=-1)
|
||||||
|
|
||||||
|
|
||||||
|
def index_to_string(tensor):
|
||||||
|
tfstring = tf.strings.join(
|
||||||
|
tf.map_fn(
|
||||||
|
lambda v:
|
||||||
|
tf.strings.join(
|
||||||
|
tf.gather(tfsquares, tf.cast(v, tf.int64))), tensor, fn_output_signature=tf.string), "\n")
|
||||||
|
|
||||||
|
return tfstring.numpy().decode()
|
||||||
|
|
||||||
|
|
||||||
|
def vector_to_string(tensor):
|
||||||
|
return index_to_string(vector_to_index(tensor))
|
@ -0,0 +1,398 @@
|
|||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 8,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"The autoreload extension is already loaded. To reload it, use:\n",
|
||||||
|
" %reload_ext autoreload\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"%load_ext autoreload\n",
|
||||||
|
"%autoreload 2"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 9,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"import tensorflow as tf\n",
|
||||||
|
"from tensorflow.keras.layers import Dense, Conv2D, Reshape\n",
|
||||||
|
"from polymap import Polymap\n",
|
||||||
|
"from aesthetic_loss import aesthetic_loss"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 46,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"['🟥', '🟦', '🟧', '🟨', '🟩', '🟪', '🟫']\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"from math import pi\n",
|
||||||
|
"squares = [chr(i) for i in range(0x1F7E5, 0x1F7EC)]\n",
|
||||||
|
"tfsquares = tf.constant(squares)\n",
|
||||||
|
"print(squares)\n",
|
||||||
|
"colors = tf.constant([10, 219, 38, 48, 80, 282, 20] * tf.constant(pi/180.0), dtype=tf.float32)\n",
|
||||||
|
"color_vectors = tf.transpose(\n",
|
||||||
|
" tf.stack([\n",
|
||||||
|
" tf.math.cos(colors),\n",
|
||||||
|
" tf.math.sin(colors)]\n",
|
||||||
|
" )\n",
|
||||||
|
")\n",
|
||||||
|
"\n",
|
||||||
|
"def vector_to_index(tensor):\n",
|
||||||
|
" return tf.argmax((tf.einsum(\"ijk,lk->ijl\", tensor, color_vectors)), axis=-1)\n",
|
||||||
|
"\n",
|
||||||
|
"def index_to_string(tensor):\n",
|
||||||
|
" # square_select = tf.math.argmax(tensor, 2)\n",
|
||||||
|
"\n",
|
||||||
|
" tfstring = tf.strings.join(\n",
|
||||||
|
" tf.map_fn(\n",
|
||||||
|
" lambda v:\n",
|
||||||
|
" tf.strings.join(\n",
|
||||||
|
" tf.gather(tfsquares, tf.cast(v, tf.int64))), tensor, fn_output_signature=tf.string), \"\\n\")\n",
|
||||||
|
"\n",
|
||||||
|
" return tfstring.numpy().decode()"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 10,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"batch_size = 16"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 51,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# class PatchGen(tf.keras.Model):\n",
|
||||||
|
"\n",
|
||||||
|
"# def __init__(self, batch_size):\n",
|
||||||
|
"# super(PatchGen, self).__init__()\n",
|
||||||
|
"# self.pm = Polymap(8, 3)(input)\n",
|
||||||
|
"# self.d1 = Dense(64)\n",
|
||||||
|
"# self.rs = Reshape([8,8,1])\n",
|
||||||
|
"# self.dc1 = Conv2D(3,(1,1))\n",
|
||||||
|
"\n",
|
||||||
|
"# def call(self, inputs):\n",
|
||||||
|
"# x = self.block_1(inputs)\n",
|
||||||
|
"# x = self.block_2(x)\n",
|
||||||
|
"# x = self.global_pool(x)\n",
|
||||||
|
"# return self.classifier(x)\n",
|
||||||
|
"\n",
|
||||||
|
"input = tf.random.normal(shape=[batch_size,2])\n",
|
||||||
|
"pm = Polymap(8, 3)\n",
|
||||||
|
"d1 = Dense(64, activation=\"relu\")\n",
|
||||||
|
"d2 = Dense(64, activation=\"relu\")\n",
|
||||||
|
"rs = Reshape([8,8,1])\n",
|
||||||
|
"dc1 = Conv2D(2,(1,1))\n",
|
||||||
|
"\n",
|
||||||
|
"def gen():\n",
|
||||||
|
" input = tf.random.normal(shape=[batch_size,2])\n",
|
||||||
|
" x = pm(input)\n",
|
||||||
|
" x = d1(x)\n",
|
||||||
|
" x = d2(x)\n",
|
||||||
|
" x = rs(x)\n",
|
||||||
|
" return dc1(x)\n",
|
||||||
|
"# gen()"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 52,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"optimizer = tf.keras.optimizers.Adam()\n",
|
||||||
|
"train_loss = tf.keras.metrics.Mean(name='train_loss')"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 53,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"@tf.function\n",
|
||||||
|
"def train_step():\n",
|
||||||
|
" with tf.GradientTape() as tape:\n",
|
||||||
|
" pics = gen()\n",
|
||||||
|
" loss = aesthetic_loss(pics)\n",
|
||||||
|
" tv = sum([pm.trainable_weights, d1.trainable_weights,\n",
|
||||||
|
" dc1.trainable_weights], [])\n",
|
||||||
|
" print(tv)\n",
|
||||||
|
" gradients = tape.gradient(loss, tv)\n",
|
||||||
|
" optimizer.apply_gradients(zip(gradients, tv))\n",
|
||||||
|
"\n",
|
||||||
|
" train_loss(loss)\n",
|
||||||
|
" # train_accuracy(labels, predictions)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 57,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Epoch 1, Loss: 0.0004605448921211064, \n",
|
||||||
|
"🟩🟩🟩🟩🟩🟩🟩🟩\n",
|
||||||
|
"🟩🟩🟩🟩🟩🟩🟩🟩\n",
|
||||||
|
"🟩🟩🟩🟩🟩🟩🟩🟩\n",
|
||||||
|
"🟩🟩🟩🟩🟩🟩🟩🟩\n",
|
||||||
|
"🟩🟩🟩🟩🟩🟩🟩🟩\n",
|
||||||
|
"🟩🟩🟩🟩🟩🟩🟩🟩\n",
|
||||||
|
"🟩🟩🟩🟩🟩🟩🟩🟩\n",
|
||||||
|
"🟩🟩🟩🟩🟩🟩🟩🟩\n",
|
||||||
|
"Epoch 2, Loss: 9.45829197007697e-06, \n",
|
||||||
|
"🟥🟥🟥🟥🟥🟥🟥🟥\n",
|
||||||
|
"🟥🟥🟥🟥🟥🟥🟥🟥\n",
|
||||||
|
"🟥🟥🟥🟥🟥🟥🟥🟥\n",
|
||||||
|
"🟥🟥🟥🟥🟥🟥🟥🟥\n",
|
||||||
|
"🟥🟥🟥🟥🟥🟥🟥🟥\n",
|
||||||
|
"🟥🟥🟥🟥🟥🟥🟥🟥\n",
|
||||||
|
"🟥🟥🟥🟥🟥🟥🟥🟥\n",
|
||||||
|
"🟥🟥🟥🟥🟥🟥🟥🟥\n",
|
||||||
|
"Epoch 3, Loss: 8.17671389086172e-05, \n",
|
||||||
|
"🟦🟦🟦🟦🟦🟦🟦🟦\n",
|
||||||
|
"🟦🟦🟦🟦🟦🟦🟦🟦\n",
|
||||||
|
"🟦🟦🟦🟦🟦🟦🟦🟦\n",
|
||||||
|
"🟦🟦🟦🟦🟦🟦🟦🟦\n",
|
||||||
|
"🟦🟦🟦🟦🟦🟦🟦🟦\n",
|
||||||
|
"🟦🟦🟦🟦🟦🟦🟦🟦\n",
|
||||||
|
"🟦🟦🟦🟦🟦🟦🟦🟦\n",
|
||||||
|
"🟦🟦🟦🟦🟦🟦🟦🟦\n",
|
||||||
|
"Epoch 4, Loss: 0.0001505890249973163, \n",
|
||||||
|
"🟦🟦🟦🟦🟦🟦🟦🟦\n",
|
||||||
|
"🟦🟦🟦🟦🟦🟦🟦🟦\n",
|
||||||
|
"🟦🟦🟦🟦🟦🟦🟦🟦\n",
|
||||||
|
"🟦🟦🟦🟦🟦🟦🟦🟦\n",
|
||||||
|
"🟦🟦🟦🟦🟦🟦🟦🟦\n",
|
||||||
|
"🟦🟦🟦🟦🟦🟦🟦🟦\n",
|
||||||
|
"🟦🟦🟦🟦🟦🟦🟦🟦\n",
|
||||||
|
"🟦🟦🟦🟦🟦🟦🟦🟦\n",
|
||||||
|
"Epoch 5, Loss: 0.00018983485642820597, \n",
|
||||||
|
"🟩🟩🟩🟩🟩🟩🟩🟩\n",
|
||||||
|
"🟩🟩🟩🟩🟩🟩🟩🟩\n",
|
||||||
|
"🟩🟩🟩🟩🟩🟩🟩🟩\n",
|
||||||
|
"🟩🟩🟩🟩🟩🟩🟩🟩\n",
|
||||||
|
"🟩🟩🟩🟩🟩🟩🟩🟩\n",
|
||||||
|
"🟩🟩🟩🟩🟩🟩🟩🟩\n",
|
||||||
|
"🟩🟩🟩🟩🟩🟩🟩🟩\n",
|
||||||
|
"🟩🟩🟩🟩🟩🟩🟩🟩\n",
|
||||||
|
"Epoch 6, Loss: 0.00012902110756840557, \n",
|
||||||
|
"🟦🟦🟦🟦🟦🟦🟦🟦\n",
|
||||||
|
"🟦🟦🟦🟦🟦🟦🟦🟦\n",
|
||||||
|
"🟦🟦🟦🟦🟦🟦🟦🟦\n",
|
||||||
|
"🟦🟦🟦🟦🟦🟦🟦🟦\n",
|
||||||
|
"🟦🟦🟦🟦🟦🟦🟦🟦\n",
|
||||||
|
"🟦🟦🟦🟦🟦🟦🟦🟦\n",
|
||||||
|
"🟦🟦🟦🟦🟦🟦🟦🟦\n",
|
||||||
|
"🟦🟦🟦🟦🟦🟦🟦🟦\n",
|
||||||
|
"Epoch 7, Loss: 1.342521863989532e-05, \n",
|
||||||
|
"🟨🟨🟨🟨🟨🟨🟨🟨\n",
|
||||||
|
"🟨🟨🟨🟨🟨🟨🟨🟨\n",
|
||||||
|
"🟨🟨🟨🟨🟨🟨🟨🟨\n",
|
||||||
|
"🟨🟨🟨🟨🟨🟨🟨🟨\n",
|
||||||
|
"🟨🟨🟨🟨🟨🟨🟨🟨\n",
|
||||||
|
"🟨🟨🟨🟨🟨🟨🟨🟨\n",
|
||||||
|
"🟨🟨🟨🟨🟨🟨🟨🟨\n",
|
||||||
|
"🟨🟨🟨🟨🟨🟨🟨🟨\n",
|
||||||
|
"Epoch 8, Loss: 0.00034237385261803865, \n",
|
||||||
|
"🟩🟩🟩🟩🟩🟩🟩🟩\n",
|
||||||
|
"🟩🟩🟩🟩🟩🟩🟩🟩\n",
|
||||||
|
"🟩🟩🟩🟩🟩🟩🟩🟩\n",
|
||||||
|
"🟩🟩🟩🟩🟩🟩🟩🟩\n",
|
||||||
|
"🟩🟩🟩🟩🟩🟩🟩🟩\n",
|
||||||
|
"🟩🟩🟩🟩🟩🟩🟩🟩\n",
|
||||||
|
"🟩🟩🟩🟩🟩🟩🟩🟩\n",
|
||||||
|
"🟩🟩🟩🟩🟩🟩🟩🟩\n",
|
||||||
|
"Epoch 9, Loss: 8.251221879618242e-05, \n",
|
||||||
|
"🟦🟦🟦🟦🟦🟦🟦🟦\n",
|
||||||
|
"🟦🟦🟦🟦🟦🟦🟦🟦\n",
|
||||||
|
"🟦🟦🟦🟦🟦🟦🟦🟦\n",
|
||||||
|
"🟦🟦🟦🟦🟦🟦🟦🟦\n",
|
||||||
|
"🟦🟦🟦🟦🟦🟦🟦🟦\n",
|
||||||
|
"🟦🟦🟦🟦🟦🟦🟦🟦\n",
|
||||||
|
"🟦🟦🟦🟦🟦🟦🟦🟦\n",
|
||||||
|
"🟦🟦🟦🟦🟦🟦🟦🟦\n",
|
||||||
|
"Epoch 10, Loss: 1.0330303013006414e-07, \n",
|
||||||
|
"🟧🟧🟧🟧🟧🟧🟧🟧\n",
|
||||||
|
"🟧🟧🟧🟧🟧🟧🟧🟧\n",
|
||||||
|
"🟧🟧🟧🟧🟧🟧🟧🟧\n",
|
||||||
|
"🟧🟧🟧🟧🟧🟧🟧🟧\n",
|
||||||
|
"🟧🟧🟧🟧🟧🟧🟧🟧\n",
|
||||||
|
"🟧🟧🟧🟧🟧🟧🟧🟧\n",
|
||||||
|
"🟧🟧🟧🟧🟧🟧🟧🟧\n",
|
||||||
|
"🟧🟧🟧🟧🟧🟧🟧🟧\n",
|
||||||
|
"Epoch 11, Loss: 3.0386236176127568e-05, \n",
|
||||||
|
"🟥🟥🟥🟥🟥🟥🟥🟥\n",
|
||||||
|
"🟥🟥🟥🟥🟥🟥🟥🟥\n",
|
||||||
|
"🟥🟥🟥🟥🟥🟥🟥🟥\n",
|
||||||
|
"🟥🟥🟥🟥🟥🟥🟥🟥\n",
|
||||||
|
"🟥🟥🟥🟥🟥🟥🟥🟥\n",
|
||||||
|
"🟥🟥🟥🟥🟥🟥🟥🟥\n",
|
||||||
|
"🟥🟥🟥🟥🟥🟥🟥🟥\n",
|
||||||
|
"🟥🟥🟥🟥🟥🟥🟥🟥\n",
|
||||||
|
"Epoch 12, Loss: 4.6750748879276216e-05, \n",
|
||||||
|
"🟦🟦🟦🟦🟦🟦🟦🟦\n",
|
||||||
|
"🟦🟦🟦🟦🟦🟦🟦🟦\n",
|
||||||
|
"🟦🟦🟦🟦🟦🟦🟦🟦\n",
|
||||||
|
"🟦🟦🟦🟦🟦🟦🟦🟦\n",
|
||||||
|
"🟦🟦🟦🟦🟦🟦🟦🟦\n",
|
||||||
|
"🟦🟦🟦🟦🟦🟦🟦🟦\n",
|
||||||
|
"🟦🟦🟦🟦🟦🟦🟦🟦\n",
|
||||||
|
"🟦🟦🟦🟦🟦🟦🟦🟦\n",
|
||||||
|
"Epoch 13, Loss: 7.047886185773677e-08, \n",
|
||||||
|
"🟦🟦🟦🟦🟦🟦🟦🟦\n",
|
||||||
|
"🟦🟦🟦🟦🟦🟦🟦🟦\n",
|
||||||
|
"🟦🟦🟦🟦🟦🟦🟦🟦\n",
|
||||||
|
"🟦🟦🟦🟦🟦🟦🟦🟦\n",
|
||||||
|
"🟦🟦🟦🟦🟦🟦🟦🟦\n",
|
||||||
|
"🟦🟦🟦🟦🟦🟦🟦🟦\n",
|
||||||
|
"🟦🟦🟦🟦🟦🟦🟦🟦\n",
|
||||||
|
"🟦🟦🟦🟦🟦🟦🟦🟦\n",
|
||||||
|
"Epoch 14, Loss: 9.711433317672463e-14, \n",
|
||||||
|
"🟥🟥🟥🟥🟥🟥🟥🟥\n",
|
||||||
|
"🟥🟥🟥🟥🟥🟥🟥🟥\n",
|
||||||
|
"🟥🟥🟥🟥🟥🟥🟥🟥\n",
|
||||||
|
"🟥🟥🟥🟥🟥🟥🟥🟥\n",
|
||||||
|
"🟥🟥🟥🟥🟥🟥🟥🟥\n",
|
||||||
|
"🟥🟥🟥🟥🟥🟥🟥🟥\n",
|
||||||
|
"🟥🟥🟥🟥🟥🟥🟥🟥\n",
|
||||||
|
"🟥🟥🟥🟥🟥🟥🟥🟥\n",
|
||||||
|
"Epoch 15, Loss: 7.075177338602006e-23, \n",
|
||||||
|
"🟪🟪🟪🟪🟪🟪🟪🟪\n",
|
||||||
|
"🟪🟪🟪🟪🟪🟪🟪🟪\n",
|
||||||
|
"🟪🟪🟪🟪🟪🟪🟪🟪\n",
|
||||||
|
"🟪🟪🟪🟪🟪🟪🟪🟪\n",
|
||||||
|
"🟪🟪🟪🟪🟪🟪🟪🟪\n",
|
||||||
|
"🟪🟪🟪🟪🟪🟪🟪🟪\n",
|
||||||
|
"🟪🟪🟪🟪🟪🟪🟪🟪\n",
|
||||||
|
"🟪🟪🟪🟪🟪🟪🟪🟪\n",
|
||||||
|
"Epoch 16, Loss: 0.11943158507347107, \n",
|
||||||
|
"🟨🟨🟨🟨🟨🟨🟨🟨\n",
|
||||||
|
"🟨🟨🟨🟨🟨🟨🟨🟨\n",
|
||||||
|
"🟨🟨🟨🟨🟨🟨🟨🟨\n",
|
||||||
|
"🟨🟨🟨🟨🟨🟨🟨🟨\n",
|
||||||
|
"🟨🟨🟨🟨🟨🟨🟨🟨\n",
|
||||||
|
"🟨🟨🟨🟨🟨🟨🟨🟨\n",
|
||||||
|
"🟨🟨🟨🟨🟨🟨🟨🟨\n",
|
||||||
|
"🟨🟨🟨🟨🟨🟨🟨🟨\n",
|
||||||
|
"Epoch 17, Loss: 0.002167836995795369, \n",
|
||||||
|
"🟨🟨🟨🟨🟨🟨🟨🟨\n",
|
||||||
|
"🟨🟨🟨🟨🟨🟨🟨🟨\n",
|
||||||
|
"🟨🟨🟨🟨🟨🟨🟨🟨\n",
|
||||||
|
"🟨🟨🟨🟨🟨🟨🟨🟨\n",
|
||||||
|
"🟨🟨🟨🟨🟨🟨🟨🟨\n",
|
||||||
|
"🟨🟨🟨🟨🟨🟨🟨🟨\n",
|
||||||
|
"🟨🟨🟨🟨🟨🟨🟨🟨\n",
|
||||||
|
"🟨🟨🟨🟨🟨🟨🟨🟨\n",
|
||||||
|
"Epoch 18, Loss: 6.590542034246027e-05, \n",
|
||||||
|
"🟧🟧🟧🟧🟧🟧🟧🟧\n",
|
||||||
|
"🟧🟧🟧🟧🟧🟧🟧🟧\n",
|
||||||
|
"🟧🟧🟧🟧🟧🟧🟧🟧\n",
|
||||||
|
"🟧🟧🟧🟧🟧🟧🟧🟧\n",
|
||||||
|
"🟧🟧🟧🟧🟧🟧🟧🟧\n",
|
||||||
|
"🟧🟧🟧🟧🟧🟧🟧🟧\n",
|
||||||
|
"🟧🟧🟧🟧🟧🟧🟧🟧\n",
|
||||||
|
"🟧🟧🟧🟧🟧🟧🟧🟧\n",
|
||||||
|
"Epoch 19, Loss: 0.0017910723108798265, \n",
|
||||||
|
"🟩🟩🟩🟩🟩🟩🟩🟩\n",
|
||||||
|
"🟩🟩🟩🟩🟩🟩🟩🟩\n",
|
||||||
|
"🟩🟩🟩🟩🟩🟩🟩🟩\n",
|
||||||
|
"🟩🟩🟩🟩🟩🟩🟩🟩\n",
|
||||||
|
"🟩🟩🟩🟩🟩🟩🟩🟩\n",
|
||||||
|
"🟩🟩🟩🟩🟩🟩🟩🟩\n",
|
||||||
|
"🟩🟩🟩🟩🟩🟩🟩🟩\n",
|
||||||
|
"🟩🟩🟩🟩🟩🟩🟩🟩\n",
|
||||||
|
"Epoch 20, Loss: 0.004703103564679623, \n",
|
||||||
|
"🟦🟦🟦🟦🟦🟦🟦🟦\n",
|
||||||
|
"🟦🟦🟦🟦🟦🟦🟦🟦\n",
|
||||||
|
"🟦🟦🟦🟦🟦🟦🟦🟦\n",
|
||||||
|
"🟦🟦🟦🟦🟦🟦🟦🟦\n",
|
||||||
|
"🟦🟦🟦🟦🟦🟦🟦🟦\n",
|
||||||
|
"🟦🟦🟦🟦🟦🟦🟦🟦\n",
|
||||||
|
"🟦🟦🟦🟦🟦🟦🟦🟦\n",
|
||||||
|
"🟦🟦🟦🟦🟦🟦🟦🟦\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"EPOCHS = 50\n",
|
||||||
|
"\n",
|
||||||
|
"for epoch in range(EPOCHS):\n",
|
||||||
|
" # Reset the metrics at the start of the next epoch\n",
|
||||||
|
" train_loss.reset_states()\n",
|
||||||
|
"# train_accuracy.reset_states()\n",
|
||||||
|
"# test_loss.reset_states()\n",
|
||||||
|
"# test_accuracy.reset_states()\n",
|
||||||
|
"\n",
|
||||||
|
"# for images, labels in train_ds:\n",
|
||||||
|
" for i in range(50):\n",
|
||||||
|
" train_step()\n",
|
||||||
|
"\n",
|
||||||
|
"# for test_images, test_labels in test_ds:\n",
|
||||||
|
"# test_step(test_images, test_labels)\n",
|
||||||
|
"\n",
|
||||||
|
" print(\n",
|
||||||
|
" f'Epoch {epoch + 1}, '\n",
|
||||||
|
" f'Loss: {train_loss.result()}, '\n",
|
||||||
|
" # f'Accuracy: {train_accuracy.result() * 100}, '\n",
|
||||||
|
" # f'Test Loss: {test_loss.result()}, '\n",
|
||||||
|
" # f'Test Accuracy: {test_accuracy.result() * 100}'\n",
|
||||||
|
" )\n",
|
||||||
|
"\n",
|
||||||
|
" ps = gen()\n",
|
||||||
|
" # print(ps[0])\n",
|
||||||
|
" print(index_to_string(vector_to_index(ps[0])))"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": ".venv_patch_gen",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 3
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython3",
|
||||||
|
"version": "3.10.9"
|
||||||
|
},
|
||||||
|
"orig_nbformat": 4
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 2
|
||||||
|
}
|
@ -1,4 +1,127 @@
|
|||||||
from tensorflow import keras
|
# autopep8: off
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
input = keras.Input((2,))
|
physical_devices = tf.config.experimental.list_physical_devices('GPU')
|
||||||
|
print(physical_devices)
|
||||||
|
tf.config.experimental.set_memory_growth(physical_devices[0], True)
|
||||||
|
|
||||||
|
import colors
|
||||||
|
from aesthetic_loss import aesthetic_loss
|
||||||
|
from polymap import Polymap
|
||||||
|
from tensorflow.keras.layers import Dense, Conv2D, Reshape, ReLU, Conv2DTranspose, LeakyReLU, BatchNormalization, GaussianNoise
|
||||||
|
# autopep8: on
|
||||||
|
|
||||||
|
batch_size = 256
|
||||||
|
|
||||||
|
|
||||||
|
input_small = tf.random.normal(shape=[4, 2])
|
||||||
|
pm = Polymap(8, 3)
|
||||||
|
rl = ReLU()
|
||||||
|
d1 = Dense(256, activation="relu",
|
||||||
|
kernel_regularizer=tf.keras.regularizers.l1_l2())
|
||||||
|
d2 = Dense(512, activation="relu",
|
||||||
|
kernel_regularizer=tf.keras.regularizers.l1_l2())
|
||||||
|
d3 = Dense(1024, activation="relu",
|
||||||
|
kernel_regularizer=tf.keras.regularizers.l1_l2())
|
||||||
|
rs = Reshape([2, 2, 1024//4])
|
||||||
|
dc1 = Conv2DTranspose(256, (5, 5),
|
||||||
|
padding="valid", use_bias=False)
|
||||||
|
dc2 = Conv2DTranspose(128, (3, 3),
|
||||||
|
padding="valid", use_bias=False)
|
||||||
|
dc3 = Conv2DTranspose(64, (3, 3),
|
||||||
|
padding="same", use_bias=False)
|
||||||
|
dc4 = Conv2D(2, (1, 1), activation="tanh", use_bias=False)
|
||||||
|
bn1 = BatchNormalization()
|
||||||
|
bn2 = BatchNormalization()
|
||||||
|
bn3 = BatchNormalization()
|
||||||
|
|
||||||
|
def gen_model():
|
||||||
|
model = tf.keras.Sequential()
|
||||||
|
|
||||||
|
model.add(pm)
|
||||||
|
# model.add(rl)
|
||||||
|
|
||||||
|
model.add(d1)
|
||||||
|
model.add(d2)
|
||||||
|
model.add(d3)
|
||||||
|
|
||||||
|
model.add(rs)
|
||||||
|
|
||||||
|
model.add(dc1)
|
||||||
|
model.add(bn1)
|
||||||
|
model.add(LeakyReLU())
|
||||||
|
|
||||||
|
model.add(dc2)
|
||||||
|
model.add(bn2)
|
||||||
|
model.add(LeakyReLU())
|
||||||
|
|
||||||
|
model.add(dc3)
|
||||||
|
model.add(bn3)
|
||||||
|
model.add(LeakyReLU())
|
||||||
|
|
||||||
|
model.add(dc4)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
optimizer = tf.keras.optimizers.Adam(learning_rate=0.001, epsilon=0.001)
|
||||||
|
train_loss = tf.keras.metrics.Mean(name='train_loss')
|
||||||
|
|
||||||
|
generator = gen_model()
|
||||||
|
|
||||||
|
rng = tf.random.Generator.from_seed(1)
|
||||||
|
|
||||||
|
@tf.function
|
||||||
|
def train_step(input):
|
||||||
|
# input = tf.random.normal(shape=[batch_size, 2])
|
||||||
|
# input = rng.normal(shape=[batch_size, 2])
|
||||||
|
|
||||||
|
# print(input)
|
||||||
|
|
||||||
|
with tf.GradientTape() as tape:
|
||||||
|
# pics = gen(input)
|
||||||
|
pics = generator(input, training=True)
|
||||||
|
loss = aesthetic_loss(
|
||||||
|
pics) + tf.reduce_sum([d1.losses, d2.losses, d3.losses])
|
||||||
|
+ 0.01 * tf.reduce_mean(tf.abs(pm.trainable_weights))
|
||||||
|
# tv = sum([pm.trainable_weights,
|
||||||
|
# d1.trainable_weights,
|
||||||
|
# d2.trainable_weights,
|
||||||
|
# d3.trainable_weights,
|
||||||
|
# dc1.trainable_weights,
|
||||||
|
# dc2.trainable_weights,
|
||||||
|
# dc3.trainable_weights,], [])
|
||||||
|
|
||||||
|
# print(tv)
|
||||||
|
gradients = tape.gradient(loss, generator.trainable_weights)
|
||||||
|
optimizer.apply_gradients(zip(gradients, generator.trainable_weights))
|
||||||
|
|
||||||
|
train_loss(loss + tf.reduce_sum([d1.losses, d2.losses, d3.losses]))
|
||||||
|
|
||||||
|
|
||||||
|
EPOCHS = 200
|
||||||
|
|
||||||
|
for epoch in range(EPOCHS):
|
||||||
|
train_loss.reset_states()
|
||||||
|
|
||||||
|
input = rng.normal(shape=[batch_size, 2])
|
||||||
|
# print(input)
|
||||||
|
|
||||||
|
for i in range(100):
|
||||||
|
train_step(input)
|
||||||
|
|
||||||
|
print(
|
||||||
|
f'Epoch {epoch + 1}, '
|
||||||
|
f'Loss: {train_loss.result()}, '
|
||||||
|
)
|
||||||
|
|
||||||
|
input_small = rng.normal(shape=[4, 2])
|
||||||
|
ps = generator(input_small, training=False)
|
||||||
|
# print(ps[0])
|
||||||
|
# print(pm.kernels)
|
||||||
|
s0 = colors.vector_to_string(ps[0])
|
||||||
|
s1 = colors.vector_to_string(ps[1])
|
||||||
|
s2 = colors.vector_to_string(ps[2])
|
||||||
|
s3 = colors.vector_to_string(ps[3])
|
||||||
|
|
||||||
|
print("\n".join(" ".join(r)
|
||||||
|
for r in zip(s0.split("\n"), s1.split("\n"), s2.split("\n"), s3.split("\n"))))
|
||||||
|
@ -0,0 +1,794 @@
|
|||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 2,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"%load_ext autoreload\n",
|
||||||
|
"%autoreload 2"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 3,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"import tensorflow as tf\n",
|
||||||
|
"import random\n",
|
||||||
|
"from math import pi"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 5,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"['🟥', '🟦', '🟧', '🟨', '🟩', '🟪', '🟫']\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"<tf.Tensor: shape=(7, 2), dtype=float32, numpy=\n",
|
||||||
|
"array([[ 0.9848077 , 0.17364818],\n",
|
||||||
|
" [-0.777146 , -0.62932044],\n",
|
||||||
|
" [ 0.7880108 , 0.6156615 ],\n",
|
||||||
|
" [ 0.6691306 , 0.74314487],\n",
|
||||||
|
" [ 0.17364822, 0.9848077 ],\n",
|
||||||
|
" [ 0.20791148, -0.9781476 ],\n",
|
||||||
|
" [ 0.9396926 , 0.34202012]], dtype=float32)>"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 5,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"squares = [chr(i) for i in range(0x1F7E5, 0x1F7EC)]\n",
|
||||||
|
"tfsquares = tf.constant(squares)\n",
|
||||||
|
"print(squares)\n",
|
||||||
|
"colors = tf.constant([10, 219, 38, 48, 80, 282, 20] * tf.constant(pi/180.0), dtype=tf.float32)\n",
|
||||||
|
"color_vectors = tf.transpose(\n",
|
||||||
|
" tf.stack([\n",
|
||||||
|
" tf.math.cos(colors),\n",
|
||||||
|
" tf.math.sin(colors)]\n",
|
||||||
|
" )\n",
|
||||||
|
")\n",
|
||||||
|
"color_vectors"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 11,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"def one_hot_to_string(tensor):\n",
|
||||||
|
" square_select = tf.math.argmax(tensor, 2)\n",
|
||||||
|
"\n",
|
||||||
|
" tfstring = tf.strings.join(\n",
|
||||||
|
" tf.map_fn(\n",
|
||||||
|
" lambda v:\n",
|
||||||
|
" tf.strings.join(\n",
|
||||||
|
" tf.gather(tfsquares, tf.cast(v, tf.int64))), square_select, fn_output_signature=tf.string), \"\\n\")\n",
|
||||||
|
"\n",
|
||||||
|
" return tfstring.numpy().decode()\n",
|
||||||
|
"\n",
|
||||||
|
"def index_to_string(tensor):\n",
|
||||||
|
" return tf.argmax((tf.einsum(\"ijk,lk->ijl\", tensor, color_vectors)), axis=2)\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 57,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"<tf.Tensor: shape=(8, 8), dtype=int64, numpy=\n",
|
||||||
|
"array([[5, 0, 1, 3, 3, 2, 4, 0],\n",
|
||||||
|
" [3, 2, 3, 6, 5, 3, 0, 5],\n",
|
||||||
|
" [1, 3, 0, 1, 6, 4, 5, 3],\n",
|
||||||
|
" [3, 5, 0, 6, 5, 5, 0, 3],\n",
|
||||||
|
" [0, 3, 0, 6, 6, 1, 5, 2],\n",
|
||||||
|
" [4, 1, 4, 6, 6, 1, 0, 5],\n",
|
||||||
|
" [6, 0, 2, 5, 0, 1, 5, 4],\n",
|
||||||
|
" [0, 2, 5, 4, 1, 5, 0, 4]])>"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 57,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"vi = tf.map_fn(lambda v: tf.map_fn(lambda s: color_vectors[s],v, fn_output_signature=tf.float32),tf.constant(e, shape=(8, 8)),fn_output_signature=tf.float32)\n",
|
||||||
|
"\n",
|
||||||
|
"index_to_string(vi)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 10,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"<tf.Tensor: shape=(1, 8, 8), dtype=float32, numpy=\n",
|
||||||
|
"array([[[0.71719116, 0.04602498, 1.0273366 , 2.9402392 , 2.4437628 ,\n",
|
||||||
|
" 1.0988994 , 1.7921778 , 0.6444071 ],\n",
|
||||||
|
" [1.2672982 , 1.8689629 , 3.8761568 , 2.4126112 , 2.840665 ,\n",
|
||||||
|
" 3.0419385 , 4.8909955 , 3.583321 ],\n",
|
||||||
|
" [0.4506898 , 0.01335579, 3.258954 , 3.80125 , 6.4259176 ,\n",
|
||||||
|
" 6.3595123 , 6.609476 , 1.6973627 ],\n",
|
||||||
|
" [1.8337119 , 1.4464988 , 4.6484404 , 4.2732587 , 7.6358337 ,\n",
|
||||||
|
" 5.4232316 , 4.313001 , 1.6071305 ],\n",
|
||||||
|
" [4.9838686 , 1.3140092 , 1.5253077 , 1.9003649 , 6.782009 ,\n",
|
||||||
|
" 3.5835938 , 4.203249 , 1.5495552 ],\n",
|
||||||
|
" [5.3104043 , 5.0123706 , 3.3745103 , 4.540594 , 6.6606255 ,\n",
|
||||||
|
" 3.9076035 , 4.9301 , 3.4342353 ],\n",
|
||||||
|
" [3.3523657 , 4.9394774 , 5.672796 , 4.8756404 , 7.1605525 ,\n",
|
||||||
|
" 8.192585 , 7.06907 , 4.199152 ],\n",
|
||||||
|
" [0.13434029, 4.039145 , 6.2488484 , 3.2168784 , 5.319997 ,\n",
|
||||||
|
" 5.4898353 , 5.301211 , 1.6475667 ]]], dtype=float32)>"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 10,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"# def colors_to_one_hot(tensor):\n",
|
||||||
|
"vi = tf.map_fn(lambda v: tf.map_fn(lambda s: color_vectors[s],v, fn_output_signature=tf.float32),tf.constant(e, shape=(8, 8)),fn_output_signature=tf.float32)\n",
|
||||||
|
"ns = tf.image.extract_patches(tf.constant(vi, shape=(1,8,8,2)), (1,3,3,1), (1,1,1,1), (1,1,1,1), padding=\"SAME\")\n",
|
||||||
|
"virs = tf.reshape(vi, (8*8,2))\n",
|
||||||
|
"# tf.reduce_sum(tf.multiply(tf.repeat(vi, 9, -1), ns))\n",
|
||||||
|
"tf.abs(tf.einsum(\"...k,...k->...\", tf.repeat(vi, 9, -1), ns))"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 27,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"tf.Tensor([0. 0. 0. 1. 0. 0. 0.], shape=(7,), dtype=float32)\n",
|
||||||
|
"tf.Tensor([[3]], shape=(1, 1), dtype=int64)\n",
|
||||||
|
"🟨🟥🟨🟨🟩🟫🟪🟫\n",
|
||||||
|
"🟩🟪🟧🟧🟦🟧🟦🟫\n",
|
||||||
|
"🟩🟦🟫🟨🟦🟨🟩🟫\n",
|
||||||
|
"🟪🟫🟩🟨🟧🟪🟥🟩\n",
|
||||||
|
"🟥🟫🟨🟫🟧🟦🟦🟧\n",
|
||||||
|
"🟩🟫🟫🟧🟨🟨🟪🟫\n",
|
||||||
|
"🟩🟪🟦🟥🟧🟫🟧🟨\n",
|
||||||
|
"🟨🟪🟨🟧🟪🟦🟨🟧\n",
|
||||||
|
"--------------\n",
|
||||||
|
"🟧🟨🟦🟪🟧🟨🟪🟨\n",
|
||||||
|
"🟨🟧🟫🟧🟥🟦🟪🟩\n",
|
||||||
|
"🟫🟪🟨🟨🟧🟫🟫🟩\n",
|
||||||
|
"🟧🟦🟦🟧🟫🟨🟫🟥\n",
|
||||||
|
"🟩🟥🟪🟧🟨🟩🟫🟪\n",
|
||||||
|
"🟫🟩🟨🟦🟨🟫🟦🟩\n",
|
||||||
|
"🟫🟦🟧🟦🟧🟧🟪🟩\n",
|
||||||
|
"🟫🟪🟫🟩🟨🟨🟥🟨\n",
|
||||||
|
"tf.Tensor(\n",
|
||||||
|
"[[0. 0. 0. 0. 0. 0. 0. 0.]\n",
|
||||||
|
" [0. 0. 1. 0. 0. 1. 0. 0.]\n",
|
||||||
|
" [0. 0. 0. 0. 0. 0. 0. 0.]\n",
|
||||||
|
" [0. 0. 0. 0. 0. 0. 0. 0.]\n",
|
||||||
|
" [0. 0. 0. 0. 0. 0. 0. 0.]\n",
|
||||||
|
" [0. 0. 0. 0. 0. 0. 0. 0.]\n",
|
||||||
|
" [0. 0. 0. 0. 0. 0. 0. 0.]\n",
|
||||||
|
" [0. 0. 0. 0. 0. 0. 0. 0.]], shape=(8, 8), dtype=float32)\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"oh = tf.one_hot(3, 7)\n",
|
||||||
|
"print(oh)\n",
|
||||||
|
"print(tf.where(oh))\n",
|
||||||
|
"\n",
|
||||||
|
"# a = tf.constant([[list(tf.one_hot(random.randrange(7), 7)) for i in range(8) ] for i in range(8)], shape=(8,8,len(squares)))\n",
|
||||||
|
"\n",
|
||||||
|
"# print([[list(tf.one_hot(random.randrange(7), 7)) for i in range(8) ] for i in range(8)])\n",
|
||||||
|
"e = [[random.randrange(7) for i in range(8)] for i in range(8)]\n",
|
||||||
|
"# print(e)\n",
|
||||||
|
"\n",
|
||||||
|
"a = tf.one_hot(tf.constant(e, shape=(8, 8)), len(squares))\n",
|
||||||
|
"\n",
|
||||||
|
"print(one_hot_to_string(a))\n",
|
||||||
|
"# print(\"--------------\")\n",
|
||||||
|
"# print(one_hot_to_string(tf.reverse(a, axis=(-2,))))\n",
|
||||||
|
"# print(\"--------------\")\n",
|
||||||
|
"# print(one_hot_to_string(tf.reverse(a, axis=(-3,))))\n",
|
||||||
|
"# print(\"--------------\")\n",
|
||||||
|
"# print(one_hot_to_string(tf.transpose(a, perm=(1,0,2))))\n",
|
||||||
|
"print(\"--------------\")\n",
|
||||||
|
"# print(one_hot_to_string(tf.reverse(tf.transpose(tf.reverse(a, axis=(-2,)), perm=(1,0,2)), axis=(-2,))))\n",
|
||||||
|
"print(one_hot_to_string(tf.reverse(a, axis=(-2,-3,))))\n",
|
||||||
|
"# print(tensor_to_string(tf.transpose(a, [1, 0, 2])))\n",
|
||||||
|
"# print(\"--------------\")\n",
|
||||||
|
"# print(tensor_to_string(tf.linalg.matmul( a , tf.constant(\n",
|
||||||
|
"# [\n",
|
||||||
|
"# [0, 0, 0, 0, 0, 0, 0, 1],\n",
|
||||||
|
"# [0, 0, 0, 0, 0, 0, 1, 0],\n",
|
||||||
|
"# [0, 0, 0, 0, 0, 1, 0, 0],\n",
|
||||||
|
"# [0, 0, 0, 0, 1, 0, 0, 0],\n",
|
||||||
|
"# [0, 0, 0, 1, 0, 0, 0, 0],\n",
|
||||||
|
"# [0, 0, 1, 0, 0, 0, 0, 0],\n",
|
||||||
|
"# [0, 1, 0, 0, 0, 0, 0, 0],\n",
|
||||||
|
"# [1, 0, 0, 0, 0, 0, 0, 0]\n",
|
||||||
|
"# ], dtype=tf.float32, shape=(8,8,1)\n",
|
||||||
|
"# ), transpose_a=True)))\n",
|
||||||
|
"# print(\"--------------\")\n",
|
||||||
|
"# print(tensor_to_string(tf.reverse(a, (0,))))\n",
|
||||||
|
"# print(\"--------------\")\n",
|
||||||
|
"# print(tensor_to_string(tf.reverse(a, (1,))))\n",
|
||||||
|
"# print(tensor_to_string(\n",
|
||||||
|
"# tf.linalg.matmul(\n",
|
||||||
|
"# tf.transpose(a, [0, 2, 1]),\n",
|
||||||
|
"# tf.reverse(a, (1,)))\n",
|
||||||
|
"# ))\n",
|
||||||
|
"print(\n",
|
||||||
|
" tf.einsum(\n",
|
||||||
|
" \"ijk,ijk->ij\",\n",
|
||||||
|
" a,\n",
|
||||||
|
" tf.reverse(a, (1,)))\n",
|
||||||
|
")\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"# print(tf.constant(range(6), shape=(2,3)))\n",
|
||||||
|
"# tf.map_fn(tf.math.reduce_sum, tf.constant(range(6), shape=(2,3)))\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 15,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"<tf.Tensor: shape=(), dtype=float32, numpy=-0.08645747>"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 15,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"import aesthetic_loss\n",
|
||||||
|
"\n",
|
||||||
|
"es = [[[random.randrange(7) for i in range(8)]\n",
|
||||||
|
" for i in range(8)] for j in range(16)]\n",
|
||||||
|
"\n",
|
||||||
|
"vis = tf.stack(\n",
|
||||||
|
" [tf.map_fn(lambda v: tf.map_fn(lambda s: color_vectors[s], v, fn_output_signature=tf.float32), tf.constant(\n",
|
||||||
|
" e, shape=(8, 8)), fn_output_signature=tf.float32) for e in es]\n",
|
||||||
|
")\n",
|
||||||
|
"\n",
|
||||||
|
"aesthetic_loss.aesthetic_loss(vis)\n",
|
||||||
|
"# aesthetic_loss.compute_score(vis, vis)\n",
|
||||||
|
"# tf.linalg.l2_normalize(vis, axis=-1)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 91,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"<tf.Tensor: shape=(), dtype=int32, numpy=2890>"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 91,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"x = tf.Variable(2)\n",
|
||||||
|
"y = tf.Variable(3)\n",
|
||||||
|
"\n",
|
||||||
|
"xp = tf.pow(x, [0,1,2,3])\n",
|
||||||
|
"yp = tf.pow(y, [0,1,2,3])\n",
|
||||||
|
"\n",
|
||||||
|
"A = tf.Variable([[i+j for i in range(4)] for j in range(4)], shape=(4,4))\n",
|
||||||
|
"tf.einsum(\"i,ij,j\", xp, A, yp)\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 70,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"<tf.Tensor: shape=(16, 8), dtype=float32, numpy=\n",
|
||||||
|
"array([[-6.07829041e+02, -1.43090210e+03, -2.25397559e+03,\n",
|
||||||
|
" -3.07704883e+03, -3.90012207e+03, -4.72319531e+03,\n",
|
||||||
|
" -5.54626855e+03, -6.36934229e+03],\n",
|
||||||
|
" [-4.28289948e+01, -8.77924423e+01, -1.32755875e+02,\n",
|
||||||
|
" -1.77719330e+02, -2.22682800e+02, -2.67646210e+02,\n",
|
||||||
|
" -3.12609680e+02, -3.57573120e+02],\n",
|
||||||
|
" [ 9.78072433e+01, 2.85662659e+02, 4.73518066e+02,\n",
|
||||||
|
" 6.61373413e+02, 8.49228882e+02, 1.03708435e+03,\n",
|
||||||
|
" 1.22493982e+03, 1.41279517e+03],\n",
|
||||||
|
" [ 3.64326048e+00, 2.58858261e+01, 4.81283913e+01,\n",
|
||||||
|
" 7.03709641e+01, 9.26135254e+01, 1.14856094e+02,\n",
|
||||||
|
" 1.37098648e+02, 1.59341217e+02],\n",
|
||||||
|
" [-1.08920467e+00, 3.87969995e+00, 8.84860516e+00,\n",
|
||||||
|
" 1.38175116e+01, 1.87864113e+01, 2.37553177e+01,\n",
|
||||||
|
" 2.87242279e+01, 3.36931267e+01],\n",
|
||||||
|
" [ 2.75129032e+01, 9.32729721e+01, 1.59033051e+02,\n",
|
||||||
|
" 2.24793137e+02, 2.90553192e+02, 3.56313324e+02,\n",
|
||||||
|
" 4.22073334e+02, 4.87833435e+02],\n",
|
||||||
|
" [-5.32951474e-01, 8.84730148e+00, 1.82275543e+01,\n",
|
||||||
|
" 2.76078072e+01, 3.69880600e+01, 4.63683128e+01,\n",
|
||||||
|
" 5.57485619e+01, 6.51288223e+01],\n",
|
||||||
|
" [-7.10889587e+01, -1.97856873e+02, -3.24624786e+02,\n",
|
||||||
|
" -4.51392792e+02, -5.78160645e+02, -7.04928589e+02,\n",
|
||||||
|
" -8.31696533e+02, -9.58464539e+02],\n",
|
||||||
|
" [-3.13533902e-01, 1.77808132e+01, 3.58751602e+01,\n",
|
||||||
|
" 5.39695053e+01, 7.20638504e+01, 9.01582031e+01,\n",
|
||||||
|
" 1.08252541e+02, 1.26346893e+02],\n",
|
||||||
|
" [ 7.97924957e+01, 2.76383240e+02, 4.72974030e+02,\n",
|
||||||
|
" 6.69564819e+02, 8.66155518e+02, 1.06274634e+03,\n",
|
||||||
|
" 1.25933704e+03, 1.45592786e+03],\n",
|
||||||
|
" [-2.12144566e+01, -1.41624451e+02, -2.62034424e+02,\n",
|
||||||
|
" -3.82444489e+02, -5.02854431e+02, -6.23264404e+02,\n",
|
||||||
|
" -7.43674438e+02, -8.64084229e+02],\n",
|
||||||
|
" [-1.87502289e+02, -3.70114594e+02, -5.52726929e+02,\n",
|
||||||
|
" -7.35339233e+02, -9.17951660e+02, -1.10056396e+03,\n",
|
||||||
|
" -1.28317627e+03, -1.46578857e+03],\n",
|
||||||
|
" [-1.07772827e+00, -9.14207935e-01, -7.50688553e-01,\n",
|
||||||
|
" -5.87172508e-01, -4.23656583e-01, -2.60130525e-01,\n",
|
||||||
|
" -9.66115594e-02, 6.69060946e-02],\n",
|
||||||
|
" [ 2.14334583e+01, 1.10324554e+02, 1.99215652e+02,\n",
|
||||||
|
" 2.88106781e+02, 3.76997894e+02, 4.65889008e+02,\n",
|
||||||
|
" 5.54780090e+02, 6.43671143e+02],\n",
|
||||||
|
" [ 5.15513878e+01, 1.88147568e+02, 3.24743774e+02,\n",
|
||||||
|
" 4.61339966e+02, 5.97936157e+02, 7.34532349e+02,\n",
|
||||||
|
" 8.71128601e+02, 1.00772473e+03],\n",
|
||||||
|
" [ 2.68221771e+02, 7.04315491e+02, 1.14040918e+03,\n",
|
||||||
|
" 1.57650305e+03, 2.01259656e+03, 2.44869043e+03,\n",
|
||||||
|
" 2.88478394e+03, 3.32087769e+03]], dtype=float32)>"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 70,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"r = tf.random.normal((16,2))\n",
|
||||||
|
"rp = tf.map_fn(lambda v: tf.pow(tf.repeat(tf.reshape(v, shape=(2,1)), repeats=4, axis=1), tf.constant([[0,1,2,3], [0,1,2,3]], dtype=tf.float32)), r)\n",
|
||||||
|
"A = tf.Variable(tf.reshape([float(i) for i in range(8*4*4)], shape=(8,4,4)), shape=(8,4,4), dtype=tf.float32)\n",
|
||||||
|
"tf.einsum(\"ik,ij,lkj->il\", rp[:,0], rp[:,1], A)\n",
|
||||||
|
"# rp[:,0]\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 82,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"tf.Tensor(\n",
|
||||||
|
"[[ 4.15302128e-01 -1.17927468e+00 2.49658436e-01 6.19551420e-01\n",
|
||||||
|
" -1.39033377e-01 2.79209971e-01 3.97601783e-01 -2.17277408e-01]\n",
|
||||||
|
" [ 1.74276757e+00 5.69359064e-02 5.38259745e-04 1.49045396e+00\n",
|
||||||
|
" -7.87080288e-01 -3.98119450e-01 1.18716180e-01 1.86502957e+00]\n",
|
||||||
|
" [ 1.78081155e+00 -1.29199183e+00 -5.72377980e-01 1.73489356e+00\n",
|
||||||
|
" -1.81523621e+00 -7.56731749e-01 5.50791979e-01 2.23066378e+00]\n",
|
||||||
|
" [-2.86107287e-02 2.38285348e-01 -2.90394664e-01 2.20198780e-02\n",
|
||||||
|
" 3.43291536e-02 2.77290910e-01 -3.25144343e-02 3.50418538e-01]\n",
|
||||||
|
" [-2.32135087e-01 1.95894271e-01 5.75335659e-02 -1.76543556e-03\n",
|
||||||
|
" -4.01449911e-02 2.00845733e-01 -9.64274257e-02 2.31855541e-01]\n",
|
||||||
|
" [ 4.60646629e-01 7.75275290e-01 -9.96064782e-01 -5.15016496e-01\n",
|
||||||
|
" 2.03099363e-02 7.87942708e-02 -3.07008505e-01 7.16598153e-01]\n",
|
||||||
|
" [ 3.02803397e-01 -8.83126408e-02 3.72119397e-01 1.14753373e-01\n",
|
||||||
|
" -8.30153376e-03 6.69163764e-02 -3.60930443e-01 4.32033747e-01]\n",
|
||||||
|
" [-1.55036688e-01 2.96032988e-02 5.69820963e-02 -1.04342289e-01\n",
|
||||||
|
" 2.18345195e-01 1.60513148e-01 -2.30654161e-02 2.92460948e-01]\n",
|
||||||
|
" [-9.37448144e-02 4.92064208e-02 1.25066668e-01 -1.08998209e-01\n",
|
||||||
|
" 1.47912696e-01 7.13704750e-02 -9.69988406e-02 3.16246897e-01]\n",
|
||||||
|
" [ 3.76237392e-01 -8.38992000e-02 3.31230164e-01 2.86825836e-01\n",
|
||||||
|
" 3.67110878e-01 -3.61552715e-01 6.06793404e-01 -1.22976625e+00]\n",
|
||||||
|
" [-1.00743435e-01 -1.46258920e-01 -1.82411104e-01 -3.14753801e-01\n",
|
||||||
|
" 7.92229354e-01 2.98170269e-01 -1.82591528e-01 2.78764069e-01]\n",
|
||||||
|
" [ 3.58289123e-01 1.97965109e+00 -1.18626821e+00 -1.47391105e+00\n",
|
||||||
|
" 6.11245155e-01 -4.61193591e-01 -7.64733970e-01 1.45492101e+00]\n",
|
||||||
|
" [-2.00078160e-01 4.32360053e-01 1.63023725e-01 -3.31694707e-02\n",
|
||||||
|
" -4.53773856e-01 4.14216101e-01 1.76714227e-01 2.12158859e-01]\n",
|
||||||
|
" [ 4.91659880e-01 1.56133473e+00 -8.19299459e-01 9.75016415e-01\n",
|
||||||
|
" -2.52336192e+00 1.69403172e+00 2.05944443e+00 1.03490460e+00]\n",
|
||||||
|
" [-2.33240795e+00 6.64168453e+00 -4.04779243e+00 -5.99489832e+00\n",
|
||||||
|
" 1.60179734e+00 2.19676828e+00 4.11457443e+00 2.81916380e-01]\n",
|
||||||
|
" [ 1.73429586e-02 4.88152355e-02 3.41429591e-01 -2.39478126e-02\n",
|
||||||
|
" -2.88104061e-02 1.31737068e-01 -2.55863100e-01 2.94781595e-01]], shape=(16, 8), dtype=float32)\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"import polymap\n",
|
||||||
|
"r = tf.random.normal((16,2))\n",
|
||||||
|
"\n",
|
||||||
|
"pm = polymap.Polymap(8,4)\n",
|
||||||
|
"# pm.build([16,2])\n",
|
||||||
|
"print(pm(r))\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 111,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"<tf.Tensor: shape=(16, 8, 8, 2), dtype=float32, numpy=\n",
|
||||||
|
"array([[[[ 0.26760346, 0.9635291 ],\n",
|
||||||
|
" [ 0.92846906, 0.37140968],\n",
|
||||||
|
" [ 0.43006882, 0.90279603],\n",
|
||||||
|
" ...,\n",
|
||||||
|
" [ 0.9174245 , 0.39790976],\n",
|
||||||
|
" [ 0.68369484, 0.729768 ],\n",
|
||||||
|
" [ 0.5905379 , -0.8070098 ]],\n",
|
||||||
|
"\n",
|
||||||
|
" [[-0.9665002 , 0.25666556],\n",
|
||||||
|
" [ 0.6174854 , -0.7865824 ],\n",
|
||||||
|
" [-0.8356619 , -0.54924417],\n",
|
||||||
|
" ...,\n",
|
||||||
|
" [-0.42570287, 0.90486294],\n",
|
||||||
|
" [ 0.99368197, 0.11223278],\n",
|
||||||
|
" [ 0.22659355, 0.97398937]],\n",
|
||||||
|
"\n",
|
||||||
|
" [[ 0.9998629 , 0.01655844],\n",
|
||||||
|
" [-0.02171519, 0.9997642 ],\n",
|
||||||
|
" [ 0.765471 , -0.64347047],\n",
|
||||||
|
" ...,\n",
|
||||||
|
" [-0.45096368, 0.8925422 ],\n",
|
||||||
|
" [ 0.23687297, 0.9715406 ],\n",
|
||||||
|
" [-0.2564354 , 0.9665613 ]],\n",
|
||||||
|
"\n",
|
||||||
|
" ...,\n",
|
||||||
|
"\n",
|
||||||
|
" [[ 0.99743736, -0.07154454],\n",
|
||||||
|
" [-0.421027 , 0.907048 ],\n",
|
||||||
|
" [-0.06904475, 0.99761355],\n",
|
||||||
|
" ...,\n",
|
||||||
|
" [-0.14730097, -0.9890917 ],\n",
|
||||||
|
" [-0.76936895, -0.63880455],\n",
|
||||||
|
" [ 0.12321126, 0.99238044]],\n",
|
||||||
|
"\n",
|
||||||
|
" [[ 0.8592728 , 0.51151776],\n",
|
||||||
|
" [ 0.5945256 , -0.80407673],\n",
|
||||||
|
" [-0.926837 , -0.37546387],\n",
|
||||||
|
" ...,\n",
|
||||||
|
" [-0.6461148 , 0.76324016],\n",
|
||||||
|
" [-0.30289924, 0.95302254],\n",
|
||||||
|
" [-0.17756721, 0.9841086 ]],\n",
|
||||||
|
"\n",
|
||||||
|
" [[-0.9989693 , 0.04539123],\n",
|
||||||
|
" [ 0.9668924 , 0.25518435],\n",
|
||||||
|
" [-0.70309734, 0.7110935 ],\n",
|
||||||
|
" ...,\n",
|
||||||
|
" [ 0.9820237 , -0.1887575 ],\n",
|
||||||
|
" [-0.45441642, -0.8907894 ],\n",
|
||||||
|
" [ 0.20170917, 0.9794454 ]]],\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
" [[[-0.9041683 , -0.4271764 ],\n",
|
||||||
|
" [-0.8145172 , 0.5801394 ],\n",
|
||||||
|
" [-0.93892664, 0.34411758],\n",
|
||||||
|
" ...,\n",
|
||||||
|
" [ 0.7724613 , -0.63506186],\n",
|
||||||
|
" [ 0.73973364, 0.6728997 ],\n",
|
||||||
|
" [-0.7584759 , -0.6517011 ]],\n",
|
||||||
|
"\n",
|
||||||
|
" [[ 0.40460777, 0.91449016],\n",
|
||||||
|
" [ 0.9511079 , 0.30885857],\n",
|
||||||
|
" [-0.8532944 , 0.5214295 ],\n",
|
||||||
|
" ...,\n",
|
||||||
|
" [ 0.3517043 , -0.9361111 ],\n",
|
||||||
|
" [-0.81593966, -0.57813704],\n",
|
||||||
|
" [-0.96846116, -0.24916425]],\n",
|
||||||
|
"\n",
|
||||||
|
" [[ 0.9971029 , -0.07606225],\n",
|
||||||
|
" [ 0.8346386 , -0.550798 ],\n",
|
||||||
|
" [ 0.79441607, 0.60737395],\n",
|
||||||
|
" ...,\n",
|
||||||
|
" [-0.17288695, -0.9849416 ],\n",
|
||||||
|
" [ 0.34486932, -0.93865067],\n",
|
||||||
|
" [ 0.15837449, 0.9873792 ]],\n",
|
||||||
|
"\n",
|
||||||
|
" ...,\n",
|
||||||
|
"\n",
|
||||||
|
" [[-0.9557019 , 0.2943367 ],\n",
|
||||||
|
" [-0.1128222 , -0.9936151 ],\n",
|
||||||
|
" [ 0.8779246 , -0.4787989 ],\n",
|
||||||
|
" ...,\n",
|
||||||
|
" [ 0.01654347, 0.9998631 ],\n",
|
||||||
|
" [-0.31221828, 0.9500104 ],\n",
|
||||||
|
" [-0.70988506, 0.70431745]],\n",
|
||||||
|
"\n",
|
||||||
|
" [[ 0.59169656, -0.8061608 ],\n",
|
||||||
|
" [ 0.29944003, 0.95411515],\n",
|
||||||
|
" [ 0.12906635, 0.9916359 ],\n",
|
||||||
|
" ...,\n",
|
||||||
|
" [-0.7345422 , -0.678563 ],\n",
|
||||||
|
" [-0.6101673 , -0.79227257],\n",
|
||||||
|
" [ 0.63129115, 0.7755459 ]],\n",
|
||||||
|
"\n",
|
||||||
|
" [[ 0.15320726, -0.9881941 ],\n",
|
||||||
|
" [ 0.8636006 , -0.5041763 ],\n",
|
||||||
|
" [ 0.5430426 , 0.83970505],\n",
|
||||||
|
" ...,\n",
|
||||||
|
" [ 0.99351764, -0.11367732],\n",
|
||||||
|
" [ 0.8001608 , -0.5997855 ],\n",
|
||||||
|
" [-0.8337144 , 0.55219585]]],\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
" [[[-0.6251835 , 0.78047776],\n",
|
||||||
|
" [ 0.6738433 , 0.7388742 ],\n",
|
||||||
|
" [-0.9261264 , -0.37721333],\n",
|
||||||
|
" ...,\n",
|
||||||
|
" [ 0.77851856, -0.6276216 ],\n",
|
||||||
|
" [ 0.6481785 , 0.76148844],\n",
|
||||||
|
" [ 0.8758924 , -0.4825066 ]],\n",
|
||||||
|
"\n",
|
||||||
|
" [[-0.9522389 , -0.3053537 ],\n",
|
||||||
|
" [-0.94333917, -0.33183014],\n",
|
||||||
|
" [-0.8840897 , 0.46731716],\n",
|
||||||
|
" ...,\n",
|
||||||
|
" [ 0.8191087 , 0.5736384 ],\n",
|
||||||
|
" [ 0.6949664 , 0.7190423 ],\n",
|
||||||
|
" [-0.2683343 , 0.96332586]],\n",
|
||||||
|
"\n",
|
||||||
|
" [[-0.83271 , 0.5537092 ],\n",
|
||||||
|
" [-0.72585243, 0.6878505 ],\n",
|
||||||
|
" [ 0.77075607, 0.63713026],\n",
|
||||||
|
" ...,\n",
|
||||||
|
" [-0.9409917 , 0.33842987],\n",
|
||||||
|
" [-0.9817458 , 0.19019753],\n",
|
||||||
|
" [ 0.9783392 , 0.20700799]],\n",
|
||||||
|
"\n",
|
||||||
|
" ...,\n",
|
||||||
|
"\n",
|
||||||
|
" [[-0.04849374, 0.99882346],\n",
|
||||||
|
" [ 0.9695309 , 0.24496901],\n",
|
||||||
|
" [-0.37880003, -0.9254785 ],\n",
|
||||||
|
" ...,\n",
|
||||||
|
" [ 0.6892487 , -0.7245249 ],\n",
|
||||||
|
" [ 0.81647533, 0.57738036],\n",
|
||||||
|
" [-0.9972612 , -0.07396041]],\n",
|
||||||
|
"\n",
|
||||||
|
" [[-0.7055271 , 0.7086829 ],\n",
|
||||||
|
" [-0.8819862 , 0.4712752 ],\n",
|
||||||
|
" [-0.9642273 , -0.2650766 ],\n",
|
||||||
|
" ...,\n",
|
||||||
|
" [ 0.82875234, -0.55961555],\n",
|
||||||
|
" [-0.06711138, 0.9977455 ],\n",
|
||||||
|
" [-0.1260777 , 0.99202037]],\n",
|
||||||
|
"\n",
|
||||||
|
" [[ 0.99550056, -0.09475542],\n",
|
||||||
|
" [-0.9923597 , 0.1233783 ],\n",
|
||||||
|
" [-0.71074235, -0.70345235],\n",
|
||||||
|
" ...,\n",
|
||||||
|
" [ 0.8930231 , -0.4500106 ],\n",
|
||||||
|
" [ 0.1744522 , 0.9846656 ],\n",
|
||||||
|
" [ 0.8210883 , -0.570801 ]]],\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
" ...,\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
" [[[-0.7103158 , 0.703883 ],\n",
|
||||||
|
" [ 0.37816712, -0.9257372 ],\n",
|
||||||
|
" [ 0.9501393 , 0.31182557],\n",
|
||||||
|
" ...,\n",
|
||||||
|
" [ 0.66041577, 0.75090003],\n",
|
||||||
|
" [ 0.00253691, 0.9999967 ],\n",
|
||||||
|
" [-0.3108576 , -0.95045644]],\n",
|
||||||
|
"\n",
|
||||||
|
" [[-0.9948109 , 0.1017414 ],\n",
|
||||||
|
" [ 0.8379823 , 0.54569715],\n",
|
||||||
|
" [-0.42345405, 0.9059175 ],\n",
|
||||||
|
" ...,\n",
|
||||||
|
" [ 0.7350055 , 0.678061 ],\n",
|
||||||
|
" [-0.6902005 , -0.7236182 ],\n",
|
||||||
|
" [ 0.3050526 , -0.9523355 ]],\n",
|
||||||
|
"\n",
|
||||||
|
" [[-0.9992436 , -0.03888722],\n",
|
||||||
|
" [-0.9524134 , -0.3048092 ],\n",
|
||||||
|
" [ 0.17434528, 0.9846846 ],\n",
|
||||||
|
" ...,\n",
|
||||||
|
" [-0.48622662, -0.8738327 ],\n",
|
||||||
|
" [ 0.5717803 , 0.82040673],\n",
|
||||||
|
" [-0.7644213 , 0.64471704]],\n",
|
||||||
|
"\n",
|
||||||
|
" ...,\n",
|
||||||
|
"\n",
|
||||||
|
" [[-0.0501435 , -0.998742 ],\n",
|
||||||
|
" [-0.915064 , 0.40330857],\n",
|
||||||
|
" [-0.7114537 , -0.7027328 ],\n",
|
||||||
|
" ...,\n",
|
||||||
|
" [ 0.9969975 , -0.07743412],\n",
|
||||||
|
" [ 0.58687896, -0.80967456],\n",
|
||||||
|
" [ 0.9493494 , 0.31422222]],\n",
|
||||||
|
"\n",
|
||||||
|
" [[ 0.49507466, 0.86885047],\n",
|
||||||
|
" [ 0.8797779 , -0.47538495],\n",
|
||||||
|
" [-0.9211454 , 0.38921857],\n",
|
||||||
|
" ...,\n",
|
||||||
|
" [ 0.8058742 , -0.5920867 ],\n",
|
||||||
|
" [ 0.53095114, -0.84740245],\n",
|
||||||
|
" [ 0.9835844 , -0.18044874]],\n",
|
||||||
|
"\n",
|
||||||
|
" [[-0.19601943, -0.98059994],\n",
|
||||||
|
" [ 0.5338162 , -0.8456005 ],\n",
|
||||||
|
" [-0.9987141 , -0.05069751],\n",
|
||||||
|
" ...,\n",
|
||||||
|
" [ 0.73851573, -0.6742363 ],\n",
|
||||||
|
" [-0.99623436, 0.08670113],\n",
|
||||||
|
" [ 0.62578243, 0.7799976 ]]],\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
" [[[ 0.99914247, -0.04140271],\n",
|
||||||
|
" [-0.45437112, 0.8908124 ],\n",
|
||||||
|
" [-0.42213327, -0.9065337 ],\n",
|
||||||
|
" ...,\n",
|
||||||
|
" [ 0.62287563, 0.78232086],\n",
|
||||||
|
" [-0.34597075, -0.9382453 ],\n",
|
||||||
|
" [ 0.79978997, -0.6002798 ]],\n",
|
||||||
|
"\n",
|
||||||
|
" [[ 0.8313185 , -0.5557963 ],\n",
|
||||||
|
" [ 0.7764067 , -0.6302323 ],\n",
|
||||||
|
" [-0.4633804 , 0.8861594 ],\n",
|
||||||
|
" ...,\n",
|
||||||
|
" [-0.06348424, -0.99798286],\n",
|
||||||
|
" [-0.24774338, 0.9688258 ],\n",
|
||||||
|
" [ 0.7214034 , 0.6925151 ]],\n",
|
||||||
|
"\n",
|
||||||
|
" [[-0.7895615 , -0.6136714 ],\n",
|
||||||
|
" [ 0.7669125 , -0.64175165],\n",
|
||||||
|
" [-0.60046875, 0.7996482 ],\n",
|
||||||
|
" ...,\n",
|
||||||
|
" [-0.22642732, 0.974028 ],\n",
|
||||||
|
" [-0.874757 , -0.48456183],\n",
|
||||||
|
" [-0.75133216, -0.6599242 ]],\n",
|
||||||
|
"\n",
|
||||||
|
" ...,\n",
|
||||||
|
"\n",
|
||||||
|
" [[ 0.99788564, 0.06499327],\n",
|
||||||
|
" [-0.9613605 , 0.27529252],\n",
|
||||||
|
" [ 0.8874723 , -0.46086106],\n",
|
||||||
|
" ...,\n",
|
||||||
|
" [ 0.24368125, 0.96985537],\n",
|
||||||
|
" [ 0.9091159 , -0.4165433 ],\n",
|
||||||
|
" [ 0.39610097, 0.9182069 ]],\n",
|
||||||
|
"\n",
|
||||||
|
" [[ 0.15766756, -0.98749226],\n",
|
||||||
|
" [-0.40235043, 0.9154857 ],\n",
|
||||||
|
" [-0.7679899 , -0.64046216],\n",
|
||||||
|
" ...,\n",
|
||||||
|
" [ 0.85627615, 0.5165182 ],\n",
|
||||||
|
" [ 0.43845195, 0.89875454],\n",
|
||||||
|
" [-0.995217 , 0.09768905]],\n",
|
||||||
|
"\n",
|
||||||
|
" [[ 0.31384754, -0.94947344],\n",
|
||||||
|
" [-0.335953 , 0.94187874],\n",
|
||||||
|
" [ 0.32295498, 0.94641435],\n",
|
||||||
|
" ...,\n",
|
||||||
|
" [ 0.56830823, 0.82281566],\n",
|
||||||
|
" [ 0.5089668 , 0.860786 ],\n",
|
||||||
|
" [ 0.43783924, 0.89905316]]],\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
" [[[ 0.00368896, -0.99999326],\n",
|
||||||
|
" [-0.15407981, -0.98805845],\n",
|
||||||
|
" [-0.9993973 , 0.03471346],\n",
|
||||||
|
" ...,\n",
|
||||||
|
" [-0.7566905 , 0.6537732 ],\n",
|
||||||
|
" [-0.58953756, 0.80774087],\n",
|
||||||
|
" [-0.68740994, 0.72626966]],\n",
|
||||||
|
"\n",
|
||||||
|
" [[-0.46347883, 0.8861079 ],\n",
|
||||||
|
" [ 0.00387764, 0.99999243],\n",
|
||||||
|
" [ 0.58661395, -0.80986667],\n",
|
||||||
|
" ...,\n",
|
||||||
|
" [ 0.9909095 , 0.13453025],\n",
|
||||||
|
" [-0.44597933, 0.89504325],\n",
|
||||||
|
" [ 0.6099662 , 0.7924275 ]],\n",
|
||||||
|
"\n",
|
||||||
|
" [[ 0.88540715, 0.46481633],\n",
|
||||||
|
" [-0.9495712 , 0.31355175],\n",
|
||||||
|
" [-0.9835763 , -0.1804929 ],\n",
|
||||||
|
" ...,\n",
|
||||||
|
" [ 0.11676199, 0.99315995],\n",
|
||||||
|
" [ 0.78835183, -0.6152245 ],\n",
|
||||||
|
" [ 0.6080205 , -0.7939213 ]],\n",
|
||||||
|
"\n",
|
||||||
|
" ...,\n",
|
||||||
|
"\n",
|
||||||
|
" [[ 0.8907368 , -0.45451924],\n",
|
||||||
|
" [-0.61389214, -0.7893899 ],\n",
|
||||||
|
" [-0.9944057 , -0.10562801],\n",
|
||||||
|
" ...,\n",
|
||||||
|
" [ 0.37469187, 0.92714936],\n",
|
||||||
|
" [ 0.94401246, -0.32990944],\n",
|
||||||
|
" [ 0.7751419 , 0.6317871 ]],\n",
|
||||||
|
"\n",
|
||||||
|
" [[-0.14082797, -0.99003404],\n",
|
||||||
|
" [ 0.27515417, 0.96140003],\n",
|
||||||
|
" [-0.40955904, 0.9122836 ],\n",
|
||||||
|
" ...,\n",
|
||||||
|
" [ 0.8188169 , 0.5740549 ],\n",
|
||||||
|
" [-0.9845477 , -0.17511682],\n",
|
||||||
|
" [-0.71323305, 0.7009269 ]],\n",
|
||||||
|
"\n",
|
||||||
|
" [[-0.9327261 , -0.36058545],\n",
|
||||||
|
" [-0.2761146 , -0.9611246 ],\n",
|
||||||
|
" [ 0.7902948 , -0.6127267 ],\n",
|
||||||
|
" ...,\n",
|
||||||
|
" [ 0.999295 , 0.03754444],\n",
|
||||||
|
" [ 0.95306456, -0.30276698],\n",
|
||||||
|
" [-0.9659615 , 0.2586859 ]]]], dtype=float32)>"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 111,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"r = tf.random.normal(shape = (16, 8, 8, 2))\n",
|
||||||
|
"tf.linalg.l2_normalize(r,axis=-1)"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": ".patch_gen_venv",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 3
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython3",
|
||||||
|
"version": "3.10.9"
|
||||||
|
},
|
||||||
|
"orig_nbformat": 4
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 2
|
||||||
|
}
|
@ -1,155 +0,0 @@
|
|||||||
{
|
|
||||||
"cells": [
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 6,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"import tensorflow as tf\n",
|
|
||||||
"import random"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 29,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"['🟥', '🟦', '🟧', '🟨', '🟩', '🟪', '🟫']\n"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
|
||||||
"squares = [chr(i) for i in range(0x1F7E5, 0x1F7EC)]\n",
|
|
||||||
"tfsquares = tf.constant(squares)\n",
|
|
||||||
"print(squares)"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 51,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"def tensor_to_string(tensor):\n",
|
|
||||||
" square_select = tf.math.argmax(tensor, 2)\n",
|
|
||||||
"\n",
|
|
||||||
" tfstring = tf.strings.join(\n",
|
|
||||||
" tf.map_fn(\n",
|
|
||||||
" lambda v:\n",
|
|
||||||
" tf.strings.join(\n",
|
|
||||||
" tf.gather(tfsquares, tf.cast(v, tf.int64))), square_select, fn_output_signature=tf.string), \"\\n\")\n",
|
|
||||||
"\n",
|
|
||||||
" return tfstring.numpy().decode()\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 95,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"tf.Tensor([0. 0. 0. 1. 0. 0. 0.], shape=(7,), dtype=float32)\n",
|
|
||||||
"tf.Tensor([[3]], shape=(1, 1), dtype=int64)\n",
|
|
||||||
"🟦🟪🟦🟫🟦🟧🟨🟨\n",
|
|
||||||
"🟦🟫🟧🟫🟪🟥🟧🟪\n",
|
|
||||||
"🟫🟩🟨🟧🟨🟫🟪🟨\n",
|
|
||||||
"🟧🟦🟨🟪🟦🟨🟥🟨\n",
|
|
||||||
"🟫🟧🟧🟨🟧🟧🟩🟪\n",
|
|
||||||
"🟦🟦🟨🟦🟥🟫🟧🟩\n",
|
|
||||||
"🟫🟩🟧🟫🟨🟫🟥🟩\n",
|
|
||||||
"🟧🟫🟫🟩🟧🟧🟫🟩\n",
|
|
||||||
"--------------\n",
|
|
||||||
"tf.Tensor(\n",
|
|
||||||
"[[0. 0. 0. 0. 0. 0. 0. 0.]\n",
|
|
||||||
" [0. 0. 0. 0. 0. 0. 0. 0.]\n",
|
|
||||||
" [0. 0. 0. 0. 0. 0. 0. 0.]\n",
|
|
||||||
" [0. 0. 1. 0. 0. 1. 0. 0.]\n",
|
|
||||||
" [0. 0. 1. 0. 0. 1. 0. 0.]\n",
|
|
||||||
" [0. 0. 0. 0. 0. 0. 0. 0.]\n",
|
|
||||||
" [0. 0. 0. 0. 0. 0. 0. 0.]\n",
|
|
||||||
" [0. 1. 0. 0. 0. 0. 1. 0.]], shape=(8, 8), dtype=float32)\n"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
|
||||||
"oh = tf.one_hot(3, 7)\n",
|
|
||||||
"print(oh)\n",
|
|
||||||
"print(tf.where(oh))\n",
|
|
||||||
"\n",
|
|
||||||
"# a = tf.constant([[list(tf.one_hot(random.randrange(7), 7)) for i in range(8) ] for i in range(8)], shape=(8,8,len(squares)))\n",
|
|
||||||
"\n",
|
|
||||||
"# print([[list(tf.one_hot(random.randrange(7), 7)) for i in range(8) ] for i in range(8)])\n",
|
|
||||||
"e = [[random.randrange(7) for i in range(8)] for i in range(8)]\n",
|
|
||||||
"# print(e)\n",
|
|
||||||
"\n",
|
|
||||||
"a = tf.one_hot(tf.constant(e, shape=(8, 8)), len(squares))\n",
|
|
||||||
"\n",
|
|
||||||
"print(tensor_to_string(a))\n",
|
|
||||||
"print(\"--------------\")\n",
|
|
||||||
"# print(tensor_to_string(tf.transpose(a, [1, 0, 2])))\n",
|
|
||||||
"# print(\"--------------\")\n",
|
|
||||||
"# print(tensor_to_string(tf.linalg.matmul( a , tf.constant(\n",
|
|
||||||
"# [\n",
|
|
||||||
"# [0, 0, 0, 0, 0, 0, 0, 1],\n",
|
|
||||||
"# [0, 0, 0, 0, 0, 0, 1, 0],\n",
|
|
||||||
"# [0, 0, 0, 0, 0, 1, 0, 0],\n",
|
|
||||||
"# [0, 0, 0, 0, 1, 0, 0, 0],\n",
|
|
||||||
"# [0, 0, 0, 1, 0, 0, 0, 0],\n",
|
|
||||||
"# [0, 0, 1, 0, 0, 0, 0, 0],\n",
|
|
||||||
"# [0, 1, 0, 0, 0, 0, 0, 0],\n",
|
|
||||||
"# [1, 0, 0, 0, 0, 0, 0, 0]\n",
|
|
||||||
"# ], dtype=tf.float32, shape=(8,8,1)\n",
|
|
||||||
"# ), transpose_a=True)))\n",
|
|
||||||
"# print(\"--------------\")\n",
|
|
||||||
"# print(tensor_to_string(tf.reverse(a, (0,))))\n",
|
|
||||||
"# print(\"--------------\")\n",
|
|
||||||
"# print(tensor_to_string(tf.reverse(a, (1,))))\n",
|
|
||||||
"# print(tensor_to_string(\n",
|
|
||||||
"# tf.linalg.matmul(\n",
|
|
||||||
"# tf.transpose(a, [0, 2, 1]),\n",
|
|
||||||
"# tf.reverse(a, (1,)))\n",
|
|
||||||
"# ))\n",
|
|
||||||
"print(\n",
|
|
||||||
" tf.einsum(\n",
|
|
||||||
" \"ijk,ijk->ij\",\n",
|
|
||||||
" a,\n",
|
|
||||||
" tf.reverse(a, (1,)))\n",
|
|
||||||
")\n",
|
|
||||||
"\n",
|
|
||||||
"\n",
|
|
||||||
"# print(tf.constant(range(6), shape=(2,3)))\n",
|
|
||||||
"# tf.map_fn(tf.math.reduce_sum, tf.constant(range(6), shape=(2,3)))\n"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"metadata": {
|
|
||||||
"kernelspec": {
|
|
||||||
"display_name": ".patch_gen_venv",
|
|
||||||
"language": "python",
|
|
||||||
"name": "python3"
|
|
||||||
},
|
|
||||||
"language_info": {
|
|
||||||
"codemirror_mode": {
|
|
||||||
"name": "ipython",
|
|
||||||
"version": 3
|
|
||||||
},
|
|
||||||
"file_extension": ".py",
|
|
||||||
"mimetype": "text/x-python",
|
|
||||||
"name": "python",
|
|
||||||
"nbconvert_exporter": "python",
|
|
||||||
"pygments_lexer": "ipython3",
|
|
||||||
"version": "3.10.0"
|
|
||||||
},
|
|
||||||
"orig_nbformat": 4
|
|
||||||
},
|
|
||||||
"nbformat": 4,
|
|
||||||
"nbformat_minor": 2
|
|
||||||
}
|
|
@ -0,0 +1,28 @@
|
|||||||
|
import tensorflow as tf
|
||||||
|
# from tensorflow import Layer
|
||||||
|
|
||||||
|
|
||||||
|
class Polymap(tf.keras.layers.Layer):
|
||||||
|
def __init__(self, num_poly, poly_deg):
|
||||||
|
super(Polymap, self).__init__()
|
||||||
|
self.num_poly = num_poly
|
||||||
|
self.poly_deg = poly_deg
|
||||||
|
self.exponents = tf.constant(
|
||||||
|
[range(self.poly_deg + 1), range(self.poly_deg + 1)], dtype=tf.float32)
|
||||||
|
self.kernels = self.add_weight("kernels",
|
||||||
|
shape=[
|
||||||
|
self.num_poly, self.poly_deg + 1, self.poly_deg + 1],
|
||||||
|
trainable=True)
|
||||||
|
|
||||||
|
def build(self, input_shape):
|
||||||
|
if input_shape[-1] != 2:
|
||||||
|
raise "Input shape must be Ix2"
|
||||||
|
|
||||||
|
|
||||||
|
def call(self, input):
|
||||||
|
powers = tf.map_fn(
|
||||||
|
lambda v: tf.pow(
|
||||||
|
tf.repeat(tf.reshape(v, shape=(2, 1)),
|
||||||
|
repeats=self.poly_deg + 1, axis=1),
|
||||||
|
self.exponents), input)
|
||||||
|
return tf.einsum("ik,ij,lkj->il", powers[:, 0], powers[:, 1], self.kernels)
|
Loading…
Reference in New Issue