Source code for symforce.test_util.test_case
# ----------------------------------------------------------------------------
# SymForce - Copyright 2022, Skydio, Inc.
# This source code is under the Apache 2.0 license found in the LICENSE file.
# ----------------------------------------------------------------------------
import random
import sys
import unittest
import numpy as np
import symforce
from symforce import typing as T
from symforce.test_util.test_case_mixin import SymforceTestCaseMixin
[docs]class TestCase(SymforceTestCaseMixin):
"""
Base class for symforce tests. Adds some useful helpers.
"""
# Set by the --run_slow_tests flag to indicate that we should run all tests even
# if we're on SymPy.
_RUN_SLOW_TESTS = False
[docs] @staticmethod
def should_run_slow_tests() -> bool:
# NOTE(aaron): This needs to be accessible before main() is called, so we do it here
# instead. This should also be called from main to make sure it runs at least once
if "--run_slow_tests" in sys.argv:
TestCase._RUN_SLOW_TESTS = True
sys.argv.remove("--run_slow_tests")
return TestCase._RUN_SLOW_TESTS
[docs] @staticmethod
def main(*args: T.Any, **kwargs: T.Any) -> None:
"""
Call this to run all tests in scope.
"""
TestCase.should_run_slow_tests()
# unittest on py3.12 exits with _NO_TESTS_EXITCODE = 5 if no tests are found. As of
# py3.12.1, this includes if all tests are skipped (which applies to some of our tests,
# depending on symbolic API and installed packages). So, we monkey patch unittest to exit
# with 0 in this case (including if no tests were found, which isn't great). This hack
# can be removed once the fix is in a release:
# https://github.com/python/cpython/commit/159e3db1f7697b9aecdf674bb833fbb87f3dcad3
if sys.version_info >= (3, 12, 0) and sys.version_info < (3, 12, 2):
sys.modules[unittest.main.__module__]._NO_TESTS_EXITCODE = 0 # type: ignore[attr-defined] # pylint: disable=protected-access
SymforceTestCaseMixin.main(*args, **kwargs)
[docs] def setUp(self) -> None:
super().setUp()
# Set random seeds
np.random.seed(42)
random.seed(42)
# Store verbosity flag so tests can use
self.verbose = ("-v" in sys.argv) or ("--verbose" in sys.argv)
[docs]def sympy_only(func: T.Callable) -> T.Callable:
"""
Decorator to mark a test to only run on SymPy, and skip otherwise.
"""
if symforce.get_symbolic_api() != "sympy":
return unittest.skip("This test only runs on SymPy symbolic API.")(func)
else:
return func
[docs]def symengine_only(func: T.Callable) -> T.Callable:
"""
Decorator to mark a test to only run on the SymEngine, and skip otherwise.
"""
if symforce.get_symbolic_api() != "symengine":
return unittest.skip("This test only runs on the SymEngine symbolic API")(func)
else:
return func
[docs]def expected_failure_on_sympy(func: T.Callable) -> T.Callable:
"""
Decorator to mark a test to be expected to fail only on SymPy..
"""
if symforce.get_symbolic_api() == "sympy":
return unittest.expectedFailure(func)
else:
return func
[docs]def slow_on_sympy(func: T.Callable) -> T.Callable:
"""
Decorator to mark a test as slow on sympy. Will be skipped unless passed the
``--run_slow_tests`` flag
"""
if symforce.get_symbolic_api() == "sympy" and not TestCase.should_run_slow_tests():
return unittest.skip("This test is too slow on SymPy.")(func)
else:
return func