"""Operations on geometric algebra tensors used internally."""
from typing import Union
import tensorflow as tf
[docs]def mv_multiply(
a_blade_values: tf.Tensor, b_blade_values: tf.Tensor, cayley: tf.Tensor
) -> tf.Tensor:
# ...i, ijk -> ...jk
x = tf.tensordot(a_blade_values, cayley, axes=[-1, 0])
# ...1j, ...jk -> ...1k
x = tf.expand_dims(b_blade_values, axis=b_blade_values.shape.ndims - 1) @ x
# ...1k -> ...k
x = tf.squeeze(x, axis=-2)
return x
[docs]def mv_conv1d(
a_blade_values: tf.Tensor,
k_blade_values: tf.Tensor,
cayley: tf.Tensor,
stride: int,
padding: str,
dilations: Union[int, None] = None,
) -> tf.Tensor:
# Winograd convolution
# A: [..., S, CI, BI]
# K: [K, CI, CO, BK]
# C: [BI, BK, BO]
kernel_size = k_blade_values.shape[0]
a_batch_shape = tf.shape(a_blade_values)[:-3]
# Reshape a_blade_values to a 2d image (since that's what the tf op expects)
# [*, S, 1, CI*BI]
a_image_shape = tf.concat(
[
a_batch_shape,
tf.shape(a_blade_values)[-3:-2],
[1, tf.reduce_prod(tf.shape(a_blade_values)[-2:])],
],
axis=0,
)
a_image = tf.reshape(a_blade_values, a_image_shape)
sizes = [1, kernel_size, 1, 1]
strides = [1, stride, 1, 1]
# [*, P, 1, K*CI*BI] where eg. number of patches P = S * K for
# stride=1 and "SAME", (S-K+1) * K for "VALID", ...
a_slices = tf.image.extract_patches(
a_image, sizes=sizes, strides=strides, rates=[1, 1, 1, 1], padding=padding
)
# [..., P, K, CI, BI]
out_shape = tf.concat(
[
a_batch_shape,
tf.shape(a_slices)[-3:-2],
tf.shape(k_blade_values)[:1],
tf.shape(a_blade_values)[-2:],
],
axis=0,
)
a_slices = tf.reshape(a_slices, out_shape)
# TODO: Optimize this to not use einsum (since it's slow with ellipses)
# a_...p,k,ci,bi; k_k,ci,co,bk; c_bi,bk,bo -> y_...p,co,bo
# ...a b c d , e c f g , d g h -> ...a f h
x = tf.einsum("...abcd,bcfg,dgh->...afh", a_slices, k_blade_values, cayley)
return x
[docs]def mv_reversion(a_blade_values, algebra_blade_degrees):
algebra_blade_degrees = tf.cast(algebra_blade_degrees, tf.float32)
# for each blade, 0 if even number of swaps required, else 1
odd_swaps = tf.cast(
tf.floor(algebra_blade_degrees * (algebra_blade_degrees - 0.5)) % 2, tf.float32
)
# [0, 1] -> [-1, 1]
reversion_signs = 1.0 - 2.0 * odd_swaps
return reversion_signs * a_blade_values
[docs]def mv_grade_automorphism(a_blade_values, algebra_blade_degrees):
algebra_blade_degrees = tf.cast(algebra_blade_degrees, tf.float32)
signs = 1.0 - 2.0 * (algebra_blade_degrees % 2.0)
return signs * a_blade_values