You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

29 lines
1.0 KiB
Python

2 years ago
import tensorflow as tf
# from tensorflow import Layer
class Polymap(tf.keras.layers.Layer):
def __init__(self, num_poly, poly_deg):
super(Polymap, self).__init__()
self.num_poly = num_poly
self.poly_deg = poly_deg
self.exponents = tf.constant(
[range(self.poly_deg + 1), range(self.poly_deg + 1)], dtype=tf.float32)
self.kernels = self.add_weight("kernels",
shape=[
self.num_poly, self.poly_deg + 1, self.poly_deg + 1],
trainable=True)
def build(self, input_shape):
if input_shape[-1] != 2:
raise "Input shape must be Ix2"
def call(self, input):
powers = tf.map_fn(
lambda v: tf.pow(
tf.repeat(tf.reshape(v, shape=(2, 1)),
repeats=self.poly_deg + 1, axis=1),
self.exponents), input)
return tf.einsum("ik,ij,lkj->il", powers[:, 0], powers[:, 1], self.kernels)