# -----------------------------------------------------------------------------
# This file was autogenerated by symforce from template:
#     geo_package/CLASS.py.jinja
# Do NOT modify by hand.
# -----------------------------------------------------------------------------
# ruff: noqa: PLR0915, F401, PLW0211, PLR0914
import math
import random
import typing as T
import numpy
# isort: split
from .ops import rot3 as ops
[docs]class Rot3(object):
    """
    Autogenerated Python implementation of :py:class:`symforce.geo.rot3.Rot3`.
    Group of three-dimensional orthogonal matrices with determinant ``+1``, representing
    rotations in 3D space. Backed by a quaternion with (x, y, z, w) storage.
    """
    __slots__ = ["data"]
    def __repr__(self):
        # type: () -> str
        return "<{} {}>".format(self.__class__.__name__, self.data)
    # --------------------------------------------------------------------------
    # Handwritten methods included from "custom_methods/rot3.py.jinja"
    # --------------------------------------------------------------------------
    def __init__(self, q=None):
        # type: (T.Union[T.Sequence[float], numpy.ndarray, None]) -> None
        if q is None:
            self.data = ops.GroupOps.identity().data  # type: T.List[float]
        else:
            if isinstance(q, numpy.ndarray):
                if q.shape in {(4, 1), (1, 4)}:
                    q = q.flatten()
                elif q.shape != (4,):
                    raise IndexError(
                        "Expected q to be a vector of length 4; instead had shape {}".format(
                            q.shape
                        )
                    )
            elif len(q) != 4:
                raise IndexError(
                    "Expected q to be a sequence of length 4, was instead length {}.".format(len(q))
                )
            self.data = list(q)
[docs]    @classmethod
    def from_rotation_matrix(cls, R, epsilon=0.0):
        # type: (numpy.ndarray, float) -> Rot3
        """
        This implementation is based on Shepperd's method (1978)
        https://arc.aiaa.org/doi/abs/10.2514/3.55767b?journalCode=jgc (this is paywalled)
        See the introduction of these papers for a description of the method:
        - https://digital.csic.es/bitstream/10261/179990/1/Accurate%20Computation_Sarabandi.pdf
        - https://arc.aiaa.org/doi/abs/10.2514/1.31730?journalCode=jgcd
        """
        assert R.shape == (3, 3)
        trace = R[0, 0] + R[1, 1] + R[2, 2]
        # trace is larger than any of the diagonal elements
        if trace > R[0, 0] and trace > R[1, 1] and trace > R[2, 2]:
            w = numpy.sqrt(1.0 + trace) / 2.0
            x = (R[2, 1] - R[1, 2]) / (4.0 * w)
            y = (R[0, 2] - R[2, 0]) / (4.0 * w)
            z = (R[1, 0] - R[0, 1]) / (4.0 * w)
        # largest diagonal element is R[0,0]
        elif R[0, 0] > R[1, 1] and R[0, 0] > R[2, 2]:
            x = numpy.sqrt(max(epsilon**2, 1.0 + R[0, 0] - R[1, 1] - R[2, 2])) / 2.0
            y = (R[0, 1] + R[1, 0]) / (4.0 * x)
            z = (R[0, 2] + R[2, 0]) / (4.0 * x)
            w = (R[2, 1] - R[1, 2]) / (4.0 * x)
        # largest diagonal element is R[1,1]
        elif R[1, 1] > R[2, 2]:
            y = numpy.sqrt(max(epsilon**2, 1.0 + R[1, 1] - R[0, 0] - R[2, 2])) / 2.0
            x = (R[0, 1] + R[1, 0]) / (4.0 * y)
            z = (R[1, 2] + R[2, 1]) / (4.0 * y)
            w = (R[0, 2] - R[2, 0]) / (4.0 * y)
        # largest diagonal element is R[2,2]
        else:
            z = numpy.sqrt(max(epsilon**2, 1.0 + R[2, 2] - R[0, 0] - R[1, 1])) / 2.0
            x = (R[0, 2] + R[2, 0]) / (4.0 * z)
            y = (R[1, 2] + R[2, 1]) / (4.0 * z)
            w = (R[1, 0] - R[0, 1]) / (4.0 * z)
        return Rot3.from_storage([x, y, z, w]) 
