Source code for symforce.ops.impl.dataclass_lie_group_ops
# ----------------------------------------------------------------------------# SymForce - Copyright 2022, Skydio, Inc.# This source code is under the Apache 2.0 license found in the LICENSE file.# ----------------------------------------------------------------------------from__future__importannotationsimportdataclassesfromsymforceimporttypingasTfromsymforceimporttyping_utilfromsymforce.opsimportLieGroupOpsfromsymforce.opsimportStorageOpsfrom.dataclass_group_opsimportDataclassGroupOpsifT.TYPE_CHECKING:fromsymforceimportgeo
[docs]@staticmethoddeftangent_dim(a:T.DataclassOrType)->int:ifisinstance(a,type):count=0type_hints_map=T.get_type_hints(a)forfieldindataclasses.fields(a):field_type=type_hints_map[field.name]iffield.metadata.get("length")isnotNone:sequence_instance=typing_util.get_sequence_from_dataclass_sequence_field(field,field_type)count+=LieGroupOps.tangent_dim(sequence_instance)elif(sequence_types:=typing_util.maybe_tuples_of_types_from_annotation(field_type))isnotNone:# It's a Tuple of known sizecount+=LieGroupOps.tangent_dim(sequence_types)else:count+=LieGroupOps.tangent_dim(field_type)returncountelse:count=0forfieldindataclasses.fields(a):count+=LieGroupOps.tangent_dim(getattr(a,field.name))returncount
[docs]@staticmethoddeffrom_tangent(a:T.DataclassOrType,vec:T.Sequence[T.Scalar],epsilon:T.Scalar)->T.Dataclass:ifisinstance(a,type):constructed_fields={}offset=0type_hints_map=T.get_type_hints(a)forfieldindataclasses.fields(a):field_type=type_hints_map[field.name]iffield.metadata.get("length")isnotNone:sequence_instance=typing_util.get_sequence_from_dataclass_sequence_field(field,field_type)tangent_dim=LieGroupOps.tangent_dim(sequence_instance)constructed_fields[field.name]=LieGroupOps.from_tangent(sequence_instance,vec[offset:offset+tangent_dim])elif(sequence_types:=typing_util.maybe_tuples_of_types_from_annotation(field_type))isnotNone:# It's a Tuple of known sizetangent_dim=LieGroupOps.tangent_dim(sequence_types)constructed_fields[field.name]=LieGroupOps.from_tangent(sequence_types,vec[offset:offset+tangent_dim],epsilon)else:tangent_dim=LieGroupOps.tangent_dim(field_type)constructed_fields[field.name]=LieGroupOps.from_tangent(field_type,vec[offset:offset+tangent_dim],epsilon)offset+=tangent_dimreturna(**constructed_fields)else:constructed_fields={}offset=0forfieldindataclasses.fields(a):field_instance=getattr(a,field.name)tangent_dim=LieGroupOps.tangent_dim(field_instance)constructed_fields[field.name]=LieGroupOps.from_tangent(field_instance,vec[offset:offset+tangent_dim],epsilon)offset+=tangent_dimreturntyping_util.get_type(a)(**constructed_fields)