symforce.test_util.random_expressions.unary_binary_expression_gen module

class UnaryBinaryExpressionGen(unary_ops, binary_ops, leaves)[source]

Bases: object

Helper to generate random symbolic expressions composed of a set of given unary and binary operators. The user provides these operators, each with an assigned probability. Then they can sample expressions given a target number of operations.

Takes care to sample uniformly from the space of all such possible expressions. This isn’t necessarily important for many use cases, but seems as good as any single strategy, with the downside that the distributions become slow to compute for large expressions.

Note that probabilities of unary and binary ops are completely independent. The sampling of unary vs binary ops is given by the variable D.

Implementation reference (Appendix C):

Deep Learning for Symbolic Mathematics (https://arxiv.org/abs/1912.01412)

Parameters:
__init__(unary_ops, binary_ops, leaves)[source]
Parameters:
static generate_D(max_ops, num_leaves=1, p1=1, p2=1)[source]

Enumerate the number of possible unary-binary trees that can be generated from empty nodes. D[e][n] represents the number of different binary trees with n nodes that can be generated from e empty nodes, using the following recursion:

D(0, n) = 0 D(e, 0) = L ** e D(e, n) = L * D(e - 1, n) + p_1 * D(e, n - 1) + p_2 * D(e + 1, n - 1)

Parameters:
  • max_ops (int) –

  • num_leaves (int) –

  • p1 (int) –

  • p2 (int) –

Return type:

List[ndarray]

sample_next_pos(nb_empty, nb_ops, num_leaves=1, p1=1, p2=1)[source]

Sample the position of the next node (unary-binary case). Sample a position in {0, …, nb_empty - 1}, along with an arity.

Parameters:
  • nb_empty (int) –

  • nb_ops (int) –

  • num_leaves (int) –

  • p1 (int) –

  • p2 (int) –

Return type:

Tuple[int, int]

build_tree_sequence(num_ops_target)[source]

Return a prefix notation sequence of the expression tree.

Parameters:

num_ops_target (int) –

Return type:

List

seq_to_expr(seq)[source]

Convert a prefix notation sequence into a sympy expression.

Parameters:

seq (Sequence[Union[str, float]]) –

Return type:

Expr

build_expr(num_ops_target)[source]

Return an expression with the given op target.

Parameters:

num_ops_target (int) –

Return type:

float

build_expr_vec(num_ops_target, num_exprs=None)[source]

Return a vector of expressions with the total given op target. If no num_exprs is provided, uses an approximate square root of the num_ops_target.

Parameters:
  • num_ops_target (int) –

  • num_exprs (Optional[int]) –

Return type:

Matrix

classmethod default(unary_ops=(OpProbability(name='neg', func=<function <lambda>>, prob=3), OpProbability(name='abs', func=<class 'symengine.lib.symengine_wrapper.Abs'>, prob=3), OpProbability(name='sign', func=<function sign_no_zero>, prob=3), OpProbability(name='sqrt', func=<function <lambda>>, prob=2), OpProbability(name='exp', func=<built-in function exp>, prob=0.1), OpProbability(name='log', func=<function <lambda>>, prob=0.1), OpProbability(name='sin', func=<class 'symengine.lib.symengine_wrapper.sin'>, prob=0.5), OpProbability(name='cos', func=<class 'symengine.lib.symengine_wrapper.cos'>, prob=0.5), OpProbability(name='tan', func=<class 'symengine.lib.symengine_wrapper.tan'>, prob=0.3), OpProbability(name='pow2', func=<function <lambda>>, prob=3), OpProbability(name='pow3', func=<function <lambda>>, prob=1), OpProbability(name='asin', func=<function asin_safe>, prob=0.2), OpProbability(name='acos', func=<function acos_safe>, prob=0.2), OpProbability(name='atan', func=<class 'symengine.lib.symengine_wrapper.atan'>, prob=0.1)), binary_ops=(OpProbability(name='add', func=<function <lambda>>, prob=4), OpProbability(name='sub', func=<function <lambda>>, prob=2), OpProbability(name='mul', func=<function <lambda>>, prob=5), OpProbability(name='div', func=<function <lambda>>, prob=1), OpProbability(name='pow', func=<function <lambda>>, prob=0.5), OpProbability(name='atan2', func=<function atan2>, prob=0.2)), leaves=(-5, -4, -3, -2, -1, 1, 2, 3, 4, 5, 7, 9, x0, x1, x2, x3, x4, x5, x6, x7, x8, x9))[source]

Construct with a reasonable default op distribution.

Parameters:
Return type:

UnaryBinaryExpressionGen