[docs]    @classmethod
    def random(cls):
        # type: () -> Rot3
        return Rot3.random_from_uniform_samples(
            random.uniform(0, 1), random.uniform(0, 1), random.uniform(0, 1)
        ) 
    # --------------------------------------------------------------------------
    # Custom generated methods
    # --------------------------------------------------------------------------
[docs]    def compose_with_point(self, right):
        # type: (Rot3, numpy.ndarray) -> numpy.ndarray
        """
        Left-multiplication. Either rotation concatenation or point transform.
        """
        # Total ops: 43
        # Input arrays
        _self = self.data
        if right.shape == (3,):
            right = right.reshape((3, 1))
        elif right.shape != (3, 1):
            raise IndexError(
                "right is expected to have shape (3, 1) or (3,); instead had shape {}".format(
                    right.shape
                )
            )
        # Intermediate terms (11)
        _tmp0 = 2 * _self[0]
        _tmp1 = _self[1] * _tmp0
        _tmp2 = 2 * _self[2]
        _tmp3 = _self[3] * _tmp2
        _tmp4 = 2 * _self[1] * _self[3]
        _tmp5 = _self[2] * _tmp0
        _tmp6 = -2 * _self[1] ** 2
        _tmp7 = 1 - 2 * _self[2] ** 2
        _tmp8 = _self[3] * _tmp0
        _tmp9 = _self[1] * _tmp2
        _tmp10 = -2 * _self[0] ** 2
        # Output terms
        _res = numpy.zeros(3)
        _res[0] = (
            right[0, 0] * (_tmp6 + _tmp7)
            + right[1, 0] * (_tmp1 - _tmp3)
            + right[2, 0] * (_tmp4 + _tmp5)
        )
        _res[1] = (
            right[0, 0] * (_tmp1 + _tmp3)
            + right[1, 0] * (_tmp10 + _tmp7)
            + right[2, 0] * (-_tmp8 + _tmp9)
        )
        _res[2] = (
            right[0, 0] * (-_tmp4 + _tmp5)
            + right[1, 0] * (_tmp8 + _tmp9)
            + right[2, 0] * (_tmp10 + _tmp6 + 1)
        )
        return _res 
[docs]    def to_tangent_norm(self, epsilon):
        # type: (Rot3, float) -> float
        """
        Returns the norm of the tangent vector corresponding to this rotation
        This is equal to the angle that should be rotated through to get this Rot3, in radians.
        Using this function directly is usually more efficient than computing the norm of the
        tangent vector, both in symbolic and generated code; by default, symbolic APIs will not
        automatically simplify to this
        """
        # Total ops: 5
        # Input arrays
        _self = self.data
        # Intermediate terms (0)
        # Output terms
        _res = 2 * math.acos(min(abs(_self[3]), 1 - epsilon))
        return _res 
[docs]    def to_rotation_matrix(self):
        # type: (Rot3) -> numpy.ndarray
        """
        Converts to a rotation matrix
        """
        # Total ops: 28
        # Input arrays
        _self = self.data
        # Intermediate terms (11)
        _tmp0 = -2 * _self[1] ** 2
        _tmp1 = 1 - 2 * _self[2] ** 2
        _tmp2 = 2 * _self[0]
        _tmp3 = _self[1] * _tmp2
        _tmp4 = 2 * _self[2]
        _tmp5 = _self[3] * _tmp4
        _tmp6 = 2 * _self[1] * _self[3]
        _tmp7 = _self[2] * _tmp2
        _tmp8 = -2 * _self[0] ** 2
        _tmp9 = _self[3] * _tmp2
        _tmp10 = _self[1] * _tmp4
        # Output terms
        _res = numpy.zeros((3, 3))
        _res[0, 0] = _tmp0 + _tmp1
        _res[1, 0] = _tmp3 + _tmp5
        _res[2, 0] = -_tmp6 + _tmp7
        _res[0, 1] = _tmp3 - _tmp5
        _res[1, 1] = _tmp1 + _tmp8
        _res[2, 1] = _tmp10 + _tmp9
        _res[0, 2] = _tmp6 + _tmp7
        _res[1, 2] = _tmp10 - _tmp9
        _res[2, 2] = _tmp0 + _tmp8 + 1
        return _res 
