import tensorflow as tf # from tensorflow import Layer class Polymap(tf.keras.layers.Layer): def __init__(self, num_poly, poly_deg): super(Polymap, self).__init__() self.num_poly = num_poly self.poly_deg = poly_deg self.exponents = tf.constant( [range(self.poly_deg + 1), range(self.poly_deg + 1)], dtype=tf.float32) self.kernels = self.add_weight("kernels", shape=[ self.num_poly, self.poly_deg + 1, self.poly_deg + 1], trainable=True) def build(self, input_shape): if input_shape[-1] != 2: raise "Input shape must be Ix2" def call(self, input): powers = tf.map_fn( lambda v: tf.pow( tf.repeat(tf.reshape(v, shape=(2, 1)), repeats=self.poly_deg + 1, axis=1), self.exponents), input) return tf.einsum("ik,ij,lkj->il", powers[:, 0], powers[:, 1], self.kernels)