# ----------------------------------------------------------------------------# SymForce - Copyright 2022, Skydio, Inc.# This source code is under the Apache 2.0 license found in the LICENSE file.# ----------------------------------------------------------------------------"""General python typing-related utilities"""importdataclassesimportnumpyasnpimportsymforce.internal.symbolicassffromsymforceimporttypingasT
[docs]defget_type(a:T.Any)->T.Type:""" Returns the type of the element if its an instance, or a pass through if already a type. """ifisinstance(a,type):returnaelse:returntype(a)
# NOTE(brad): Each of these classes is automatically registered with ScalarLieGroupOps. (see# ops/__init__.py)SCALAR_TYPES=(float,np.float16,np.float32,np.float64,# NOTE(hayk): It's weird to call integers lie groups, but the implementation of ScalarLieGroupOps# converts everything to symbolic types so it acts like a floating point.int,np.int8,np.int16,np.int32,np.int64,)
[docs]defscalar_like(a:T.Any)->bool:""" Returns whether the element is scalar-like (an int, float, or sympy expression). This method does not rely on the value of a, only the type. """a_type=get_type(a)ifissubclass(a_type,SCALAR_TYPES):returnTrueis_expr=issubclass(a_type,sf.Expr)ifnotis_expr:returnFalse# It is an expr, check that it's not a matrixis_matrix=issubclass(a_type,sf.sympy.MatrixBase)or(hasattr(a,"is_Matrix")anda.is_Matrix)returnnotis_matrix
[docs]defget_sequence_from_dataclass_sequence_field(field:dataclasses.Field,field_type:T.Type)->T.Sequence[T.Any]:origin=T.get_origin(field_type)length=field.metadata.get("length")iforiginisNoneornotissubclass(origin,T.Sequence):raiseTypeError(f"Annotated field with `length={length}` that is of type {field_type}, not T.Sequence")assertisinstance(length,int)arg_type=T.get_args(field_type)[0]return[arg_type]*length
[docs]defmaybe_tuples_of_types_from_annotation(annotation:T.Union[T.Type,T.Any],return_annotation_if_not_tuple:bool=False)->T.Optional[T.Union[T.Tuple[T.Union[T.Tuple,T.Type]],T.Any]]:""" Attempt to construct a tuple of types from an annotation of the form ``T.Tuple[A, B, C]`` of any fixed length, recursively. If this is not possible, because the annotation is not a ``T.Tuple``, returns: 1) The annotation itself, if ``return_annotation_if_not_tuple`` is True 2) ``None``, otherwise If the annotation is a ``T.Tuple``, but is of unknown length, returns ``None`` """origin=T.get_origin(annotation)ifnotisinstance(origin,type)ornotissubclass(origin,T.cast(T.Type,T.Tuple)):ifreturn_annotation_if_not_tuple:returnannotationelse:returnNoneargs=T.get_args(annotation)ifEllipsisinargs:ifreturn_annotation_if_not_tuple:raiseValueError()else:returnNonereturntuple(maybe_tuples_of_types_from_annotation(arg,return_annotation_if_not_tuple=True)forarginargs)