[docs]    def to_yaw_pitch_roll(self):
        # type: (Rot3) -> numpy.ndarray
        """
        Compute the yaw, pitch, and roll Euler angles in radians of this rotation
        Euler angles are subject to gimbal lock: https://en.wikipedia.org/wiki/Gimbal_lock
        This means that when the pitch is close to +/- pi/2, the yaw and roll angles are not
        uniquely defined, so the returned values are not unique in this case.
        Returns:
            Scalar: Yaw angle [radians]
            Scalar: Pitch angle [radians]
            Scalar: Roll angle [radians]
        """
        # Total ops: 27
        # Input arrays
        _self = self.data
        # Intermediate terms (7)
        _tmp0 = 2 * _self[0]
        _tmp1 = 2 * _self[2]
        _tmp2 = _self[2] ** 2
        _tmp3 = _self[0] ** 2
        _tmp4 = -(_self[1] ** 2) + _self[3] ** 2
        _tmp5 = -_tmp2 + _tmp3 + _tmp4
        _tmp6 = _tmp2 - _tmp3 + _tmp4
        # Output terms
        _res = numpy.zeros(3)
        _res[0] = math.atan2(_self[1] * _tmp0 + _self[3] * _tmp1, _tmp5)
        _res[1] = -math.asin(max(-1, min(1, -2 * _self[1] * _self[3] + _self[2] * _tmp0)))
        _res[2] = math.atan2(_self[1] * _tmp1 + _self[3] * _tmp0, _tmp6)
        return _res 
[docs]    @staticmethod
    def from_yaw_pitch_roll(yaw, pitch, roll):
        # type: (float, float, float) -> Rot3
        """
        Construct from yaw, pitch, and roll Euler angles in radians
        """
        # Total ops: 25
        # Input arrays
        # Intermediate terms (13)
        _tmp0 = (1.0 / 2.0) * pitch
        _tmp1 = math.sin(_tmp0)
        _tmp2 = (1.0 / 2.0) * yaw
        _tmp3 = math.sin(_tmp2)
        _tmp4 = (1.0 / 2.0) * roll
        _tmp5 = math.cos(_tmp4)
        _tmp6 = _tmp3 * _tmp5
        _tmp7 = math.cos(_tmp0)
        _tmp8 = math.sin(_tmp4)
        _tmp9 = math.cos(_tmp2)
        _tmp10 = _tmp8 * _tmp9
        _tmp11 = _tmp3 * _tmp8
        _tmp12 = _tmp5 * _tmp9
        # Output terms
        _res = [0.0] * 4
        _res[0] = -_tmp1 * _tmp6 + _tmp10 * _tmp7
        _res[1] = _tmp1 * _tmp12 + _tmp11 * _tmp7
        _res[2] = -_tmp1 * _tmp10 + _tmp6 * _tmp7
        _res[3] = _tmp1 * _tmp11 + _tmp12 * _tmp7
        return Rot3.from_storage(_res) 
[docs]    @staticmethod
    def from_yaw(yaw):
        # type: (float) -> Rot3
        """Construct from yaw angle in radians"""
        # Total ops: 5
        # Input arrays
        # Intermediate terms (1)
        _tmp0 = (1.0 / 2.0) * yaw
        # Output terms
        _res = [0.0] * 4
        _res[0] = 0
        _res[1] = 0
        _res[2] = 1.0 * math.sin(_tmp0)
        _res[3] = 1.0 * math.cos(_tmp0)
        return Rot3.from_storage(_res) 
