# ----------------------------------------------------------------------------
# SymForce - Copyright 2022, Skydio, Inc.
# This source code is under the Apache 2.0 license found in the LICENSE file.
# ----------------------------------------------------------------------------
"""
Shared helper code between codegen of all languages.
"""
from __future__ import annotations
import dataclasses
import importlib.abc
import importlib.util
import itertools
import sys
from pathlib import Path
import sympy
import symforce
import symforce.symbolic as sf
from symforce import _sympy_count_ops
from symforce import ops
from symforce import typing as T
from symforce import typing_util
from symforce.codegen import codegen_config
from symforce.codegen import format_util
from symforce.values import IndexEntry
from symforce.values import Values
NUMPY_DTYPE_FROM_SCALAR_TYPE = {"double": "numpy.float64", "float": "numpy.float32"}
# Type representing generated code (list of lhs and rhs terms)
T_terms = T.Sequence[T.Tuple[sf.Symbol, sf.Expr]]
T_nested_terms = T.Sequence[T_terms]
T_terms_printed = T.Sequence[T.Tuple[str, str]]
[docs]class DenseAndSparseOutputTerms(T.NamedTuple):
dense: T.List[T.List[sf.Expr]]
sparse: T.List[T.List[sf.Expr]]
[docs]class OutputWithTerms(T.NamedTuple):
name: str
type: T.Element
terms: T_terms_printed
[docs]class PrintCodeResult(T.NamedTuple):
intermediate_terms: T_terms_printed
dense_terms: T.List[OutputWithTerms]
sparse_terms: T.List[OutputWithTerms]
total_ops: int
[docs]def print_code(
inputs: Values,
outputs: Values,
sparse_mat_data: T.Dict[str, CSCFormat],
config: codegen_config.CodegenConfig,
cse: bool = True,
) -> PrintCodeResult:
"""
Return executable code lines from the given input/output values.
Args:
inputs: Values object specifying names and symbolic inputs
outputs: Values object specifying names and output expressions (written in terms
of the symbolic inputs)
sparse_mat_data: Data associated with sparse matrices. ``sparse_mat_data["keys"]`` stores
a list of the keys in outputs which should be treated as sparse matrices
config: Programming language and configuration in which the expressions are to be generated
cse: Perform common sub-expression elimination
Returns:
T.List[T.Tuple[str, str]]: Line of code per temporary variable
T.List[OutputWithTerms]: Collection of lines of code per dense output variable
T.List[OutputWithTerms]: Collection of lines of code per sparse output variable
int: Total number of ops
"""
# Split outputs into dense and sparse outputs, since we treat them differently when doing codegen
dense_outputs = Values()
sparse_outputs = Values()
for key, value in outputs.items():
if key in sparse_mat_data:
sparse_outputs[key] = sparse_mat_data[key].nonzero_elements
else:
dense_outputs[key] = value
output_exprs = DenseAndSparseOutputTerms(
dense=[ops.StorageOps.to_storage(value) for key, value in dense_outputs.items()],
sparse=[ops.StorageOps.to_storage(value) for key, value in sparse_outputs.items()],
)
# CSE If needed
if cse:
temps, simplified_outputs = perform_cse(
output_exprs=output_exprs,
cse_optimizations=config.cse_optimizations,
)
else:
temps = []
simplified_outputs = output_exprs
# Replace default symbols with vector notation (e.g. "R_re" -> "_R[0]")
temps_formatted, dense_outputs_formatted, sparse_outputs_formatted = format_symbols(
inputs=inputs,
dense_outputs=dense_outputs,
sparse_outputs=sparse_outputs,
intermediate_terms=temps,
output_terms=simplified_outputs,
config=config,
)
simpify_list = lambda lst: [sympy.S(term) for term in lst]
simpify_nested_lists = lambda nested_lsts: [simpify_list(lst) for lst in nested_lsts]
temps_formatted = simpify_list(temps_formatted)
dense_outputs_formatted = simpify_nested_lists(dense_outputs_formatted)
sparse_outputs_formatted = simpify_nested_lists(sparse_outputs_formatted)
def count_ops(expr: T.Any) -> int:
op_count = _sympy_count_ops.count_ops(expr)
assert isinstance(op_count, int)
return op_count
total_ops = (
count_ops(temps_formatted)
+ count_ops(dense_outputs_formatted)
+ count_ops(sparse_outputs_formatted)
)
# Get printer
printer = config.printer()
# Print code
intermediate_terms = [(str(var), printer.doprint(t)) for var, t in temps_formatted]
dense_outputs_code_no_names = [
[(str(var), printer.doprint(t)) for var, t in single_output_terms]
for single_output_terms in dense_outputs_formatted
]
sparse_outputs_code_no_names = [
[(str(var), printer.doprint(t)) for var, t in single_output_terms]
for single_output_terms in sparse_outputs_formatted
]
# Pack names and types with outputs
dense_terms = [
OutputWithTerms(key, value, output_code_no_name)
for output_code_no_name, (key, value) in zip(
dense_outputs_code_no_names, dense_outputs.items()
)
]
sparse_terms = [
OutputWithTerms(key, value, sparse_output_code_no_name)
for sparse_output_code_no_name, (key, value) in zip(
sparse_outputs_code_no_names, sparse_outputs.items()
)
]
return PrintCodeResult(
intermediate_terms=intermediate_terms,
dense_terms=dense_terms,
sparse_terms=sparse_terms,
total_ops=total_ops,
)
def _get_scalar_keys_recursive(
index_value: IndexEntry, prefix: str, config: codegen_config.CodegenConfig, use_data: bool
) -> T.List[sf.Symbol]:
"""
Returns a vector of keys, recursing on Values or List objects to get sub-elements.
Args:
index_value: Entry in a given index consisting of (inx, datatype, shape, item_index)
See Values.index() for details on how this entry is built.
prefix: Symbol used to access parent object, e.g. "my_values.item" or "my_list[i]"
config: Programming language and configuration for when language-specific formatting is
required
use_data: If true, we assume we can have a list of geo/cam objects whose data can be
accessed with ".data" or ".Data()". Otherwise, assume geo/cam objects are represented
by a vector of scalars (e.g. as they are in lcm types).
"""
vec = []
datatype = index_value.datatype()
if issubclass(datatype, sf.Scalar):
# Element is a scalar, no need to access subvalues
vec.append(sf.Symbol(prefix))
elif issubclass(datatype, Values):
assert index_value.item_index is not None
# Recursively add subitems using "." to access subvalues
for name, sub_index_val in index_value.item_index.items():
vec.extend(
_get_scalar_keys_recursive(
sub_index_val, prefix=f"{prefix}.{name}", config=config, use_data=False
)
)
elif issubclass(datatype, sf.DataBuffer):
vec.append(sf.DataBuffer(prefix))
elif issubclass(datatype, (list, tuple)):
assert index_value.item_index is not None
# Assume all elements of list are same type as first element
# Recursively add subitems using "[]" to access subvalues
for i, sub_index_val in enumerate(index_value.item_index.values()):
vec.extend(
_get_scalar_keys_recursive(
sub_index_val, prefix=f"{prefix}[{i}]", config=config, use_data=use_data
)
)
elif issubclass(datatype, sf.Matrix) or not use_data:
if config.use_eigen_types:
vec.extend(
sf.Symbol(config.format_eigen_lcm_accessor(prefix, i))
for i in range(index_value.storage_dim)
)
else:
vec.extend(sf.Symbol(f"{prefix}[{i}]") for i in range(index_value.storage_dim))
else:
# We have a geo/cam or other object that uses "data" to store a flat vector of scalars.
vec.extend(
sf.Symbol(config.format_data_accessor(prefix=prefix, index=i))
for i in range(index_value.storage_dim)
)
assert len(vec) == len(set(vec)), "Non-unique keys:\n{}".format(
[symbol for symbol in vec if vec.count(symbol) > 1]
)
return vec
def _load_generated_package_internal(name: str, path: Path) -> T.Tuple[T.Any, T.List[str]]:
"""
Dynamically load generated package (or module).
Returns the generated package (module) and a list of the names of all modules added
to sys.module by this function.
Does not remove the modules it imports from sys.modules.
Precondition: If m is a module from the same package as name and is imported by name, then
there does not exist a different module with the same name as m in sys.modules. This is to
ensure name imports the correct modules.
"""
if path.is_dir():
path = path / "__init__.py" # noqa: PLR6104
parts = name.split(".")
if len(parts) > 1:
# Load parent packages
_, added_module_names = _load_generated_package_internal(
".".join(parts[:-1]), path.parent / "__init__.py"
)
else:
added_module_names = []
spec = importlib.util.spec_from_file_location(name, path)
assert spec is not None
module = importlib.util.module_from_spec(spec)
sys.modules[name] = module
added_module_names.append(name)
# For mypy: https://github.com/python/typeshed/issues/2793
assert isinstance(spec.loader, importlib.abc.Loader)
spec.loader.exec_module(module)
return module, added_module_names
[docs]def load_generated_package(name: str, path: T.Openable, evict: bool = True) -> T.Any:
"""
Dynamically load generated package (or module).
Args:
name: The full name of the package or module to load (for example, ``"pkg.sub_pkg"``
for a package called ``sub_pkg`` inside of another package ``pkg``, or
``"pkg.sub_pkg.mod"`` for a module called ``mod`` inside of ``pkg.sub_pkg``).
path: The path to the directory (or ``__init__.py``) of the package, or the python
file of the module.
evict: Whether to evict the imported package from sys.modules after loading it. This is
necessary for functions generated in the ``sym`` namespace, since leaving them would
make it impossible to ``import sym`` and get the ``symforce-sym`` package as expected.
For this reason, attempting to load a generated package called ``sym`` with
``evict=False`` is disallowed. However, evict should be ``False`` for numba-compiled
functions.
"""
if not evict:
if name.split(".")[0] == "sym":
raise ValueError(
"Attempted to hotload a generated package called `sym` - see "
"`help(load_generated_package)` for more information"
)
return _load_generated_package_internal(name, Path(path))[0]
# NOTE(brad): We remove all possibly conflicting modules from the cache. This is
# to ensure that when name is executed, it loads local modules (if any) rather
# than any with colliding names that have been loaded elsewhere
root_package_name = name.split(".")[0]
callee_saved_modules: T.List[T.Tuple[str, T.Any]] = []
for module_name in tuple(sys.modules.keys()):
if root_package_name == module_name.split(".")[0]:
try:
conflicting_module = sys.modules[module_name]
del sys.modules[module_name]
callee_saved_modules.append((module_name, conflicting_module))
except KeyError:
pass
module, added_module_names = _load_generated_package_internal(name, Path(path))
# We remove the temporarily added modules
for added_name in added_module_names:
try:
del sys.modules[added_name]
except KeyError:
pass
# And we restore the original removed modules
for removed_name, removed_module in callee_saved_modules:
sys.modules[removed_name] = removed_module
return module
[docs]def load_generated_function(
func_name: str, path_to_package: T.Openable, evict: bool = True
) -> T.Callable:
"""
Returns the function with name ``func_name`` found inside the package located at
``path_to_package``.
Example usage::
def my_func(...):
...
my_codegen = Codegen.function(my_func, config=PythonConfig())
codegen_data = my_codegen.generate_function(output_dir=output_dir)
generated_func = load_generated_function("my_func", codegen_data.function_dir)
generated_func(...)
Args:
path_to_package: a python package with an ``__init__.py`` containing a module defined in
``func_name.py`` which in turn defines an attribute named ``func_name``. See the example
above.
evict: Whether to evict the imported package from sys.modules after loading it. This is
necessary for functions generated in the ``sym`` namespace, since leaving them would
make it impossible to ``import sym`` and get the ``symforce-sym`` package as expected.
For this reason, attempting to load a generated package called ``sym`` with
``evict=False`` is disallowed. However, evict should be ``False`` for numba-compiled
functions.
"""
pkg_path = Path(path_to_package)
if pkg_path.name == "__init__.py":
pkg_path = pkg_path.parent
pkg_name = pkg_path.name
func_module = load_generated_package(
f"{pkg_name}.{func_name}", pkg_path / f"{func_name}.py", evict
)
return getattr(func_module, func_name)
[docs]def load_generated_lcmtype(
package: str, type_name: str, lcmtypes_path: T.Union[str, Path]
) -> T.Type:
"""
Load an LCM type generated by
:meth:`Codegen.generate_function <symforce.codegen.codegen.Codegen.generate_function>`
Example usage::
my_codegen = Codegen(my_func, config=PythonConfig())
codegen_data = my_codegen.generate_function(output_dir=output_dir, namespace=namespace)
my_type_t = codegen_util.load_generated_lcmtype(
namespace, "my_type_t", codegen_data.python_types_dir
)
my_type_msg = my_type_t(foo=5)
Args:
package: The name of the LCM package for the type
type_name: The name of the LCM type itself (not including the package)
lcmtypes_path: The path to the directory containing the generated lcmtypes package
Returns:
The Python LCM type
"""
# We need to import the lcmtypes package first so that sys.path is set up correctly, since this
# is a namespace package
import lcmtypes # noqa: F401
return getattr(
load_generated_package(
f"lcmtypes.{package}._{type_name}",
Path(lcmtypes_path) / "lcmtypes" / package / f"_{type_name}.py",
),
type_name,
)
[docs]def get_base_instance(obj: T.Sequence[T.Any]) -> T.Any:
"""
Returns an instance of the base element (e.g. Scalar, Values, Matrix, etc.) of an object.
If input is a list (incl. multidimensional lists), we return an instance of one of the base
elements (i.e. the first element that isn't a list). If input is a list we assume all
elements are of the same type/shape.
"""
if isinstance(obj, (list, tuple)):
return get_base_instance(obj[0])
return obj
[docs]@dataclasses.dataclass
class LcmBindingsDirs:
python_types_dir: Path
cpp_types_dir: Path
[docs]def generate_lcm_types(
lcm_type_dir: T.Openable,
lcm_files: T.Sequence[str],
lcm_output_dir: T.Optional[T.Openable] = None,
) -> LcmBindingsDirs:
"""
Generates the language-specific type files for all symforce generated ".lcm" files.
Args:
lcm_type_dir: Directory containing symforce-generated .lcm files
lcm_files: List of .lcm files to process
"""
lcm_type_dir = Path(lcm_type_dir)
if lcm_output_dir is None:
lcm_output_dir = lcm_type_dir / ".."
else:
lcm_output_dir = Path(lcm_output_dir)
python_types_dir = lcm_output_dir / "python"
cpp_types_dir = lcm_output_dir / "cpp" / "lcmtypes"
lcm_include_dir = "lcmtypes"
result = LcmBindingsDirs(python_types_dir=python_types_dir, cpp_types_dir=cpp_types_dir)
# TODO(brad, aaron): Do something reasonable with lcm_files other than returning early
# If no LCM files provided, do nothing
if not lcm_files:
return result
from skymarshal import skymarshal
from skymarshal.emit_cpp import SkymarshalCpp
from skymarshal.emit_python import SkymarshalPython
skymarshal.main(
[SkymarshalPython, SkymarshalCpp],
args=[
str(lcm_type_dir),
"--python",
"--python-path",
str(python_types_dir / "lcmtypes"),
"--python-namespace-packages",
"--python-package-prefix",
"lcmtypes",
"--cpp",
"--cpp-hpath",
str(cpp_types_dir),
"--cpp-include",
lcm_include_dir,
"--no-source-paths",
],
print_generated=False,
)
# Autoformat generated python files
format_util.format_py_dir(python_types_dir)
return result
[docs]def flat_symbols_from_values(values: Values) -> T.List[T.Any]:
"""
Returns a flat list of unique symbols in the object for codegen
Note that this *does not* respect storage ordering
"""
symbols_list = values.to_storage()
for v in values.values_recursive():
if isinstance(v, sf.DataBuffer):
symbols_list.append(v)
return symbols_list