Source code for symforce.codegen.ops_codegen_util

# ----------------------------------------------------------------------------
# SymForce - Copyright 2022, Skydio, Inc.
# This source code is under the Apache 2.0 license found in the LICENSE file.
# ----------------------------------------------------------------------------

import dataclasses

import symforce.symbolic as sf
from symforce import ops
from symforce import typing as T
from symforce.codegen import Codegen
from symforce.codegen import CodegenConfig


[docs]def make_group_ops_funcs(cls: T.Type, config: CodegenConfig) -> T.List[Codegen]: """ Create func spec arguments for group ops on the given class. """ identity = Codegen.function( name="identity", func=(lambda: ops.GroupOps.identity(cls)), input_types=[], config=dataclasses.replace(config, normalize_results=False), ) inverse = Codegen.function( func=ops.GroupOps.inverse, input_types=[cls], config=config, docstring=ops.GroupOps.inverse.__doc__, ) compose = Codegen.function( func=ops.GroupOps.compose, input_types=[cls, cls], config=config, docstring=ops.GroupOps.compose.__doc__, ) between = Codegen.function(func=ops.GroupOps.between, input_types=[cls, cls], config=config) return [ identity, inverse, compose, between, inverse.with_jacobians(), compose.with_jacobians(), between.with_jacobians(), ]
[docs]def make_lie_group_ops_funcs(cls: T.Type, config: CodegenConfig) -> T.List[Codegen]: """ Create func spec arguments for lie group ops on the given class. """ tangent_vec = sf.M(list(range(ops.LieGroupOps.tangent_dim(cls)))) return [ Codegen.function( name="from_tangent", func=(lambda vec, epsilon: ops.LieGroupOps.from_tangent(cls, vec, epsilon)), input_types=[tangent_vec, sf.Symbol], config=config, docstring=ops.LieGroupOps.from_tangent.__doc__, ), Codegen.function( func=ops.LieGroupOps.to_tangent, input_types=[cls, sf.Symbol], config=config, docstring=ops.LieGroupOps.to_tangent.__doc__, ), Codegen.function( func=ops.LieGroupOps.retract, input_types=[cls, tangent_vec, sf.Symbol], config=config, docstring=ops.LieGroupOps.retract.__doc__, ), Codegen.function( func=ops.LieGroupOps.local_coordinates, input_types=[cls, cls, sf.Symbol], config=config, docstring=ops.LieGroupOps.local_coordinates.__doc__, ), Codegen.function( func=ops.LieGroupOps.interpolate, input_types=[cls, cls, sf.Symbol, sf.Symbol], config=config, docstring=ops.LieGroupOps.interpolate.__doc__, ), ]