Source code for symforce.test_util.storage_ops_test_mixin
# ----------------------------------------------------------------------------# SymForce - Copyright 2022, Skydio, Inc.# This source code is under the Apache 2.0 license found in the LICENSE file.# ----------------------------------------------------------------------------importnumpyasnpimportsymforce.symbolicassffromsymforceimporttypingasTfromsymforce.opsimportStorageOpsfromsymforce.test_utilimportTestCaseifT.TYPE_CHECKING:_Base=TestCaseelse:_Base=object
[docs]classStorageOpsTestMixin(_Base):""" Test helper for the StorageOps concept. Inherit a test case from this. """
[docs]@classmethoddefelement(cls)->T.Any:""" Overriden by child to provide an example non-identity element. """raiseNotImplementedError()
[docs]@classmethoddefelement_type(cls)->T.Type:""" Returns the type of the StorageOps-compatible class being tested. """returntype(cls.element())
[docs]deftest_storage_ops(self)->None:""" Tests: - storage_dim - to_storage - from_storage """# Check sane storage dimensionelement=self.element()storage_dim=StorageOps.storage_dim(element)self.assertGreater(storage_dim,0)# Create from listvec=np.random.normal(size=(storage_dim,)).tolist()value=StorageOps.from_storage(element,vec)self.assertEqual(type(value),self.element_type())# Serialize to listvec2=StorageOps.to_storage(value)self.assertEqual(len(vec2),storage_dim)self.assertListEqual(vec,vec2)# Build from list againvalue2=StorageOps.from_storage(value,vec2)# Check equalitiesself.assertEqual(value,value2)vec2[0]=10000.0self.assertNotEqual(value,StorageOps.from_storage(value,vec2))# Exercise printingself.assertGreater(len(str(value)),0)# Test symbolic operationssym_element=StorageOps.symbolic(element,"name")self.assertEqual(sym_element,StorageOps.subs(sym_element,{sf.Symbol("var_not_in_element"):sf.Symbol("new_var")}),)self.assertEqual(sym_element,StorageOps.simplify(sym_element))withself.assertRaises(ValueError):StorageOps.subs(sym_element,StorageOps.to_storage(sf.Symbol("var_not_in_element")),StorageOps.to_storage(sf.Symbol("new_var"))+[0.0],)