Source code for symforce.ops.impl.array_storage_ops
# ----------------------------------------------------------------------------# SymForce - Copyright 2022, Skydio, Inc.# This source code is under the Apache 2.0 license found in the LICENSE file.# ----------------------------------------------------------------------------importnumpyasnpfromsymforceimporttypingasTfromsymforce.opsimportStorageOps
[docs]classArrayStorageOps:""" Implements Storage operations for numpy ndarrays. """
[docs]@staticmethoddefstorage_dim(a:T.ArrayElementOrType)->int:# NOTE(brad): Must take T.ArrayElementOrType to match AbstractStorageOpsassertisinstance(a,np.ndarray)returna.size
[docs]@staticmethoddefto_storage(a:T.ArrayElement)->T.List[T.Scalar]:# NOTE(brad): I have the T.cast because mypy thinks the values of np.nditer are tuples.return[T.cast(np.ndarray,scalar)[()]forscalarinnp.nditer(a,order="F",flags=["refs_ok"])]
[docs]@staticmethoddeffrom_storage(a:T.ArrayElementOrType,elements:T.Sequence[T.Scalar])->T.ArrayElement:# NOTE(brad): Must take T.ArrayElementOrType to match AbstractStorageOpsassertisinstance(a,np.ndarray)assertlen(elements)==ArrayStorageOps.storage_dim(a)returnnp.array(elements).reshape(tuple(reversed(a.shape))).transpose()
[docs]@staticmethoddefsymbolic(a:T.ArrayElementOrType,name:str,**kwargs:T.Dict)->T.ArrayElement:# NOTE(brad): Must take T.ArrayElementOrType to match AbstractStorageOpsassertisinstance(a,np.ndarray)returnnp.array([StorageOps.symbolic(v,f"{name}_{i}",**kwargs)fori,vinenumerate(a)]).reshape(a.shape)