Source code for symforce.codegen.backends.python.python_code_printer

# ----------------------------------------------------------------------------
# SymForce - Copyright 2022, Skydio, Inc.
# This source code is under the Apache 2.0 license found in the LICENSE file.
# ----------------------------------------------------------------------------

import sympy
from sympy.printing.pycode import PythonCodePrinter as _PythonCodePrinter


[docs]class PythonCodePrinter(_PythonCodePrinter): """ Symforce customized code printer for Python. Modifies the Sympy printing behavior for codegen compatibility and efficiency. """ def _print_Rational(self, expr: sympy.Rational) -> str: """ Customizations: * Decimal points for Python2 support, doesn't exist in some sympy versions. """ return f"{expr.p}./{expr.q}." def _print_Max(self, expr: sympy.Max) -> str: """ Max is not supported by default, so we add a version here. """ if len(expr.args) == 1: return self._print(expr.args[0]) else: return "max({})".format(", ".join([self._print(arg) for arg in expr.args])) def _print_Min(self, expr: sympy.Min) -> str: """ Min is not supported by default, so we add a version here. """ if len(expr.args) == 1: return self._print(expr.args[0]) else: return "min({})".format(", ".join([self._print(arg) for arg in expr.args])) # NOTE(brad): We type ignore the signature because mypy complains that it # does not match that of the sympy base class CodePrinter. This is because the base class # defines _print_Heaviside with: _print_Heaviside = None (see # https://github.com/sympy/sympy/blob/95f0228c033d27731f8707cdbb5bb672e500847d/sympy/printing/codeprinter.py#L446 # ). # Despite this, our signature here matches the signatures of the sympy defined subclasses # of CodePrinter. I don't know of any other way to resolve this issue other than to # to type ignore. def _print_Heaviside(self, expr: "sympy.Heaviside") -> str: # type: ignore[override] """ Heaviside is not supported by default, so we add a version here. """ return f"(0.0 if ({self._print(expr.args[0])}) < 0 else 1.0)" def _print_MatrixElement(self, expr: sympy.matrices.expressions.matexpr.MatrixElement) -> str: """ default printer doesn't cast to int """ return "{}[int({})]".format( expr.parent, self._print(expr.j + expr.i * expr.parent.shape[1]) )