Source code for symforce.test_util.storage_ops_test_mixin

# ----------------------------------------------------------------------------
# SymForce - Copyright 2022, Skydio, Inc.
# This source code is under the Apache 2.0 license found in the LICENSE file.
# ----------------------------------------------------------------------------

import numpy as np

import symforce.symbolic as sf
from symforce import typing as T
from symforce.ops import StorageOps
from symforce.test_util import TestCase

if T.TYPE_CHECKING:
    _Base = TestCase
else:
    _Base = object


[docs]class StorageOpsTestMixin(_Base): """ Test helper for the StorageOps concept. Inherit a test case from this. """
[docs] @classmethod def element(cls) -> T.Any: """ Overriden by child to provide an example non-identity element. """ raise NotImplementedError()
[docs] @classmethod def element_type(cls) -> T.Type: """ Returns the type of the StorageOps-compatible class being tested. """ return type(cls.element())
[docs] def test_storage_ops(self) -> None: """ Tests: - storage_dim - to_storage - from_storage """ # Check sane storage dimension element = self.element() storage_dim = StorageOps.storage_dim(element) self.assertGreater(storage_dim, 0) # Create from list vec = np.random.normal(size=(storage_dim,)).tolist() value = StorageOps.from_storage(element, vec) self.assertEqual(type(value), self.element_type()) # Serialize to list vec2 = StorageOps.to_storage(value) self.assertEqual(len(vec2), storage_dim) self.assertListEqual(vec, vec2) # Build from list again value2 = StorageOps.from_storage(value, vec2) # Check equalities self.assertEqual(value, value2) vec2[0] = 10000.0 self.assertNotEqual(value, StorageOps.from_storage(value, vec2)) # Exercise printing self.assertGreater(len(str(value)), 0) # Test symbolic operations sym_element = StorageOps.symbolic(element, "name") self.assertEqual( sym_element, StorageOps.subs(sym_element, {sf.Symbol("var_not_in_element"): sf.Symbol("new_var")}), ) self.assertEqual(sym_element, StorageOps.simplify(sym_element)) with self.assertRaises(ValueError): StorageOps.subs( sym_element, StorageOps.to_storage(sf.Symbol("var_not_in_element")), StorageOps.to_storage(sf.Symbol("new_var")) + [0.0], )