# -----------------------------------------------------------------------------
# This file was autogenerated by symforce from template:
#     geo_package/CLASS.py.jinja
# Do NOT modify by hand.
# -----------------------------------------------------------------------------
# ruff: noqa: PLR0915, F401, PLW0211, PLR0914
import math
import random
import typing as T
import numpy
# isort: split
from .ops import unit3 as ops
[docs]class Unit3(object):
    """
    Autogenerated Python implementation of :py:class:`symforce.geo.unit3.Unit3`.
    Direction in R^3 represented as a unit vector on the S^2 sphere manifold.
    Storage is three dimensional, and tangent space is two dimensional. Due to the nature of the
    manifold, the unit X vector is handled as a singularity.
    The implementation of the retract and local_coordinates functions are based on Appendix B.2 :
    [Hertzberg 2013] Integrating Generic Sensor Fusion Algorithms with Sound State Representations
    through Encapsulation of Manifolds
    The retract operation performs a perturbation to the desired unit X vector, which is then rotated to
    desired position along the actual stored unit vector through a Householder-reflection + relection
    across the XZ plane.
        x.retract(delta) = x [+] delta = Rx * Exp(delta), where
        Exp(delta) = [cos(||delta||), sinc(||delta||) * delta], and
        Rx = (I - 2 vv^T / (v^Tv))X, v = x - e_x != 0, X is a matrix negating 2nd vector component
           = diag(1, -1, -1)       , x = e_x
    See: `unit3_visualization.ipynb` for a visualization of the Unit3 manifold.
    """
    __slots__ = ["data"]
    def __repr__(self):
        # type: () -> str
        return "<{} {}>".format(self.__class__.__name__, self.data)
    # --------------------------------------------------------------------------
    # Handwritten methods included from "custom_methods/unit3.py.jinja"
    # --------------------------------------------------------------------------
    def __init__(self, vec):
        # type: (T.Union[T.Sequence[float], numpy.ndarray]) -> None
        if isinstance(vec, numpy.ndarray):
            if vec.shape in {(3, 1), (1, 3)}:
                vec = vec.flatten()
            elif vec.shape != (3,):
                raise IndexError(
                    "Expected vec to be a vector of length 3; instead had shape {}".format(
                        vec.shape
                    )
                )
        elif len(vec) != 3:
            raise IndexError(
                "Expected vec to be a sequence of length 3, was instead length {}.".format(len(vec))
            )
        self.data = list(vec)  # type: T.List[float]
[docs]    @classmethod
    def random(cls, epsilon=1e-8):
        # type: (float) -> Unit3
        """
        Return a random :class:`Unit3` object
        """
        u1 = random.uniform(0, 1)
        u2 = random.uniform(0, 1)
        return Unit3.random_from_uniform_samples(u1, u2, epsilon=epsilon) 
    # --------------------------------------------------------------------------
    # Custom generated methods
    # --------------------------------------------------------------------------
[docs]    def basis(self, epsilon):
        # type: (Unit3, float) -> numpy.ndarray
        """
        Returns a :class:`Matrix32` with the basis vectors of the tangent space (in R^3) at the
        current Unit3 direction.
        """
        # Total ops: 50
        # Input arrays
        _self = self.data
        # Intermediate terms (12)
        _tmp0 = _self[1] ** 2
        _tmp1 = max(
            0,
            -(
                0.0
                if _self[2] ** 2 + _tmp0 - 10 * epsilon * math.copysign(1, _self[0]) == 0
                else math.copysign(
                    1, _self[2] ** 2 + _tmp0 - 10 * epsilon * math.copysign(1, _self[0])
                )
            ),
        )
        _tmp2 = 1 - _tmp1
        _tmp3 = _self[0] - 1
        _tmp4 = _self[2] + epsilon * math.copysign(1, _self[2])
        _tmp5 = _tmp4**2
        _tmp6 = _tmp0 + _tmp5
        _tmp7 = 2 / (_tmp3**2 + _tmp6)
        _tmp8 = 2 / _tmp6
        _tmp9 = _tmp2 * _tmp4 * _tmp7
        _tmp10 = _self[1] * _tmp9
        _tmp11 = _self[1] * _tmp1 * _tmp4 * _tmp8
        # Output terms
        _res = numpy.zeros((3, 2))
        _res[0, 0] = _self[1] * _tmp2 * _tmp3 * _tmp7
        _res[1, 0] = -_tmp1 * (-_tmp0 * _tmp8 + 1) - _tmp2 * (-_tmp0 * _tmp7 + 1)
        _res[2, 0] = _tmp10 + _tmp11
        _res[0, 1] = -_tmp3 * _tmp9
        _res[1, 1] = -_tmp10 - _tmp11
        _res[2, 1] = _tmp1 * (-_tmp5 * _tmp8 + 1) + _tmp2 * (-_tmp5 * _tmp7 + 1)
        return _res 
