Source code for symforce.caspar.memory.special_square_matrices
# ----------------------------------------------------------------------------
# SymForce - Copyright 2025, Skydio, Inc.
# This source code is under the Apache 2.0 license found in the LICENSE file.
# ----------------------------------------------------------------------------
from __future__ import annotations
from abc import abstractmethod
import symforce.symbolic as sf
from symforce import typing as T
from symforce.ops.interfaces import Storage
SpecialSquareMatT = T.TypeVar("SpecialSquareMatT", bound="_SpecialSquareMatrix")
class _SpecialSquareMatrix(Storage):
"""Special square matrix"""
storage: T.List[sf.Scalar]
SHAPE: T.ClassVar[int] = -1
@classmethod
def get_class(cls, shape: int) -> T.Type[_SpecialSquareMatrix]:
return type(f"{cls.__name__}{shape}", (cls,), {"SHAPE": shape})
def __new__(cls: T.Type[SpecialSquareMatT], matrix: sf.Matrix) -> SpecialSquareMatT:
if cls.SHAPE == -1:
sizedT = type(f"{cls.__name__}{matrix.shape[0]}", (cls,), {"SHAPE": matrix.shape[0]})
return T.cast(SpecialSquareMatT, super().__new__(sizedT))
return super().__new__(cls)
@classmethod
def check_size(cls, mat: sf.Matrix) -> None:
if (target := (cls.SHAPE, cls.SHAPE)) != mat.SHAPE:
raise ValueError(f"Expected matrix of dim {target}, got shape {mat.SHAPE}")
def __repr__(self) -> str:
return self.mat().__repr__()
def to_storage(self) -> T.List[T.Scalar]:
"""
Flat list representation of the underlying storage, length of :meth:`storage_dim`.
This is used purely for plumbing, it is NOT like a tangent space.
"""
return self.storage.copy()
@classmethod
def from_storage(
cls: T.Type[SpecialSquareMatT], elements: T.Sequence[T.Scalar]
) -> SpecialSquareMatT:
"""
Construct from a flat list representation. Opposite of :meth:`to_storage`.
"""
if cls.SHAPE == -1:
raise ValueError("SHAPE not set")
if len(elements) != cls.storage_dim():
raise ValueError(
f"Expected {cls.storage_dim()} elements for {cls}, got {len(elements)}"
)
instance = object.__new__(cls)
instance.storage = list(elements)
return instance
@abstractmethod
def mat(self) -> sf.Matrix:
raise NotImplementedError
@classmethod
@abstractmethod
def storage_dim(cls) -> int:
raise NotImplementedError
[docs]
class LowerTriangularMatrix(_SpecialSquareMatrix):
"""Lower triangular matrix, including the diagonal"""
def __init__(self, mat: sf.Matrix):
self.check_size(mat)
self.storage = []
for i in range(self.SHAPE):
for j in range(i + 1):
if i != j and mat[j, i] != 0:
raise ValueError("Matrix is not lower triangular")
self.storage.append(mat[i, j])
assert self.mat() == mat
[docs]
@classmethod
def lower_storage_indices(cls) -> tuple[int, ...]:
out = []
for i in range(cls.SHAPE):
for j in range(i + 1):
out.append(i + j * cls.SHAPE)
return tuple(out)
[docs]
def mat(self) -> sf.Matrix:
mat = sf.Matrix(self.SHAPE, self.SHAPE)
offset = 0
for i in range(self.SHAPE):
for j in range(i + 1):
mat[i, j] = self.storage[offset]
offset += 1
return mat
[docs]
@classmethod
def storage_dim(cls) -> int:
return cls.SHAPE * (cls.SHAPE + 1) // 2
[docs]
class UpperTriangularMatrix(_SpecialSquareMatrix):
"""Upper triangular matrix, including the diagonal"""
def __init__(self, mat: sf.Matrix):
self.check_size(mat)
self.storage = []
for i in range(self.SHAPE):
for j in range(i, self.SHAPE):
if i != j and mat[j, i] != 0:
raise ValueError("Matrix is not upper triangular")
self.storage.append(mat[i, j])
assert self.mat() == mat
[docs]
@classmethod
def upper_storage_indices(cls) -> tuple[int, ...]:
out = []
for i in range(cls.SHAPE):
for j in range(i, cls.SHAPE):
out.append(i + j * cls.SHAPE)
return tuple(out)
[docs]
def mat(self) -> sf.Matrix:
mat = sf.Matrix(self.SHAPE, self.SHAPE)
offset = 0
for i in range(self.SHAPE):
for j in range(i, self.SHAPE):
mat[i, j] = self.storage[offset]
offset += 1
return mat
[docs]
@classmethod
def storage_dim(cls) -> int:
return cls.SHAPE * (cls.SHAPE + 1) // 2
[docs]
class SymmetricMatrix(UpperTriangularMatrix):
"""Symmetrical matrix"""
def __init__(self, mat: sf.Matrix):
self.check_size(mat)
self.storage = []
for i in range(self.SHAPE):
for j in range(i, self.SHAPE):
if i != j and mat[i, j] != mat[j, i]:
raise ValueError("Matrix is not symmetric")
self.storage.append(mat[i, j])
assert self.mat() == mat
[docs]
def mat(self) -> sf.Matrix:
mat = super().mat()
for i in range(self.SHAPE):
for j in range(i):
mat[i, j] = mat[j, i]
return mat
[docs]
class LMat11(LowerTriangularMatrix):
SHAPE = 1
[docs]
class LMat22(LowerTriangularMatrix):
SHAPE = 2
[docs]
class LMat33(LowerTriangularMatrix):
SHAPE = 3
[docs]
class LMat44(LowerTriangularMatrix):
SHAPE = 4
[docs]
class LMat55(LowerTriangularMatrix):
SHAPE = 5
[docs]
class LMat66(LowerTriangularMatrix):
SHAPE = 6
[docs]
class LMat77(LowerTriangularMatrix):
SHAPE = 7
[docs]
class LMat88(LowerTriangularMatrix):
SHAPE = 8
[docs]
class LMat99(LowerTriangularMatrix):
SHAPE = 9
[docs]
class UMat11(UpperTriangularMatrix):
SHAPE = 1
[docs]
class UMat22(UpperTriangularMatrix):
SHAPE = 2
[docs]
class UMat33(UpperTriangularMatrix):
SHAPE = 3
[docs]
class UMat44(UpperTriangularMatrix):
SHAPE = 4
[docs]
class UMat55(UpperTriangularMatrix):
SHAPE = 5
[docs]
class UMat66(UpperTriangularMatrix):
SHAPE = 6
[docs]
class UMat77(UpperTriangularMatrix):
SHAPE = 7
[docs]
class UMat88(UpperTriangularMatrix):
SHAPE = 8
[docs]
class UMat99(UpperTriangularMatrix):
SHAPE = 9
[docs]
class SMat11(SymmetricMatrix):
SHAPE = 1
[docs]
class SMat22(SymmetricMatrix):
SHAPE = 2
[docs]
class SMat33(SymmetricMatrix):
SHAPE = 3
[docs]
class SMat44(SymmetricMatrix):
SHAPE = 4
[docs]
class SMat55(SymmetricMatrix):
SHAPE = 5
[docs]
class SMat66(SymmetricMatrix):
SHAPE = 6
[docs]
class SMat77(SymmetricMatrix):
SHAPE = 7
[docs]
class SMat88(SymmetricMatrix):
SHAPE = 8
[docs]
class SMat99(SymmetricMatrix):
SHAPE = 9