Source code for symforce.ops.impl.dataclass_storage_ops

# ----------------------------------------------------------------------------
# SymForce - Copyright 2022, Skydio, Inc.
# This source code is under the Apache 2.0 license found in the LICENSE file.
# ----------------------------------------------------------------------------

import dataclasses

from symforce import typing as T
from symforce import typing_util
from symforce.ops import StorageOps


[docs]class DataclassStorageOps: """ StorageOps implementation for dataclasses Supports nested types. If any of the fields are of unknown size (e.g. sequences), the relevant functions expect to be passed an instance instead of the type. However, the length of sequences can be specified using field metadata, allowing for StorageOps functions such as ``storage_dim``, ``from_storage``, and ``symbolic`` to be passed the dataclass type rather than an instance. Adding a sequence of length 10, for example, would look like:: @dataclass class ExampleDataclass: example_list: T.Sequence[ExampleType] = field(metadata={"length": 10}) """ # NOTE(aaron): We use T.get_type_hints in multiple places in here to the field types, does this # always work? A bit worried that this never uses field.type, e.g. if it isn't a simple # annotation
[docs] @staticmethod def storage_dim(a: T.DataclassOrType) -> int: if isinstance(a, type): count = 0 type_hints_map = T.get_type_hints(a) for field in dataclasses.fields(a): field_type = type_hints_map[field.name] if field.metadata.get("length") is not None: sequence_instance = typing_util.get_sequence_from_dataclass_sequence_field( field, field_type ) count += StorageOps.storage_dim(sequence_instance) elif ( sequence_types := typing_util.maybe_tuples_of_types_from_annotation(field_type) ) is not None: # It's a Tuple of known size count += StorageOps.storage_dim(sequence_types) else: count += StorageOps.storage_dim(field_type) return count else: count = 0 for field in dataclasses.fields(a): count += StorageOps.storage_dim(getattr(a, field.name)) return count
[docs] @staticmethod def to_storage(a: T.Dataclass) -> T.List[T.Scalar]: storage = [] for field in dataclasses.fields(a): storage.extend(StorageOps.to_storage(getattr(a, field.name))) return storage
[docs] @staticmethod def from_storage(a: T.DataclassOrType, elements: T.Sequence[T.Scalar]) -> T.Dataclass: if isinstance(a, type): constructed_fields = {} offset = 0 type_hints_map = T.get_type_hints(a) for field in dataclasses.fields(a): field_type = type_hints_map[field.name] if field.metadata.get("length") is not None: sequence_instance = typing_util.get_sequence_from_dataclass_sequence_field( field, field_type ) storage_dim = StorageOps.storage_dim(sequence_instance) constructed_fields[field.name] = StorageOps.from_storage( sequence_instance, elements[offset : offset + storage_dim] ) elif ( sequence_types := typing_util.maybe_tuples_of_types_from_annotation(field_type) ) is not None: # It's a Tuple of known size storage_dim = StorageOps.storage_dim(sequence_types) constructed_fields[field.name] = StorageOps.from_storage( sequence_types, elements[offset : offset + storage_dim] ) else: storage_dim = StorageOps.storage_dim(field_type) constructed_fields[field.name] = StorageOps.from_storage( field_type, elements[offset : offset + storage_dim] ) offset += storage_dim return a(**constructed_fields) else: constructed_fields = {} offset = 0 for field in dataclasses.fields(a): field_instance = getattr(a, field.name) storage_dim = StorageOps.storage_dim(field_instance) constructed_fields[field.name] = StorageOps.from_storage( field_instance, elements[offset : offset + storage_dim] ) offset += storage_dim return typing_util.get_type(a)(**constructed_fields)
[docs] @staticmethod def symbolic(a: T.DataclassOrType, name: T.Optional[str], **kwargs: T.Dict) -> T.Dataclass: """ Return a symbolic instance of a Dataclass Names are chosen by creating each field with symbolic name {name}.{field_name}. If the `name` argument is not given, that part is left off, and fields are created with just {field_name}. """ if isinstance(a, type): constructed_fields = {} name_prefix = f"{name}." if name is not None else "" type_hints_map = T.get_type_hints(a) for field in dataclasses.fields(a): field_type = type_hints_map[field.name] try: if field.metadata.get("length") is not None: sequence_instance = typing_util.get_sequence_from_dataclass_sequence_field( field, field_type ) constructed_fields[field.name] = StorageOps.symbolic( sequence_instance, f"{name_prefix}{field.name}", **kwargs ) elif ( sequence_types := typing_util.maybe_tuples_of_types_from_annotation( field_type ) ) is not None: # It's a Tuple of known size constructed_fields[field.name] = StorageOps.symbolic( sequence_types, f"{name_prefix}{field.name}", **kwargs ) else: constructed_fields[field.name] = StorageOps.symbolic( field_type, f"{name_prefix}{field.name}", **kwargs ) except NotImplementedError as ex: raise NotImplementedError( f"Could not create field {field.name} of type {field_type}" ) from ex return typing_util.get_type(a)(**constructed_fields) else: constructed_fields = {} name_prefix = f"{name}." if name is not None else "" for field in dataclasses.fields(a): field_instance = getattr(a, field.name) try: constructed_fields[field.name] = StorageOps.symbolic( field_instance, f"{name_prefix}{field.name}", **kwargs ) except NotImplementedError as ex: raise NotImplementedError( f"Could not create field {field.name} of type {field_instance}" ) from ex return typing_util.get_type(a)(**constructed_fields)