Source code for symforce.opt.timestep_sub_problem
# ----------------------------------------------------------------------------
# SymForce - Copyright 2022, Skydio, Inc.
# This source code is under the Apache 2.0 license found in the LICENSE file.
# ----------------------------------------------------------------------------
import dataclasses
from symforce import ops
from symforce import typing as T
from symforce import typing_util
from symforce.opt.sub_problem import SubProblem
[docs]class TimestepSubProblem(SubProblem):
"""
A SubProblem intended for use when the Inputs block contains sequences tied to timesteps.
Provides a :attr:`timesteps` variable for the number of timesteps, and a :meth:`build_inputs`
function which works for Inputs blocks containing sequences as long as the number of timesteps.
Args:
timesteps: The number of timesteps
name: (optional) The name of the subproblem, derived from the class name by default
"""
timesteps: int
def __init__(self, timesteps: int, name: T.Optional[str] = None) -> None:
self.timesteps = timesteps
super().__init__(name=name)
[docs] def build_inputs(self) -> None:
"""
Build the inputs block of the subproblem, and store in self.inputs.
Each field in the subproblem Inputs that's meant to be a sequence of length
:attr:`timesteps` should be marked with ``"timestepped": True`` in the field metadata. Other
sequences of known length should be marked with the ``"length": <sequence length>`` in the
field metadata, where ``<sequence length>`` is the length of the sequence. For example::
@dataclass
class Inputs:
my_timestepped_field: T.Sequence[sf.Scalar] = field(metadata={"timestepped": True})
my_sequence_field: T.Sequence[sf.Scalar] = field(metadata={"length": 3})
Any remaining fields of unknown size will cause an exception.
"""
constructed_fields = {}
type_hints_map = T.get_type_hints(self.Inputs)
for field in dataclasses.fields(self.Inputs):
field_type = type_hints_map[field.name]
if field.metadata.get("timestepped", False):
field_type = T.get_args(field_type)[0]
constructed_fields[field.name] = [
ops.StorageOps.symbolic(field_type, f"{self.name}.{field.name}[{i}]")
for i in range(self.timesteps)
]
elif field.metadata.get("length", False):
sequence_instance = typing_util.get_sequence_from_dataclass_sequence_field(
field, field_type
)
constructed_fields[field.name] = ops.StorageOps.symbolic(
sequence_instance, f"{self.name}.{field.name}"
)
else:
try:
constructed_fields[field.name] = ops.StorageOps.symbolic(
field_type, f"{self.name}.{field.name}"
)
except NotImplementedError as ex:
raise TypeError(
f"Could not create instance of type {field_type} for field "
f"{self.name}.{field.name}; if this is a sequence, please either annotate "
"with timestepped=True, or override build_inputs"
) from ex
self.inputs = self.Inputs(**constructed_fields)