Source code for symforce.ops.ops

# ----------------------------------------------------------------------------
# SymForce - Copyright 2022, Skydio, Inc.
# This source code is under the Apache 2.0 license found in the LICENSE file.
# ----------------------------------------------------------------------------

import inspect

from symforce import typing as T


[docs]class OpNotImplementedError(NotImplementedError): """ Exception raised when an operation is not implemented for a given type. """
[docs]class Ops: """ Class for specifying how Storage/Group/LieGroup ops functions should be implemented for specific types (e.g. scalars, geo objects, etc.). Classes that inherit from Ops can be considered "concepts" (see `<https://en.wikipedia.org/wiki/Concept_(generic_programming)>`_), meaning that they define a set of valid operations on the types (or subtypes) registered with this base class. As classes are created, they (or one of their parent classes) must be registered by calling :meth:`register`, which specifies a specific implementation of the ops for that class. This is similar to template specialization in C++. """ # An item of IMPLEMENTATIONS is of the form (DataType, (OpsClass, Implementation)), # where OpsClass is the subclass of Ops we are registering with, DataType is the # type we are registering with OpsClass, and Implementation is the class whose # methods know how to perform OpsClass operations on DataType. # # For example, the item that is added by StorageOps.register(float, ScalarStorageOps) # is (float, (StorageOps, ScalarStorageOps)). IMPLEMENTATIONS: T.Dict[T.Type, T.Tuple[T.Type, T.Type]] = {}
[docs] @classmethod def register(cls, impl_type: T.Type, impl_ops: T.Type) -> None: """ Register the operations class for a given type. Once a type is registered, child classes of Ops will be able to call functions defined in ``impl_ops``. Example:: # ScalarStorageOps defines valid storage operations on floats StorageOps.register(float, ScalarStorageOps) # We can now perform storage operations of objects of type float StorageOps.storage_dim(1.0) Args: impl_type: Type to be registered impl_ops: Class defining how each operation is implemented for the given type """ assert impl_type not in cls.IMPLEMENTATIONS cls.IMPLEMENTATIONS[impl_type] = (cls, impl_ops)
[docs] @classmethod def implementation(cls, impl_type: T.Type) -> T.Type: """ Returns the class defining the operations for the given type or one of its parent classes. If multiple parent classes are registered with the calling class, the implementation of the first such parent class in method resolution order is returned. Raises: OpNotImplementedError: If ``impl_type`` or one of its parent classes is not registered with the calling class or one of its subclasses. """ registered_and_base: T.Set[T.Tuple[T.Type, T.Type]] = set() for base_class in inspect.getmro(impl_type): reg_class_and_impl = cls.IMPLEMENTATIONS.get(base_class, None) if reg_class_and_impl is not None: registration_class, impl = reg_class_and_impl if issubclass(registration_class, cls): return impl registered_and_base.add((registration_class, base_class)) # Handle Protocols/Metaclasses that aren't in the MRO (method resolution order), by trying # every class that's registered, and checking if impl_type is a subclass. for base_class, (registration_class, impl) in cls.IMPLEMENTATIONS.items(): if issubclass(impl_type, base_class): if issubclass(registration_class, cls): return impl registered_and_base.add((registration_class, base_class)) if len(registered_and_base) != 0: raise OpNotImplementedError( f"Class {impl_type} is not registered under {cls.__name__}. It does have " "registered implementations for:\n{reg_info}".format( reg_info="\n".join( [ f" - {reg.__name__} (via base class {base})" for reg, base in registered_and_base ] ) ) ) raise OpNotImplementedError(f"{impl_type} is not registered under {cls.__name__}")