Source code for symforce.test_util.random_expressions.op_probabilities

# ----------------------------------------------------------------------------
# SymForce - Copyright 2022, Skydio, Inc.
# This source code is under the Apache 2.0 license found in the LICENSE file.
# ----------------------------------------------------------------------------

from dataclasses import dataclass

import symforce.symbolic as sf
from symforce import typing as T


[docs]@dataclass class OpProbability: """ Represents a logical symbolic operation and the probability of it occurring within regular use, for the purpose of generating and profiling random expressions containing these operations. Probabilities are relative across ops with the same arity, and do not have to sum to one. """ name: str func: T.Callable prob: float @property def arity(self) -> int: return self.func.__code__.co_argcount
# Some reasonable defaults for generating expressions DEFAULT_UNARY_OPS = ( OpProbability("neg", lambda x: -x, 3), OpProbability("abs", sf.Abs, 3), OpProbability("sign", sf.sign_no_zero, 3), OpProbability("sqrt", lambda x: sf.sqrt(sf.Abs(x)), 2), OpProbability("exp", sf.exp, 0.1), OpProbability("log", lambda x: sf.log(sf.Abs(x)), 0.1), OpProbability("sin", sf.sin, 0.5), OpProbability("cos", sf.cos, 0.5), OpProbability("tan", sf.tan, 0.3), OpProbability("pow2", lambda x: x**2, 3), OpProbability("pow3", lambda x: x**3, 1), OpProbability("asin", sf.asin_safe, 0.2), OpProbability("acos", sf.acos_safe, 0.2), OpProbability("atan", sf.atan, 0.1), ) DEFAULT_BINARY_OPS = ( OpProbability("add", lambda x, y: x + y, 4), OpProbability("sub", lambda x, y: x - y, 2), OpProbability("mul", lambda x, y: x * y, 5), OpProbability("div", lambda x, y: x / y, 1), OpProbability("pow", lambda x, y: x**y, 0.5), OpProbability("atan2", sf.atan2, 0.2), ) DEFAULT_INTEGER_LEAVES = (-5, -4, -3, -2, -1, 1, 2, 3, 4, 5, 7, 9) DEFAULT_SYMBOL_LEAVES = sf.symbols("x:10") DEFAULT_LEAVES = DEFAULT_INTEGER_LEAVES + DEFAULT_SYMBOL_LEAVES