Source code for symforce.ops.impl.dataclass_group_ops
# ----------------------------------------------------------------------------
# SymForce - Copyright 2022, Skydio, Inc.
# This source code is under the Apache 2.0 license found in the LICENSE file.
# ----------------------------------------------------------------------------
import dataclasses
from symforce import typing as T
from symforce import typing_util
from symforce.ops import GroupOps
from .dataclass_storage_ops import DataclassStorageOps
[docs]class DataclassGroupOps(DataclassStorageOps):
[docs] @staticmethod
def identity(a: T.DataclassOrType) -> T.Dataclass:
constructed_fields = {}
if isinstance(a, type):
type_hints_map = T.get_type_hints(a)
for field in dataclasses.fields(a):
field_type = type_hints_map[field.name]
if field.metadata.get("length") is not None:
sequence_instance = typing_util.get_sequence_from_dataclass_sequence_field(
field, field_type
)
constructed_fields[field.name] = GroupOps.identity(sequence_instance)
elif (
sequence_types := typing_util.maybe_tuples_of_types_from_annotation(field_type)
) is not None:
# It's a Tuple of known size
constructed_fields[field.name] = GroupOps.identity(sequence_types)
else:
constructed_fields[field.name] = GroupOps.identity(field_type)
return a(**constructed_fields)
else:
for field in dataclasses.fields(a):
constructed_fields[field.name] = GroupOps.identity(getattr(a, field.name))
return typing_util.get_type(a)(**constructed_fields)
[docs] @staticmethod
def compose(a: T.Dataclass, b: T.Dataclass) -> T.Dataclass:
assert typing_util.get_type(a) == typing_util.get_type(b)
constructed_fields = {}
for field in dataclasses.fields(a):
constructed_fields[field.name] = GroupOps.compose(
getattr(a, field.name), getattr(b, field.name)
)
return typing_util.get_type(a)(**constructed_fields)
[docs] @staticmethod
def inverse(a: T.Dataclass) -> T.Dataclass:
constructed_fields = {}
for field in dataclasses.fields(a):
constructed_fields[field.name] = GroupOps.inverse(getattr(a, field.name))
return typing_util.get_type(a)(**constructed_fields)