Source code for symforce.ops.impl.array_storage_ops

# ----------------------------------------------------------------------------
# SymForce - Copyright 2022, Skydio, Inc.
# This source code is under the Apache 2.0 license found in the LICENSE file.
# ----------------------------------------------------------------------------

import numpy as np

from symforce import typing as T
from symforce.ops import StorageOps


[docs]class ArrayStorageOps: """ Implements Storage operations for numpy ndarrays. """
[docs] @staticmethod def storage_dim(a: T.ArrayElementOrType) -> int: # NOTE(brad): Must take T.ArrayElementOrType to match AbstractStorageOps assert isinstance(a, np.ndarray) return a.size
[docs] @staticmethod def to_storage(a: T.ArrayElement) -> T.List[T.Scalar]: # NOTE(brad): I have the T.cast because mypy thinks the values of np.nditer are tuples. return [ T.cast(np.ndarray, scalar)[()] for scalar in np.nditer(a, order="F", flags=["refs_ok"]) ]
[docs] @staticmethod def from_storage(a: T.ArrayElementOrType, elements: T.Sequence[T.Scalar]) -> T.ArrayElement: # NOTE(brad): Must take T.ArrayElementOrType to match AbstractStorageOps assert isinstance(a, np.ndarray) assert len(elements) == ArrayStorageOps.storage_dim(a) return np.array(elements).reshape(tuple(reversed(a.shape))).transpose()
[docs] @staticmethod def symbolic(a: T.ArrayElementOrType, name: str, **kwargs: T.Dict) -> T.ArrayElement: # NOTE(brad): Must take T.ArrayElementOrType to match AbstractStorageOps assert isinstance(a, np.ndarray) return np.array( [StorageOps.symbolic(v, f"{name}_{i}", **kwargs) for i, v in enumerate(a)] ).reshape(a.shape)