Source code for symforce.codegen.backends.python.python_code_printer
# ----------------------------------------------------------------------------# SymForce - Copyright 2022, Skydio, Inc.# This source code is under the Apache 2.0 license found in the LICENSE file.# ----------------------------------------------------------------------------importsympyfromsympy.printing.pycodeimportPythonCodePrinteras_PythonCodePrinter
[docs]classPythonCodePrinter(_PythonCodePrinter):""" Symforce customized code printer for Python. Modifies the Sympy printing behavior for codegen compatibility and efficiency. """@staticmethoddef_print_Rational(expr:sympy.Rational)->str:""" Customizations: * Decimal points for Python2 support, doesn't exist in some sympy versions. """returnf"{expr.p}./{expr.q}."def_print_Max(self,expr:sympy.Max)->str:""" Max is not supported by default, so we add a version here. """iflen(expr.args)==1:returnself._print(expr.args[0])else:return"max({})".format(", ".join([self._print(arg)forarginexpr.args]))def_print_Min(self,expr:sympy.Min)->str:""" Min is not supported by default, so we add a version here. """iflen(expr.args)==1:returnself._print(expr.args[0])else:return"min({})".format(", ".join([self._print(arg)forarginexpr.args]))# NOTE(brad): We type ignore the signature because mypy complains that it# does not match that of the sympy base class CodePrinter. This is because the base class# defines _print_Heaviside with: _print_Heaviside = None (see# https://github.com/sympy/sympy/blob/95f0228c033d27731f8707cdbb5bb672e500847d/sympy/printing/codeprinter.py#L446# ).# Despite this, our signature here matches the signatures of the sympy defined subclasses# of CodePrinter. I don't know of any other way to resolve this issue other than to# to type ignore.def_print_Heaviside(self,expr:"sympy.Heaviside")->str:# type: ignore[override]""" Heaviside is not supported by default, so we add a version here. """returnf"(0.0 if ({self._print(expr.args[0])}) < 0 else 1.0)"def_print_MatrixElement(self,expr:sympy.matrices.expressions.matexpr.MatrixElement)->str:""" default printer doesn't cast to int """return"{}[int({})]".format(expr.parent,self._print(expr.j+expr.i*expr.parent.shape[1]))