Source code for symforce.test_util.test_case

# ----------------------------------------------------------------------------
# SymForce - Copyright 2022, Skydio, Inc.
# This source code is under the Apache 2.0 license found in the LICENSE file.
# ----------------------------------------------------------------------------

import random
import sys
import unittest

import numpy as np

import symforce
from symforce import typing as T
from symforce.test_util.test_case_mixin import SymforceTestCaseMixin


[docs]class TestCase(SymforceTestCaseMixin): """ Base class for symforce tests. Adds some useful helpers. """ # Set by the --run_slow_tests flag to indicate that we should run all tests even # if we're on SymPy. _RUN_SLOW_TESTS = False
[docs] @staticmethod def should_run_slow_tests() -> bool: # NOTE(aaron): This needs to be accessible before main() is called, so we do it here # instead. This should also be called from main to make sure it runs at least once if "--run_slow_tests" in sys.argv: TestCase._RUN_SLOW_TESTS = True sys.argv.remove("--run_slow_tests") return TestCase._RUN_SLOW_TESTS
[docs] @staticmethod def main(*args: T.Any, **kwargs: T.Any) -> None: """ Call this to run all tests in scope. """ TestCase.should_run_slow_tests() SymforceTestCaseMixin.main(*args, **kwargs)
[docs] def setUp(self) -> None: super().setUp() # Set random seeds np.random.seed(42) random.seed(42) # Store verbosity flag so tests can use self.verbose = ("-v" in sys.argv) or ("--verbose" in sys.argv)
[docs]def sympy_only(func: T.Callable) -> T.Callable: """ Decorator to mark a test to only run on SymPy, and skip otherwise. """ if symforce.get_symbolic_api() != "sympy": return unittest.skip("This test only runs on SymPy symbolic API.")(func) else: return func
[docs]def symengine_only(func: T.Callable) -> T.Callable: """ Decorator to mark a test to only run on the SymEngine, and skip otherwise. """ if symforce.get_symbolic_api() != "symengine": return unittest.skip("This test only runs on the SymEngine symbolic API")(func) else: return func
[docs]def expected_failure_on_sympy(func: T.Callable) -> T.Callable: """ Decorator to mark a test to be expected to fail only on SymPy.. """ if symforce.get_symbolic_api() == "sympy": return unittest.expectedFailure(func) else: return func
[docs]def slow_on_sympy(func: T.Callable) -> T.Callable: """ Decorator to mark a test as slow on sympy.. Will be skipped unless passed the --run_slow_tests flag """ if symforce.get_symbolic_api() == "sympy" and not TestCase.should_run_slow_tests(): return unittest.skip("This test is too slow on SymPy.")(func) else: return func