[docs]    @staticmethod
    def from_pitch(pitch):
        # type: (float) -> Rot3
        """Construct from pitch angle in radians"""
        # Total ops: 5
        # Input arrays
        # Intermediate terms (1)
        _tmp0 = (1.0 / 2.0) * pitch
        # Output terms
        _res = [0.0] * 4
        _res[0] = 0
        _res[1] = 1.0 * math.sin(_tmp0)
        _res[2] = 0
        _res[3] = 1.0 * math.cos(_tmp0)
        return Rot3.from_storage(_res) 
[docs]    @staticmethod
    def from_roll(roll):
        # type: (float) -> Rot3
        """Construct from roll angle in radians"""
        # Total ops: 5
        # Input arrays
        # Intermediate terms (1)
        _tmp0 = (1.0 / 2.0) * roll
        # Output terms
        _res = [0.0] * 4
        _res[0] = 1.0 * math.sin(_tmp0)
        _res[1] = 0
        _res[2] = 0
        _res[3] = 1.0 * math.cos(_tmp0)
        return Rot3.from_storage(_res) 
[docs]    @staticmethod
    def from_angle_axis(angle, axis):
        # type: (float, numpy.ndarray) -> Rot3
        """
        Construct from an angle in radians and a (normalized) axis as a 3-vector.
        """
        # Total ops: 6
        # Input arrays
        if axis.shape == (3,):
            axis = axis.reshape((3, 1))
        elif axis.shape != (3, 1):
            raise IndexError(
                "axis is expected to have shape (3, 1) or (3,); instead had shape {}".format(
                    axis.shape
                )
            )
        # Intermediate terms (2)
        _tmp0 = (1.0 / 2.0) * angle
        _tmp1 = math.sin(_tmp0)
        # Output terms
        _res = [0.0] * 4
        _res[0] = _tmp1 * axis[0, 0]
        _res[1] = _tmp1 * axis[1, 0]
        _res[2] = _tmp1 * axis[2, 0]
        _res[3] = math.cos(_tmp0)
        return Rot3.from_storage(_res) 
[docs]    @staticmethod
    def from_two_unit_vectors(a, b, epsilon):
        # type: (numpy.ndarray, numpy.ndarray, float) -> Rot3
        """
        Return a rotation that transforms a to b. Both inputs are three-vectors that
        are expected to be normalized.
        Reference:
            http://lolengine.net/blog/2013/09/18/beautiful-maths-quaternion-from-vectors
        """
        # Total ops: 44
        # Input arrays
        if a.shape == (3,):
            a = a.reshape((3, 1))
        elif a.shape != (3, 1):
            raise IndexError(
                "a is expected to have shape (3, 1) or (3,); instead had shape {}".format(a.shape)
            )
        if b.shape == (3,):
            b = b.reshape((3, 1))
        elif b.shape != (3, 1):
            raise IndexError(
                "b is expected to have shape (3, 1) or (3,); instead had shape {}".format(b.shape)
            )
        # Intermediate terms (7)
        _tmp0 = a[0, 0] * b[0, 0] + a[1, 0] * b[1, 0] + a[2, 0] * b[2, 0]
        _tmp1 = math.sqrt(2 * _tmp0 + epsilon + 2)
        _tmp2 = (
            0.0 if -epsilon + abs(_tmp0 + 1) == 0 else math.copysign(1, -epsilon + abs(_tmp0 + 1))
        ) + 1
        _tmp3 = (1.0 / 2.0) * _tmp2
        _tmp4 = _tmp3 / _tmp1
        _tmp5 = 1.0 / 2.0 - 1.0 / 2.0 * (
            0.0
            if a[1, 0] ** 2 + a[2, 0] ** 2 - epsilon**2 == 0
            else math.copysign(1, a[1, 0] ** 2 + a[2, 0] ** 2 - epsilon**2)
        )
        _tmp6 = 1 - _tmp3
        # Output terms
        _res = [0.0] * 4
        _res[0] = _tmp4 * (a[1, 0] * b[2, 0] - a[2, 0] * b[1, 0]) + _tmp6 * (1 - _tmp5)
        _res[1] = _tmp4 * (-a[0, 0] * b[2, 0] + a[2, 0] * b[0, 0]) + _tmp5 * _tmp6
        _res[2] = _tmp4 * (a[0, 0] * b[1, 0] - a[1, 0] * b[0, 0])
        _res[3] = (1.0 / 4.0) * _tmp1 * _tmp2
        return Rot3.from_storage(_res) 
    # --------------------------------------------------------------------------
    # StorageOps concept
    # --------------------------------------------------------------------------
