Source code for symforce.codegen.codegen_util

# ----------------------------------------------------------------------------
# SymForce - Copyright 2022, Skydio, Inc.
# This source code is under the Apache 2.0 license found in the LICENSE file.
# ----------------------------------------------------------------------------

"""
Shared helper code between codegen of all languages.
"""

from __future__ import annotations

import dataclasses
import importlib.abc
import importlib.util
import itertools
import sys
from pathlib import Path

import sympy

import symforce
import symforce.symbolic as sf
from symforce import _sympy_count_ops
from symforce import ops
from symforce import typing as T
from symforce import typing_util
from symforce.codegen import codegen_config
from symforce.codegen import format_util
from symforce.values import IndexEntry
from symforce.values import Values

NUMPY_DTYPE_FROM_SCALAR_TYPE = {"double": "numpy.float64", "float": "numpy.float32"}
# Type representing generated code (list of lhs and rhs terms)
T_terms = T.Sequence[T.Tuple[sf.Symbol, sf.Expr]]
T_nested_terms = T.Sequence[T_terms]
T_terms_printed = T.Sequence[T.Tuple[str, str]]


[docs]class DenseAndSparseOutputTerms(T.NamedTuple): dense: T.List[T.List[sf.Expr]] sparse: T.List[T.List[sf.Expr]]
[docs]class OutputWithTerms(T.NamedTuple): name: str type: T.Element terms: T_terms_printed
[docs]class PrintCodeResult(T.NamedTuple): intermediate_terms: T_terms_printed dense_terms: T.List[OutputWithTerms] sparse_terms: T.List[OutputWithTerms] total_ops: int
[docs]@dataclasses.dataclass class CSCFormat: """ A matrix written in Compressed Sparse Column format. """ kRows: int # Number of rows kCols: int # Number of columns kNumNonZero: int # Number of nonzero entries kColPtrs: T.List[int] # nonzero_elements[kColPtrs[col]] is the first nonzero entry of col kRowIndices: T.List[int] # row indices of nonzero entries written in column-major order nonzero_elements: T.List[sf.Scalar] # nonzero entries written in column-major order
[docs] @staticmethod def from_matrix(sparse_matrix: sf.Matrix) -> CSCFormat: """ Returns a dictionary with the metadata required to represent a matrix as a sparse matrix in CSC form. Args: sparse_matrix: A symbolic :class:`sf.Matrix <symforce.geo.matrix.Matrix>` where sparsity is given by exact zero equality. """ kColPtrs = [] kRowIndices = [] nonzero_elements = [] data_inx = 0 # Loop through columns because we assume CSC form for j in range(sparse_matrix.shape[1]): kColPtrs.append(data_inx) for i in range(sparse_matrix.shape[0]): if sparse_matrix[i, j] == 0: continue kRowIndices.append(i) nonzero_elements.append(sparse_matrix[i, j]) data_inx += 1 kColPtrs.append(data_inx) return CSCFormat( kRows=sparse_matrix.rows, kCols=sparse_matrix.cols, kNumNonZero=len(nonzero_elements), kColPtrs=kColPtrs, kRowIndices=kRowIndices, nonzero_elements=nonzero_elements, )
[docs] def to_matrix(self) -> sf.Matrix: """ Returns a dense matrix representing this CSC sparse matrix. """ dense_matrix = sf.M.zeros(self.kRows, self.kCols) for j in range(self.kCols): end_inx = self.kColPtrs[j + 1] if j + 1 < self.kCols else self.kNumNonZero for k in range(self.kColPtrs[j], end_inx): dense_matrix[self.kRowIndices[k], j] = self.nonzero_elements[k] return dense_matrix
[docs]def perform_cse( output_exprs: DenseAndSparseOutputTerms, cse_optimizations: T.Optional[ T.Union[T.Literal["basic"], T.Sequence[T.Tuple[T.Callable, T.Callable]]] ] = None, ) -> T.Tuple[T_terms, DenseAndSparseOutputTerms]: """ Run common sub-expression elimination on the given input/output values. Args: output_exprs: expressions on which to perform cse cse_optimizations: optimizations to be forwarded to :func:`sf.cse <symforce.symbolic.cse>` Returns: T_terms: Temporary variables holding the common sub-expressions found within output_exprs DenseAndSparseOutputTerms: output_exprs, but in terms of the returned temporaries. """ # Perform CSE flat_output_exprs = [ x for storage in (output_exprs.dense + output_exprs.sparse) for x in storage ] def tmp_symbols() -> T.Iterable[sf.Symbol]: for i in itertools.count(): yield sf.Symbol(f"_tmp{i}") if cse_optimizations is not None: if symforce.get_symbolic_api() == "symengine": raise ValueError("cse_optimizations is not supported on symengine") temps, flat_simplified_outputs = sf.cse( flat_output_exprs, symbols=tmp_symbols(), optimizations=cse_optimizations ) else: temps, flat_simplified_outputs = sf.cse(flat_output_exprs, symbols=tmp_symbols()) # Unflatten output of CSE simplified_outputs = DenseAndSparseOutputTerms(dense=[], sparse=[]) flat_i = 0 for storage in output_exprs.dense: simplified_outputs.dense.append(flat_simplified_outputs[flat_i : flat_i + len(storage)]) flat_i += len(storage) for storage in output_exprs.sparse: simplified_outputs.sparse.append(flat_simplified_outputs[flat_i : flat_i + len(storage)]) flat_i += len(storage) return temps, simplified_outputs
[docs]def format_symbols( inputs: Values, dense_outputs: Values, sparse_outputs: Values, intermediate_terms: T_terms, output_terms: DenseAndSparseOutputTerms, config: codegen_config.CodegenConfig, ) -> T.Tuple[T_terms, T_nested_terms, T_nested_terms]: """ Reformats symbolic variables used in intermediate and outputs terms to match structure of inputs/outputs. For example, if we have an input array ``"arr"`` with symbolic elements ``[arr0, arr1]``, we will remap symbol ``"arr0"`` to ``"arr[0]"`` and symbol ``"arr1"`` to ``"arr[1]"``. """ # Rename the symbolic inputs so that they match the code we generate formatted_input_args, original_args = get_formatted_list(inputs, config, format_as_inputs=True) input_subs = dict( zip( itertools.chain.from_iterable(original_args), itertools.chain.from_iterable(formatted_input_args), ) ) intermediate_terms_formatted = list( zip( (lhs for lhs, _ in intermediate_terms), ops.StorageOps.subs( [rhs for _, rhs in intermediate_terms], input_subs, dont_flatten_args=True ), ) ) dense_output_lhs_formatted, _ = get_formatted_list( dense_outputs, config, format_as_inputs=False ) dense_output_terms_formatted = [ list(zip(lhs_formatted, subbed_storage)) for lhs_formatted, subbed_storage in zip( dense_output_lhs_formatted, ops.StorageOps.subs(output_terms.dense, input_subs, dont_flatten_args=True), ) ] sparse_output_lhs_formatted = get_formatted_sparse_list(sparse_outputs) sparse_output_terms_formatted = [ list(zip(lhs_formatted, subbed_storage)) for lhs_formatted, subbed_storage in zip( sparse_output_lhs_formatted, ops.StorageOps.subs(output_terms.sparse, input_subs, dont_flatten_args=True), ) ] return intermediate_terms_formatted, dense_output_terms_formatted, sparse_output_terms_formatted
[docs]def get_formatted_list( values: Values, config: codegen_config.CodegenConfig, format_as_inputs: bool ) -> T.Tuple[T.List[T.List[T.Union[sf.Symbol, sf.DataBuffer]]], T.List[T.List[sf.Scalar]]]: """ Returns a nested list of formatted symbols, as well as a nested list of the corresponding original scalar values. For use in generated functions. Args: values: Values object mapping keys to different objects. Here we only use the object types, not their actual values. config: Programming language and configuration for when language-specific formatting is required format_as_inputs: True if values defines the input symbols, false if values defines output expressions. Returns: flattened_formatted_symbolic_values: nested list of formatted scalar symbols flattened_original_values: nested list of original scalar values """ flattened_formatted_symbolic_values = [] flattened_original_values = [] for key, value in values.items(): arg_cls = typing_util.get_type(value) storage_dim = ops.StorageOps.storage_dim(value) # For each item in the given Values object, we construct a list of symbols used # to access the scalar elements of the object. These symbols will later be matched up # with the flattened Values object symbols. if issubclass(arg_cls, sf.DataBuffer): formatted_symbols = [sf.DataBuffer(key, value.shape[0])] flattened_value = [value] elif isinstance(value, (sf.Expr, sf.Symbol)): formatted_symbols = [sf.Symbol(key)] flattened_value = [value] elif issubclass(arg_cls, sf.Matrix): # NOTE(brad): The order of the symbols must match the storage order of sf.Matrix # (as returned by sf.Matrix.to_storage). Hence, if there storage order were # changed to, say, row major, the below for loops would have to be swapped to # reflect that. formatted_symbols = [] for j in range(value.shape[1]): for i in range(value.shape[0]): formatted_symbols.append( sf.Symbol(config.format_matrix_accessor(key, i, j, shape=value.shape)) ) flattened_value = ops.StorageOps.to_storage(value) elif issubclass(arg_cls, Values): # Term is a Values object, so we must flatten it. Here we loop over the index so that # we can use the same code with lists. formatted_symbols = [] flattened_value = value.to_storage() for name, index_value in value.index().items(): # Elements of a Values object are accessed with the "." operator formatted_symbols.extend( _get_scalar_keys_recursive( index_value, prefix=f"{key}.{name}", config=config, use_data=False ) ) assert len(formatted_symbols) == len( set(formatted_symbols) ), "Non-unique keys:\n{}".format( [symbol for symbol in formatted_symbols if formatted_symbols.count(symbol) > 1] ) elif issubclass(arg_cls, (list, tuple)): # Term is a list, so we loop over the index of the list, i.e. # "values.index()[key].item_index". formatted_symbols = [] flattened_value = ops.StorageOps.to_storage(value) sub_index = values.index()[key].item_index assert sub_index is not None for i, sub_index_val in enumerate(sub_index.values()): # Elements of a list are accessed with the "[]" operator. formatted_symbols.extend( _get_scalar_keys_recursive( sub_index_val, prefix=f"{key}[{i}]", config=config, use_data=format_as_inputs, ) ) assert len(formatted_symbols) == len( set(formatted_symbols) ), "Non-unique keys:\n{}".format( [symbol for symbol in formatted_symbols if formatted_symbols.count(symbol) > 1] ) else: if format_as_inputs: # For readability, we will store the data of geo/cam objects in a temp vector named "_key" # where "key" is the name of the given input variable (can be "self" for member functions accessing # object data) formatted_symbols = [sf.Symbol(f"_{key}[{j}]") for j in range(storage_dim)] else: # For geo/cam objects being output, we can't access "data" directly, so in the # jinja template we will construct a new object from a vector formatted_symbols = [sf.Symbol(f"{key}[{j}]") for j in range(storage_dim)] flattened_value = ops.StorageOps.to_storage(value) if len(formatted_symbols) != len(flattened_value): error_text = ( "Number of symbols does not match number of values. " + "This can happen if a databuffer is included in a Values object used as an input " + "to the codegen function (databuffers should be top-level arguments/inputs). " ) # Only print matches if flattened_value isn't filled with expressions if format_as_inputs: matches = list(zip(formatted_symbols, flattened_value)) error_text += f"The following symbol/value pairs should match: {matches}" raise ValueError(error_text) flattened_formatted_symbolic_values.append(formatted_symbols) flattened_original_values.append(flattened_value) return flattened_formatted_symbolic_values, flattened_original_values
def _get_scalar_keys_recursive( index_value: IndexEntry, prefix: str, config: codegen_config.CodegenConfig, use_data: bool ) -> T.List[sf.Symbol]: """ Returns a vector of keys, recursing on Values or List objects to get sub-elements. Args: index_value: Entry in a given index consisting of (inx, datatype, shape, item_index) See Values.index() for details on how this entry is built. prefix: Symbol used to access parent object, e.g. "my_values.item" or "my_list[i]" config: Programming language and configuration for when language-specific formatting is required use_data: If true, we assume we can have a list of geo/cam objects whose data can be accessed with ".data" or ".Data()". Otherwise, assume geo/cam objects are represented by a vector of scalars (e.g. as they are in lcm types). """ vec = [] datatype = index_value.datatype() if issubclass(datatype, sf.Scalar): # Element is a scalar, no need to access subvalues vec.append(sf.Symbol(prefix)) elif issubclass(datatype, Values): assert index_value.item_index is not None # Recursively add subitems using "." to access subvalues for name, sub_index_val in index_value.item_index.items(): vec.extend( _get_scalar_keys_recursive( sub_index_val, prefix=f"{prefix}.{name}", config=config, use_data=False ) ) elif issubclass(datatype, sf.DataBuffer): vec.append(sf.DataBuffer(prefix)) elif issubclass(datatype, (list, tuple)): assert index_value.item_index is not None # Assume all elements of list are same type as first element # Recursively add subitems using "[]" to access subvalues for i, sub_index_val in enumerate(index_value.item_index.values()): vec.extend( _get_scalar_keys_recursive( sub_index_val, prefix=f"{prefix}[{i}]", config=config, use_data=use_data ) ) elif issubclass(datatype, sf.Matrix) or not use_data: if config.use_eigen_types: vec.extend( sf.Symbol(config.format_eigen_lcm_accessor(prefix, i)) for i in range(index_value.storage_dim) ) else: vec.extend(sf.Symbol(f"{prefix}[{i}]") for i in range(index_value.storage_dim)) else: # We have a geo/cam or other object that uses "data" to store a flat vector of scalars. vec.extend( sf.Symbol(config.format_data_accessor(prefix=prefix, index=i)) for i in range(index_value.storage_dim) ) assert len(vec) == len(set(vec)), "Non-unique keys:\n{}".format( [symbol for symbol in vec if vec.count(symbol) > 1] ) return vec
[docs]def get_formatted_sparse_list(sparse_outputs: Values) -> T.List[T.List[sf.Scalar]]: """ Returns a nested list of symbols for use in generated functions for sparse matrices. """ symbolic_args = [] # Each element of sparse_outputs is a list of the nonzero terms in the sparse matrix for key, sparse_matrix_data in sparse_outputs.items(): symbolic_args.append( [sf.Symbol(f"{key}_value_ptr[{i}]") for i in range(len(sparse_matrix_data))] ) return symbolic_args
def _load_generated_package_internal(name: str, path: Path) -> T.Tuple[T.Any, T.List[str]]: """ Dynamically load generated package (or module). Returns the generated package (module) and a list of the names of all modules added to sys.module by this function. Does not remove the modules it imports from sys.modules. Precondition: If m is a module from the same package as name and is imported by name, then there does not exist a different module with the same name as m in sys.modules. This is to ensure name imports the correct modules. """ if path.is_dir(): path = path / "__init__.py" # noqa: PLR6104 parts = name.split(".") if len(parts) > 1: # Load parent packages _, added_module_names = _load_generated_package_internal( ".".join(parts[:-1]), path.parent / "__init__.py" ) else: added_module_names = [] spec = importlib.util.spec_from_file_location(name, path) assert spec is not None module = importlib.util.module_from_spec(spec) sys.modules[name] = module added_module_names.append(name) # For mypy: https://github.com/python/typeshed/issues/2793 assert isinstance(spec.loader, importlib.abc.Loader) spec.loader.exec_module(module) return module, added_module_names
[docs]def load_generated_package(name: str, path: T.Openable, evict: bool = True) -> T.Any: """ Dynamically load generated package (or module). Args: name: The full name of the package or module to load (for example, ``"pkg.sub_pkg"`` for a package called ``sub_pkg`` inside of another package ``pkg``, or ``"pkg.sub_pkg.mod"`` for a module called ``mod`` inside of ``pkg.sub_pkg``). path: The path to the directory (or ``__init__.py``) of the package, or the python file of the module. evict: Whether to evict the imported package from sys.modules after loading it. This is necessary for functions generated in the ``sym`` namespace, since leaving them would make it impossible to ``import sym`` and get the ``symforce-sym`` package as expected. For this reason, attempting to load a generated package called ``sym`` with ``evict=False`` is disallowed. However, evict should be ``False`` for numba-compiled functions. """ if not evict: if name.split(".")[0] == "sym": raise ValueError( "Attempted to hotload a generated package called `sym` - see " "`help(load_generated_package)` for more information" ) return _load_generated_package_internal(name, Path(path))[0] # NOTE(brad): We remove all possibly conflicting modules from the cache. This is # to ensure that when name is executed, it loads local modules (if any) rather # than any with colliding names that have been loaded elsewhere root_package_name = name.split(".")[0] callee_saved_modules: T.List[T.Tuple[str, T.Any]] = [] for module_name in tuple(sys.modules.keys()): if root_package_name == module_name.split(".")[0]: try: conflicting_module = sys.modules[module_name] del sys.modules[module_name] callee_saved_modules.append((module_name, conflicting_module)) except KeyError: pass module, added_module_names = _load_generated_package_internal(name, Path(path)) # We remove the temporarily added modules for added_name in added_module_names: try: del sys.modules[added_name] except KeyError: pass # And we restore the original removed modules for removed_name, removed_module in callee_saved_modules: sys.modules[removed_name] = removed_module return module
[docs]def load_generated_function( func_name: str, path_to_package: T.Openable, evict: bool = True ) -> T.Callable: """ Returns the function with name ``func_name`` found inside the package located at ``path_to_package``. Example usage:: def my_func(...): ... my_codegen = Codegen.function(my_func, config=PythonConfig()) codegen_data = my_codegen.generate_function(output_dir=output_dir) generated_func = load_generated_function("my_func", codegen_data.function_dir) generated_func(...) Args: path_to_package: a python package with an ``__init__.py`` containing a module defined in ``func_name.py`` which in turn defines an attribute named ``func_name``. See the example above. evict: Whether to evict the imported package from sys.modules after loading it. This is necessary for functions generated in the ``sym`` namespace, since leaving them would make it impossible to ``import sym`` and get the ``symforce-sym`` package as expected. For this reason, attempting to load a generated package called ``sym`` with ``evict=False`` is disallowed. However, evict should be ``False`` for numba-compiled functions. """ pkg_path = Path(path_to_package) if pkg_path.name == "__init__.py": pkg_path = pkg_path.parent pkg_name = pkg_path.name func_module = load_generated_package( f"{pkg_name}.{func_name}", pkg_path / f"{func_name}.py", evict ) return getattr(func_module, func_name)
[docs]def load_generated_lcmtype( package: str, type_name: str, lcmtypes_path: T.Union[str, Path] ) -> T.Type: """ Load an LCM type generated by :meth:`Codegen.generate_function <symforce.codegen.codegen.Codegen.generate_function>` Example usage:: my_codegen = Codegen(my_func, config=PythonConfig()) codegen_data = my_codegen.generate_function(output_dir=output_dir, namespace=namespace) my_type_t = codegen_util.load_generated_lcmtype( namespace, "my_type_t", codegen_data.python_types_dir ) my_type_msg = my_type_t(foo=5) Args: package: The name of the LCM package for the type type_name: The name of the LCM type itself (not including the package) lcmtypes_path: The path to the directory containing the generated lcmtypes package Returns: The Python LCM type """ # We need to import the lcmtypes package first so that sys.path is set up correctly, since this # is a namespace package import lcmtypes # noqa: F401 return getattr( load_generated_package( f"lcmtypes.{package}._{type_name}", Path(lcmtypes_path) / "lcmtypes" / package / f"_{type_name}.py", ), type_name, )
[docs]def get_base_instance(obj: T.Sequence[T.Any]) -> T.Any: """ Returns an instance of the base element (e.g. Scalar, Values, Matrix, etc.) of an object. If input is a list (incl. multidimensional lists), we return an instance of one of the base elements (i.e. the first element that isn't a list). If input is a list we assume all elements are of the same type/shape. """ if isinstance(obj, (list, tuple)): return get_base_instance(obj[0]) return obj
[docs]@dataclasses.dataclass class LcmBindingsDirs: python_types_dir: Path cpp_types_dir: Path
[docs]def generate_lcm_types( lcm_type_dir: T.Openable, lcm_files: T.Sequence[str], lcm_output_dir: T.Optional[T.Openable] = None, ) -> LcmBindingsDirs: """ Generates the language-specific type files for all symforce generated ".lcm" files. Args: lcm_type_dir: Directory containing symforce-generated .lcm files lcm_files: List of .lcm files to process """ lcm_type_dir = Path(lcm_type_dir) if lcm_output_dir is None: lcm_output_dir = lcm_type_dir / ".." else: lcm_output_dir = Path(lcm_output_dir) python_types_dir = lcm_output_dir / "python" cpp_types_dir = lcm_output_dir / "cpp" / "lcmtypes" lcm_include_dir = "lcmtypes" result = LcmBindingsDirs(python_types_dir=python_types_dir, cpp_types_dir=cpp_types_dir) # TODO(brad, aaron): Do something reasonable with lcm_files other than returning early # If no LCM files provided, do nothing if not lcm_files: return result from skymarshal import skymarshal from skymarshal.emit_cpp import SkymarshalCpp from skymarshal.emit_python import SkymarshalPython skymarshal.main( [SkymarshalPython, SkymarshalCpp], args=[ str(lcm_type_dir), "--python", "--python-path", str(python_types_dir / "lcmtypes"), "--python-namespace-packages", "--python-package-prefix", "lcmtypes", "--cpp", "--cpp-hpath", str(cpp_types_dir), "--cpp-include", lcm_include_dir, "--no-source-paths", ], print_generated=False, ) # Autoformat generated python files format_util.format_py_dir(python_types_dir) return result
[docs]def flat_symbols_from_values(values: Values) -> T.List[T.Any]: """ Returns a flat list of unique symbols in the object for codegen Note that this *does not* respect storage ordering """ symbols_list = values.to_storage() for v in values.values_recursive(): if isinstance(v, sf.DataBuffer): symbols_list.append(v) return symbols_list