Source code for symforce.caspar.memory.pair

# ----------------------------------------------------------------------------
# SymForce - Copyright 2025, Skydio, Inc.
# This source code is under the Apache 2.0 license found in the LICENSE file.
# ----------------------------------------------------------------------------

from __future__ import annotations

from textwrap import indent

import symforce.symbolic as sf
from symforce import jacobian_helpers
from symforce import typing as T
from symforce.ops import StorageOps as Ops
from symforce.ops.interfaces import Storage

storage_t = T.TypeVar("storage_t", bound=T.Storable)


[docs] class Pair(T.Generic[storage_t]): def __init__(self, first: storage_t, second: storage_t): if type(first) is not type(second): raise ValueError("First and second must be of the same type") self.data = [first, second] def __iter__(self) -> T.Iterable[storage_t]: return iter(self.data) def __getitem__(self, idx: int) -> storage_t: return self.data[idx] def __repr__(self) -> str: return f"Pair(\n{indent(repr(self[0]), ' ')},\n{indent(repr(self[1]), ' ')}\n)"
[docs] def is_pair(thing: T.Union[T.Type[Storage], T.Type[Pair], Storage, Pair]) -> bool: return thing is Pair or T.get_origin(thing) is Pair or isinstance(thing, Pair)
[docs] def get_symbolic(storage_or_pair: T.Union[Storage, Pair], name: str) -> T.Union[Storage, Pair]: if is_pair(storage_or_pair): storage_t = get_memtype(storage_or_pair) return Pair(*(Ops.symbolic(storage_t, f"{name}_{k}") for k in ["first", "second"])) return Ops.symbolic(storage_or_pair, name)
[docs] def jacobians( fx: Storage, storage_or_pair: T.Union[Storage, Pair] ) -> T.Union[Pair[sf.Matrix], sf.Matrix]: diff = lambda fx, x: jacobian_helpers.tangent_jacobians(fx, [x])[0] if is_pair(storage_or_pair): storage_or_pair = T.cast(Pair, storage_or_pair) return Pair(diff(fx, storage_or_pair[0]), diff(fx, storage_or_pair[1])) return diff(fx, storage_or_pair)
[docs] def get_elements(storage_or_pair: T.Union[Storage, Pair]) -> T.List: if is_pair(storage_or_pair): storage_or_pair = T.cast(Pair, storage_or_pair) return [ *Ops.to_storage(storage_or_pair[0]), *Ops.to_storage(storage_or_pair[1]), ] return Ops.to_storage(storage_or_pair)
[docs] def get_memtype( storage_or_pair: T.Union[T.Type[Storage], T.Type[Pair], Storage, Pair], ) -> T.Type[Storage]: if is_pair(storage_or_pair): if storage_or_pair is Pair: raise ValueError("Cannot get type from unannotated Pair") elif T.get_origin(storage_or_pair) is Pair: return T.get_args(storage_or_pair)[0] elif isinstance(storage_or_pair, Pair): return storage_or_pair.data[0].__class__ if isinstance(storage_or_pair, type): storage_or_pair = T.cast(T.Type[Storage], storage_or_pair) return storage_or_pair storage_or_pair = T.cast(Storage, storage_or_pair) return storage_or_pair.__class__