Source code for symforce.codegen.similarity_index

# ----------------------------------------------------------------------------
# SymForce - Copyright 2022, Skydio, Inc.
# This source code is under the Apache 2.0 license found in the LICENSE file.
# ----------------------------------------------------------------------------

from __future__ import annotations

import dataclasses

from symforce import codegen
from symforce import typing as T
from symforce.codegen import codegen_config
from symforce.values import Values


[docs]@dataclasses.dataclass class SimilarityIndex: """ Contains all the information needed to assess if two :class:`Codegen <symforce.codegen.codegen.Codegen>` objects would generate the same function, modulo function name and docstring. WARNING: :class:`SimilarityIndex` is hashable despite being mutable. This means you should be careful when storing it as a key of a dict, as ordinary keys are immutable. """ config: codegen_config.CodegenConfig inputs: Values outputs: Values return_key: T.Optional[str] # NOTE(brad): Only the keys are needed because Codegen will generate sparse_mat_data # to be the same if both the keys and the outputs are the same between two objects. sorted_sparse_matrices: T.Tuple[str, ...] = dataclasses.field(init=False) sparse_matrices: dataclasses.InitVar[T.Iterable[str]] def __post_init__(self, sparse_matrices: T.Iterable[str]) -> None: self.sorted_sparse_matrices = tuple(sorted(sparse_matrices))
[docs] @staticmethod def from_codegen(co: codegen.Codegen) -> SimilarityIndex: """ Returns the :class:`SimilarityIndex` of a :class:`Codegen <symforce.codegen.codegen.Codegen>` object. If co1 and co2 are two :class:`Codegen <symforce.codegen.codegen.Codegen>` objects, then ``from_codegen(co1) == from_codegen(co2)`` if and only if the function generated by ``co1.generate_function()`` is the same as that of ``co2.generate_function()`` (up to differences in function name and docstrings). """ return SimilarityIndex( inputs=co.inputs, outputs=co.outputs, config=co.config, return_key=co.return_key, sparse_matrices=co.sparse_mat_data.keys(), )
[docs] def __hash__(self) -> int: """ WARNING: :class:`SimilarityIndex` is mutable, and you must be mindful of the fact that keys of dicts are supposed to be immutable. If seeking to use a :class:`SimilarityIndex` in a dict as a key, encapsulate this to make sure others aren't able to modify the object after it has been hashed. """ return hash( ( tuple(self.inputs.to_storage()), tuple(self.outputs.to_storage()), self.return_key, self.sorted_sparse_matrices, # Convert to key, value tuples recursively. Unlike astuple, this has field names dataclasses.asdict( self.config, dict_factory=T.cast(T.Callable[[T.List], T.Tuple], tuple) ), ) )