Source code for symforce.typing_util

# ----------------------------------------------------------------------------
# SymForce - Copyright 2022, Skydio, Inc.
# This source code is under the Apache 2.0 license found in the LICENSE file.
# ----------------------------------------------------------------------------

"""
General python typing-related utilities
"""

import dataclasses

import numpy as np

import symforce.internal.symbolic as sf
from symforce import typing as T


[docs]def get_type(a: T.Any) -> T.Type: """ Returns the type of the element if its an instance, or a pass through if already a type. """ if isinstance(a, type): return a else: return type(a)
# NOTE(brad): Each of these classes is automatically registered with ScalarLieGroupOps. (see # ops/__init__.py) SCALAR_TYPES = ( float, np.float16, np.float32, np.float64, # NOTE(hayk): It's weird to call integers lie groups, but the implementation of ScalarLieGroupOps # converts everything to symbolic types so it acts like a floating point. int, np.int8, np.int16, np.int32, np.int64, )
[docs]def scalar_like(a: T.Any) -> bool: """ Returns whether the element is scalar-like (an int, float, or sympy expression). This method does not rely on the value of a, only the type. """ a_type = get_type(a) if issubclass(a_type, SCALAR_TYPES): return True is_expr = issubclass(a_type, sf.Expr) if not is_expr: return False # It is an expr, check that it's not a matrix is_matrix = issubclass(a_type, sf.sympy.MatrixBase) or (hasattr(a, "is_Matrix") and a.is_Matrix) return not is_matrix
[docs]def get_sequence_from_dataclass_sequence_field( field: dataclasses.Field, field_type: T.Type ) -> T.Sequence[T.Any]: origin = T.get_origin(field_type) length = field.metadata.get("length") if origin is None or not issubclass(origin, T.Sequence): raise TypeError( f"Annotated field with `length={length}` that is of type {field_type}, not T.Sequence" ) assert isinstance(length, int) arg_type = T.get_args(field_type)[0] return [arg_type] * length
[docs]def maybe_tuples_of_types_from_annotation( annotation: T.Union[T.Type, T.Any], return_annotation_if_not_tuple: bool = False ) -> T.Optional[T.Union[T.Tuple[T.Union[T.Tuple, T.Type]], T.Any]]: """ Attempt to construct a tuple of types from an annotation of the form ``T.Tuple[A, B, C]`` of any fixed length, recursively. If this is not possible, because the annotation is not a ``T.Tuple``, returns: 1) The annotation itself, if ``return_annotation_if_not_tuple`` is True 2) ``None``, otherwise If the annotation is a ``T.Tuple``, but is of unknown length, returns ``None`` """ origin = T.get_origin(annotation) if not isinstance(origin, type) or not issubclass(origin, T.cast(T.Type, T.Tuple)): if return_annotation_if_not_tuple: return annotation else: return None args = T.get_args(annotation) if Ellipsis in args: if return_annotation_if_not_tuple: raise ValueError() else: return None return tuple( maybe_tuples_of_types_from_annotation(arg, return_annotation_if_not_tuple=True) for arg in args )