[docs]    @staticmethod
    def storage_dim():
        # type: () -> int
        return 4 
[docs]    def to_storage(self):
        # type: () -> T.List[float]
        return list(self.data) 
[docs]    @classmethod
    def from_storage(cls, vec):
        # type: (T.Sequence[float]) -> Rot3
        instance = cls.__new__(cls)
        if isinstance(vec, list):
            instance.data = vec
        else:
            instance.data = list(vec)
        if len(vec) != cls.storage_dim():
            raise ValueError(
                "{} has storage dim {}, got {}.".format(cls.__name__, cls.storage_dim(), len(vec))
            )
        return instance 
    # --------------------------------------------------------------------------
    # GroupOps concept
    # --------------------------------------------------------------------------
[docs]    @classmethod
    def identity(cls):
        # type: () -> Rot3
        return ops.GroupOps.identity() 
[docs]    def inverse(self):
        # type: () -> Rot3
        return ops.GroupOps.inverse(self) 
[docs]    def compose(self, b):
        # type: (Rot3) -> Rot3
        return ops.GroupOps.compose(self, b) 
[docs]    def between(self, b):
        # type: (Rot3) -> Rot3
        return ops.GroupOps.between(self, b) 
    # --------------------------------------------------------------------------
    # LieGroupOps concept
    # --------------------------------------------------------------------------
[docs]    @staticmethod
    def tangent_dim():
        # type: () -> int
        return 3 
[docs]    @classmethod
    def from_tangent(cls, vec, epsilon=1e-8):
        # type: (numpy.ndarray, float) -> Rot3
        if len(vec) != cls.tangent_dim():
            raise ValueError(
                "Vector dimension ({}) not equal to tangent space dimension ({}).".format(
                    len(vec), cls.tangent_dim()
                )
            )
        return ops.LieGroupOps.from_tangent(vec, epsilon) 
[docs]    def to_tangent(self, epsilon=1e-8):
        # type: (float) -> numpy.ndarray
        return ops.LieGroupOps.to_tangent(self, epsilon) 
[docs]    def retract(self, vec, epsilon=1e-8):
        # type: (numpy.ndarray, float) -> Rot3
        if len(vec) != self.tangent_dim():
            raise ValueError(
                "Vector dimension ({}) not equal to tangent space dimension ({}).".format(
                    len(vec), self.tangent_dim()
                )
            )
        return ops.LieGroupOps.retract(self, vec, epsilon) 
[docs]    def local_coordinates(self, b, epsilon=1e-8):
        # type: (Rot3, float) -> numpy.ndarray
        return ops.LieGroupOps.local_coordinates(self, b, epsilon) 
[docs]    def interpolate(self, b, alpha, epsilon=1e-8):
        # type: (Rot3, float, float) -> Rot3
        return ops.LieGroupOps.interpolate(self, b, alpha, epsilon) 
    # --------------------------------------------------------------------------
    # General Helpers
    # --------------------------------------------------------------------------
    def __eq__(self, other):
        # type: (T.Any) -> bool
        if isinstance(other, Rot3):
            return self.data == other.data
        else:
            return False
    @T.overload
    def __mul__(self, other):  # pragma: no cover
        # type: (Rot3) -> Rot3
        pass
    @T.overload
    def __mul__(self, other):  # pragma: no cover
        # type: (numpy.ndarray) -> numpy.ndarray
        pass
    def __mul__(self, other):
        # type: (T.Union[Rot3, numpy.ndarray]) -> T.Union[Rot3, numpy.ndarray]
        if isinstance(other, Rot3):
            return self.compose(other)
        elif isinstance(other, numpy.ndarray) and hasattr(self, "compose_with_point"):
            return self.compose_with_point(other).reshape(other.shape)
        else:
            raise NotImplementedError("Cannot compose {} with {}.".format(type(self), type(other)))