[docs]    def to_unit_vector(self):
        # type: (Unit3) -> numpy.ndarray
        """
        Returns a :class:`Vector3` version of the unit direction.
        """
        # Total ops: 0
        # Input arrays
        _self = self.data
        # Intermediate terms (0)
        # Output terms
        _res = numpy.zeros(3)
        _res[0] = _self[0]
        _res[1] = _self[1]
        _res[2] = _self[2]
        return _res 
[docs]    @staticmethod
    def from_vector(a, epsilon):
        # type: (numpy.ndarray, float) -> Unit3
        """
        Return a :class:`Unit3` that points along the direction of vector ``a``
        ``a`` will be normalized.
        """
        # Total ops: 10
        # Input arrays
        if a.shape == (3,):
            a = a.reshape((3, 1))
        elif a.shape != (3, 1):
            raise IndexError(
                "a is expected to have shape (3, 1) or (3,); instead had shape {}".format(a.shape)
            )
        # Intermediate terms (1)
        _tmp0 = 1 / math.sqrt(a[0, 0] ** 2 + a[1, 0] ** 2 + a[2, 0] ** 2 + epsilon)
        # Output terms
        _res = [0.0] * 3
        _res[0] = _tmp0 * a[0, 0]
        _res[1] = _tmp0 * a[1, 0]
        _res[2] = _tmp0 * a[2, 0]
        return Unit3.from_storage(_res) 
[docs]    @staticmethod
    def from_unit_vector(a):
        # type: (numpy.ndarray) -> Unit3
        """
        Return a :class:`Unit3` that points along the direction of vector ``a``
        ``a`` is expected to be a unit vector.
        """
        # Total ops: 0
        # Input arrays
        if a.shape == (3,):
            a = a.reshape((3, 1))
        elif a.shape != (3, 1):
            raise IndexError(
                "a is expected to have shape (3, 1) or (3,); instead had shape {}".format(a.shape)
            )
        # Intermediate terms (0)
        # Output terms
        _res = [0.0] * 3
        _res[0] = a[0, 0]
        _res[1] = a[1, 0]
        _res[2] = a[2, 0]
        return Unit3.from_storage(_res) 
    # --------------------------------------------------------------------------
    # StorageOps concept
    # --------------------------------------------------------------------------
[docs]    @staticmethod
    def storage_dim():
        # type: () -> int
        return 3 
[docs]    def to_storage(self):
        # type: () -> T.List[float]
        return list(self.data) 
[docs]    @classmethod
    def from_storage(cls, vec):
        # type: (T.Sequence[float]) -> Unit3
        instance = cls.__new__(cls)
        if isinstance(vec, list):
            instance.data = vec
        else:
            instance.data = list(vec)
        if len(vec) != cls.storage_dim():
            raise ValueError(
                "{} has storage dim {}, got {}.".format(cls.__name__, cls.storage_dim(), len(vec))
            )
        return instance 
    # --------------------------------------------------------------------------
    # LieGroupOps concept
    # --------------------------------------------------------------------------
[docs]    @staticmethod
    def tangent_dim():
        # type: () -> int
        return 2 
[docs]    def retract(self, vec, epsilon=1e-8):
        # type: (numpy.ndarray, float) -> Unit3
        if len(vec) != self.tangent_dim():
            raise ValueError(
                "Vector dimension ({}) not equal to tangent space dimension ({}).".format(
                    len(vec), self.tangent_dim()
                )
            )
        return ops.LieGroupOps.retract(self, vec, epsilon) 
[docs]    def local_coordinates(self, b, epsilon=1e-8):
        # type: (Unit3, float) -> numpy.ndarray
        return ops.LieGroupOps.local_coordinates(self, b, epsilon) 
[docs]    def interpolate(self, b, alpha, epsilon=1e-8):
        # type: (Unit3, float, float) -> Unit3
        return ops.LieGroupOps.interpolate(self, b, alpha, epsilon) 
    # --------------------------------------------------------------------------
    # General Helpers
    # --------------------------------------------------------------------------
    def __eq__(self, other):
        # type: (T.Any) -> bool
        if isinstance(other, Unit3):
            return self.data == other.data
        else:
            return False