Source code for symforce.test_util.test_case_mixin

# ----------------------------------------------------------------------------
# SymForce - Copyright 2022, Skydio, Inc.
# This source code is under the Apache 2.0 license found in the LICENSE file.
# ----------------------------------------------------------------------------

import difflib
import logging
import re
import sys
import tempfile
import unittest
from pathlib import Path

import numpy as np

import symforce.symbolic as sf
from symforce import logger
from symforce import python_util
from symforce import typing as T
from symforce.codegen import codegen_config
from symforce.ops import LieGroupOps
from symforce.ops import StorageOps
from symforce.ops import interfaces


[docs]class SymforceTestCaseMixin(unittest.TestCase): """ Mixin for SymForce tests, adds useful helpers for code generation """ LieGroupOpsType = T.Union[interfaces.LieGroup, sf.Scalar] # Set by the --update flag to tell tests that compare against some saved # data to update that data instead of failing _UPDATE = False KEEP_PATHS = [ r".*/__pycache__/.*", r".*\.pyc", ]
[docs] @staticmethod def should_update() -> bool: # NOTE(aaron): This needs to be accessible before main() is called, so we do it here # instead. This should also be called from main to make sure it runs at least once if "--update" in sys.argv: SymforceTestCaseMixin._UPDATE = True sys.argv.remove("--update") return SymforceTestCaseMixin._UPDATE
def __init__(self, methodName: str = "runTest") -> None: super().__init__(methodName) # Registers assertArrayEqual with python unittest TestCase such that we use numpy array # comparison functions rather than the "==" operator, which throws an error for ndarrays self.addTypeEqualityFunc(np.ndarray, SymforceTestCaseMixin.assertArrayEqual)
[docs] @staticmethod def main(*args: T.Any, **kwargs: T.Any) -> None: """ Call this to run all tests in scope. """ SymforceTestCaseMixin.should_update() unittest.main(*args, **kwargs)
[docs] @staticmethod def assertStorageNear( actual: T.Any, desired: T.Any, *, places: int = 7, msg: str = "", verbose: bool = True ) -> None: """ Check that two elements are close. Handles sequences, scalars, and geometry types using StorageOps. """ return np.testing.assert_almost_equal( actual=np.array(StorageOps.evalf(StorageOps.to_storage(actual)), dtype=np.double), desired=np.array(StorageOps.evalf(StorageOps.to_storage(desired)), dtype=np.double), decimal=places, err_msg=msg, verbose=verbose, )
[docs] @staticmethod def assertLieGroupNear( actual: LieGroupOpsType, desired: LieGroupOpsType, *, places: int = 7, msg: str = "", verbose: bool = True, ) -> None: """ Check that two LieGroup elements are close. """ epsilon = 10 ** (-max(9, places + 1)) # Compute the tangent space perturbation around `actual` that produces `desired` local_coordinates = LieGroupOps.local_coordinates(actual, desired, epsilon=epsilon) # Compute the identity tangent space perturbation to compare against identity = sf.Matrix.zeros(LieGroupOps.tangent_dim(actual), 1) return np.testing.assert_almost_equal( actual=StorageOps.evalf(local_coordinates), desired=StorageOps.to_storage(identity), decimal=places, err_msg=msg, verbose=verbose, )
[docs] @staticmethod def assertArrayEqual(actual: T.ArrayElement, desired: T.ArrayElement, msg: str = "") -> None: """ Called by unittest base class when comparing ndarrays when "assertEqual" is called. By default, "assertEqual" uses the "==" operator, which is not implemented for ndarrays. """ return np.testing.assert_array_equal(actual, desired, err_msg=msg)
[docs] def assertNotEqual(self, first: T.Any, second: T.Any, msg: T.Optional[str] = "") -> None: """ Overrides unittest.assertNotEqual to handle ndarrays separately. "assertNotEqual" uses the "!=" operator, but this is not implemented for ndarrays. Instead, we check that np.testing.assert_array_equal raises an assertion error, as numpy testing does not provide a assert_array_not_equal function. Note that assertNotEqual does not work like assertEqual in unittest. Rather than allowing you to register a custom equality evaluator (e.g. with ``addTypeEqualityFunc()``), assertNotEqual assumes the "!=" can be used with the arguments regardless of type. """ if isinstance(first, np.ndarray): return np.testing.assert_raises( AssertionError, np.testing.assert_array_equal, first, second, err_msg=msg or "" ) else: return super().assertNotEqual(first, second, msg)
[docs] def make_output_dir( self, prefix: T.Optional[str] = None, directory: Path = Path("/tmp") ) -> Path: """ Create a temporary output directory, which will be automatically removed (regardless of exceptions) on shutdown, unless logger.level is DEBUG Args: prefix: The prefix for the directory name - a random unique identifier is added to this. Defaults to the name of the test, in snake_case dir: Location of the output directory. Defaults to "/tmp". Returns: str: The absolute path to the created output directory """ if prefix is None: prefix = python_util.camelcase_to_snakecase(self.__class__.__name__) output_dir = Path(tempfile.mkdtemp(prefix=prefix, dir=directory)) logger.debug(f"Creating temp directory: {output_dir}") self.output_dirs.append(output_dir) return output_dir
[docs] def setUp(self) -> None: """ Creates list of temporary directories that will be removed before shutdown (unless debug mode is on) """ super().setUp() # Set to fail on default epsilon == 0 codegen_config.DEFAULT_ZERO_EPSILON_BEHAVIOR = codegen_config.ZeroEpsilonBehavior.FAIL # Storage for temporary output directories self.output_dirs: T.List[Path] = []
[docs] def tearDown(self) -> None: """ Removes temporary output directories (unless debug mode is on) """ super().tearDown() if logger.level != logging.DEBUG: for output_dir in self.output_dirs: python_util.remove_if_exists(output_dir)
[docs] def compare_or_update(self, path: T.Openable, data: str) -> None: """ Compare the given data to what is saved in path, OR update the saved data if the ``--update`` flag was passed to the test. """ path = Path(path) if self.should_update(): logger.debug(f'Updating data at: "{path}"') dirname = path.parent if dirname and not dirname.exists(): dirname.mkdir(parents=True) path.write_text(data) else: logger.debug(f'Comparing data at: "{path}"') expected_data = path.read_text() if data != expected_data: diff = difflib.unified_diff( expected_data.splitlines(keepends=True), data.splitlines(keepends=True), "expected", "got", ) self.fail( "\n" + "".join(diff) + f"\n\n{80*'='}\nData did not match for file {path}, see diff above. Use " "`--update` to write the changes to the working directory and commit if desired" )
[docs] def compare_or_update_file(self, path: T.Openable, new_file: T.Openable) -> None: self.compare_or_update(path, Path(new_file).read_text())
def _filtered_paths_in_dir(self, directory: T.Openable) -> T.List[str]: """ Find the list of paths in a directory not in KEEP_PATHS, recursively. The result is in sorted order """ keep_regex = re.compile("|".join(self.KEEP_PATHS)) files_in_dir = python_util.files_in_dir(directory, relative=True) return sorted(path for path in files_in_dir if not re.match(keep_regex, path))
[docs] def compare_or_update_directory(self, actual_dir: T.Openable, expected_dir: T.Openable) -> None: """ Check the contents of actual_dir match expected_dir, OR update the expected directory if the ``--update`` flag was passed to the test. """ actual_dir = Path(actual_dir) expected_dir = Path(expected_dir) logger.debug(f'Comparing directories: actual="{actual_dir}", expected="{expected_dir}"') actual_paths = self._filtered_paths_in_dir(actual_dir) expected_paths = self._filtered_paths_in_dir(expected_dir) if not self.should_update(): # If checking, make sure all file paths are the same self.assertSequenceEqual(actual_paths, expected_paths) else: # If updating, remove any expected files not in actual for only_in_expected in set(expected_paths).difference(set(actual_paths)): (expected_dir / only_in_expected).unlink() for path in actual_paths: self.compare_or_update_file(expected_dir / path, actual_dir / path)