You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
29 lines
1.0 KiB
Python
29 lines
1.0 KiB
Python
2 years ago
|
import tensorflow as tf
|
||
|
# from tensorflow import Layer
|
||
|
|
||
|
|
||
|
class Polymap(tf.keras.layers.Layer):
|
||
|
def __init__(self, num_poly, poly_deg):
|
||
|
super(Polymap, self).__init__()
|
||
|
self.num_poly = num_poly
|
||
|
self.poly_deg = poly_deg
|
||
|
self.exponents = tf.constant(
|
||
|
[range(self.poly_deg + 1), range(self.poly_deg + 1)], dtype=tf.float32)
|
||
|
self.kernels = self.add_weight("kernels",
|
||
|
shape=[
|
||
|
self.num_poly, self.poly_deg + 1, self.poly_deg + 1],
|
||
|
trainable=True)
|
||
|
|
||
|
def build(self, input_shape):
|
||
|
if input_shape[-1] != 2:
|
||
|
raise "Input shape must be Ix2"
|
||
|
|
||
|
|
||
|
def call(self, input):
|
||
|
powers = tf.map_fn(
|
||
|
lambda v: tf.pow(
|
||
|
tf.repeat(tf.reshape(v, shape=(2, 1)),
|
||
|
repeats=self.poly_deg + 1, axis=1),
|
||
|
self.exponents), input)
|
||
|
return tf.einsum("ik,ij,lkj->il", powers[:, 0], powers[:, 1], self.kernels)
|