"""Defines the `MultiVector` class which is used as a convenience wrapper
for `GeometricAlgebra` operations.
"""
from typing import List, Union
import tensorflow as tf
from tfga.blades import BladeKind
[docs]class MultiVector:
    """Wrapper for geometric algebra tensors using `GeometricAlgebra`
    operations in a less verbose way using operators.
    """
    def __init__(self, blade_values: tf.Tensor, algebra: "GeometricAlgebra"):
        """Initializes a MultiVector from a geometric algebra `tf.Tensor`
        and its corresponding `GeometricAlgebra`.
        Args:
            blade_values: Geometric algebra `tf.Tensor` with as many elements
            on its last axis as blades in the algebra
            algebra: `GeometricAlgebra` instance corresponding to the geometric
            algebra tensor
        """
        self._blade_values = blade_values
        self._algebra = algebra
    @property
    def tensor(self):
        """Geometric algebra tensor holding the values of this multivector."""
        return self._blade_values
    @property
    def algebra(self):
        """`GeometricAlgebra` instance this multivector belongs to."""
        return self._algebra
    @property
    def batch_shape(self):
        """Batch shape of the multivector (ie. the shape of all axes except
        for the last one in the geometric algebra tensor).
        """
        return self._blade_values.shape[:-1]
    def __len__(self) -> int:
        """Number of elements on the first axis of the geometric algebra
        tensor."""
        return self._blade_values.shape[0]
    def __iter__(self):
        for n in range(self._blade_values.shape[0]):
            # If we only have one axis left, return the
            # actual numbers, otherwise return a new
            # multivector.
            if self._blade_values.shape.ndims == 1:
                yield self._blade_values[n]
            else:
                yield MultiVector(self._blade_values[n], self._algebra)
    def __xor__(self, other: "MultiVector") -> "MultiVector":
        """Exterior product. See `GeometricAlgebra.ext_prod()`"""
        assert isinstance(other, MultiVector)
        return MultiVector(
            self._algebra.ext_prod(self._blade_values, other._blade_values),
            self._algebra,
        )
    def __or__(self, other: "MultiVector") -> "MultiVector":
        """Inner product. See `GeometricAlgebra.inner_prod()`"""
        assert isinstance(other, MultiVector)
        return MultiVector(
            self._algebra.inner_prod(self._blade_values, other._blade_values),
            self._algebra,
        )
    def __mul__(self, other: "MultiVector") -> "MultiVector":
        """Geometric product. See `GeometricAlgebra.geom_prod()`"""
        assert isinstance(other, MultiVector)
        return MultiVector(
            self._algebra.geom_prod(self._blade_values, other._blade_values),
            self._algebra,
        )
    def __truediv__(self, other: "MultiVector") -> "MultiVector":
        """Division, ie. multiplication with the inverse."""
        assert isinstance(other, MultiVector)
        return MultiVector(
            self._algebra.geom_prod(
                self._blade_values, self._algebra.inverse(other._blade_values)
            ),
            self._algebra,
        )
    def __and__(self, other: "MultiVector") -> "MultiVector":
        """Regressive product. See `GeometricAlgebra.reg_prod()`"""
        assert isinstance(other, MultiVector)
        return MultiVector(
            self._algebra.reg_prod(self._blade_values, other._blade_values),
            self._algebra,
        )
    def __invert__(self) -> "MultiVector":
        """Reversion. See `GeometricAlgebra.reversion()`"""
        return MultiVector(self._algebra.reversion(self._blade_values), self._algebra)
    def __neg__(self) -> "MultiVector":
        """Negation."""
        return MultiVector(-self._blade_values, self._algebra)
    def __add__(self, other: "MultiVector") -> "MultiVector":
        """Addition of multivectors."""
        assert isinstance(other, MultiVector)
        return MultiVector(self._blade_values + other._blade_values, self._algebra)
    def __sub__(self, other: "MultiVector") -> "MultiVector":
        """Subtraction of multivectors."""
        assert isinstance(other, MultiVector)
        return MultiVector(self._blade_values - other._blade_values, self._algebra)
    def __pow__(self, n: int) -> "MultiVector":
        """Multivector raised to an integer power."""
        return MultiVector(self._algebra.int_pow(self._blade_values, n), self._algebra)
    def __getitem__(self, key: Union[str, List[str]]) -> "MultiVector":
        """`MultiVector` with only passed blade names as non-zeros."""
        return MultiVector(
            self._algebra.keep_blades_with_name(self._blade_values, key), self._algebra
        )
    def __call__(self, key: Union[str, List[str]]):
        """`tf.Tensor` with passed blade names on last axis."""
        return self._algebra.select_blades_with_name(self._blade_values, key)
    def __repr__(self) -> str:
        return self._algebra.mv_repr(self._blade_values)
[docs]    def inverse(self) -> "MultiVector":
        """Inverse. See `GeometricAlgebra.inverse()`."""
        return MultiVector(self._algebra.inverse(self._blade_values), self._algebra) 
[docs]    def simple_inverse(self) -> "MultiVector":
        """Simple inverse. See `GeometricAlgebra.simple_inverse()`."""
        return MultiVector(
            self._algebra.simple_inverse(self._blade_values), self._algebra
        ) 
[docs]    def dual(self) -> "MultiVector":
        """Dual. See `GeometricAlgebra.dual()`."""
        return MultiVector(self._algebra.dual(self._blade_values), self._algebra) 
[docs]    def conjugation(self) -> "MultiVector":
        """Conjugation. See `GeometricAlgebra.conjugation()`."""
        return MultiVector(self._algebra.conjugation(self._blade_values), self._algebra) 
[docs]    def grade_automorphism(self) -> "MultiVector":
        """Grade automorphism. See `GeometricAlgebra.grade_automorphism()`."""
        return MultiVector(
            self._algebra.grade_automorphism(self._blade_values), self._algebra
        ) 
[docs]    def approx_exp(self, order: int = 50) -> "MultiVector":
        """Approximate exponential. See `GeometricAlgebra.approx_exp()`."""
        return MultiVector(
            self._algebra.approx_exp(self._blade_values, order=order), self._algebra
        ) 
[docs]    def exp(self, square_scalar_tolerance: Union[float, None] = 1e-4) -> "MultiVector":
        """Exponential. See `GeometricAlgebra.exp()`."""
        return MultiVector(
            self._algebra.exp(
                self._blade_values, square_scalar_tolerance=square_scalar_tolerance
            ),
            self._algebra,
        ) 
[docs]    def approx_log(self, order: int = 50) -> "MultiVector":
        """Approximate logarithm. See `GeometricAlgebra.approx_log()`."""
        return MultiVector(
            self._algebra.approx_log(self._blade_values, order=order), self._algebra
        ) 
[docs]    def is_pure_kind(self, kind: BladeKind) -> bool:
        """Whether the `MultiVector` is of a pure kind."""
        return self._algebra.is_pure_kind(self._blade_values, kind=kind) 
[docs]    def geom_conv1d(
        self,
        kernel: "MultiVector",
        stride: int,
        padding: str,
        dilations: Union[int, None] = None,
    ) -> "MultiVector":
        """1D convolution. See `GeometricAlgebra.geom_conv1d().`"""
        return MultiVector(
            self._algebra.geom_conv1d(
                self._blade_values,
                kernel._blade_values,
                stride=stride,
                padding=padding,
                dilations=dilations,
            ),
            self._algebra,
        )