Source code for symforce.codegen.backends.rust.rust_code_printer
# ----------------------------------------------------------------------------# SymForce - Copyright 2022, Skydio, Inc.# This source code is under the Apache 2.0 license found in the LICENSE file.# ----------------------------------------------------------------------------fromenumimportEnumimportsympyfromsympy.codegen.astimportfloat32fromsympy.codegen.astimportfloat64fromsympy.printing.rustimportRustCodePrinterasSympyRustCodePrinterfromsymforceimporttypingasT
[docs]classRustCodePrinter(SympyRustCodePrinter):""" SymForce code printer for Rust. Based on the SymPy Rust printer. """def__init__(self,scalar_type:ScalarType,settings:T.Optional[T.Dict[str,T.Any]]=None,override_methods:T.Optional[T.Dict[sympy.Function,str]]=None,)->None:super().__init__(dict(settingsor{}))self.scalar_type=scalar_type.valueself.override_methods=override_methodsor{}forexpr,nameinself.override_methods.items():self._set_override_methods(expr,name)def_set_override_methods(self,expr:sympy.Function,name:str)->None:method_name=f"_print_{str(expr)}"def_print_expr(expr:sympy.Expr)->str:expr_string=", ".join(map(self._print,expr.args))returnf"{name}({expr_string})"setattr(self,method_name,_print_expr)@staticmethoddef_print_Zero(expr:sympy.Expr)->str:return"0.0"def_print_Integer(self,expr:sympy.Integer,_type:T.Any=None)->T.Any:""" Customizations: * Cast all integers to either f32 or f64 because Rust does not have implicit casting and needs to know the type of the literal at compile time. We assume that we are only ever operating on floats in SymForce which should make this safe. """ifself.scalar_typeisfloat32:returnf"{expr.p}_f32"ifself.scalar_typeisfloat64:returnf"{expr.p}_f64"assertFalse,f"Scalar type {self.scalar_type} not supported"def_print_Pow(self,expr:T.Any,rational:T.Any=None)->str:ifexpr.exp.is_rational:power=self._print_Rational(expr.exp)func="powf"returnf"{self._print(expr.base)}.{func}({power})"else:power=self._print(expr.exp)ifexpr.exp.is_integer:func="powi"else:func="powf"returnf"{expr.base}.{func}({power})"@staticmethoddef_print_ImaginaryUnit(expr:sympy.Expr)->str:""" Customizations: * Print 1i instead of I * Cast to Scalar, since the literal is of type std::complex<double> """return"Scalar(1i)"def_print_Float(self,flt:sympy.Float,_type:T.Any=None)->T.Any:""" Customizations: * Cast all literals to Scalar at compile time instead of using a suffix at codegen time """ifself.scalar_typeisfloat32:returnf"{super()._print_Float(flt)}_f32"ifself.scalar_typeisfloat64:returnf"{super()._print_Float(flt)}_f64"raiseNotImplementedError(f"Scalar type {self.scalar_type} not supported")def_print_Pi(self,expr:T.Any,_type:bool=False)->str:ifself.scalar_typeisfloat32:return"core::f32::consts::PI"ifself.scalar_typeisfloat64:return"core::f64::consts::PI"raiseNotImplementedError(f"Scalar type {self.scalar_type} not supported")def_print_Max(self,expr:sympy.Max)->str:""" Customizations: * The first argument calls the max method on the second argument. """return"{}.max({})".format(self._print(expr.args[0]),self._print(expr.args[1]))def_print_Min(self,expr:sympy.Min)->str:""" Customizations: * The first argument calls the min method on the second argument. """return"{}.min({})".format(self._print(expr.args[0]),self._print(expr.args[1]))def_print_log(self,expr:sympy.log)->str:""" Customizations: """return"{}.ln()".format(self._print(expr.args[0]))def_print_Rational(self,expr:sympy.Rational)->str:p,q=int(expr.p),int(expr.q)float_suffix=Noneifself.scalar_typeisfloat32:float_suffix="f32"elifself.scalar_typeisfloat64:float_suffix="f64"returnf"({p}_{float_suffix}/{q}_{float_suffix})"def_print_Exp1(self,expr:T.Any,_type:bool=False)->str:ifself.scalar_typeisfloat32:return"core::f32::consts::E"elifself.scalar_typeisfloat64:return"core::f64::consts::E"raiseNotImplementedError(f"Scalar type {self.scalar_type} not supported")