# ----------------------------------------------------------------------------
# SymForce - Copyright 2022, Skydio, Inc.
# This source code is under the Apache 2.0 license found in the LICENSE file.
# ----------------------------------------------------------------------------
from __future__ import annotations
import dataclasses
from symforce import codegen
from symforce import typing as T
from symforce.codegen import codegen_config
from symforce.values import Values
[docs]@dataclasses.dataclass
class SimilarityIndex:
"""
Contains all the information needed to assess if two
:class:`Codegen <symforce.codegen.codegen.Codegen>` objects would generate the same function,
modulo function name and docstring.
WARNING: :class:`SimilarityIndex` is hashable despite being mutable. This means
you should be careful when storing it as a key of a dict, as ordinary
keys are immutable.
"""
config: codegen_config.CodegenConfig
inputs: Values
outputs: Values
return_key: T.Optional[str]
# NOTE(brad): Only the keys are needed because Codegen will generate sparse_mat_data
# to be the same if both the keys and the outputs are the same between two objects.
sorted_sparse_matrices: T.Tuple[str, ...] = dataclasses.field(init=False)
sparse_matrices: dataclasses.InitVar[T.Iterable[str]]
def __post_init__(self, sparse_matrices: T.Iterable[str]) -> None:
self.sorted_sparse_matrices = tuple(sorted(sparse_matrices))
[docs] @staticmethod
def from_codegen(co: codegen.Codegen) -> SimilarityIndex:
"""
Returns the :class:`SimilarityIndex` of a
:class:`Codegen <symforce.codegen.codegen.Codegen>` object.
If co1 and co2 are two :class:`Codegen <symforce.codegen.codegen.Codegen>` objects, then
``from_codegen(co1) == from_codegen(co2)`` if and only if the function
generated by ``co1.generate_function()`` is the same as that of ``co2.generate_function()``
(up to differences in function name and docstrings).
"""
return SimilarityIndex(
inputs=co.inputs,
outputs=co.outputs,
config=co.config,
return_key=co.return_key,
sparse_matrices=co.sparse_mat_data.keys(),
)
[docs] def __hash__(self) -> int:
"""
WARNING: :class:`SimilarityIndex` is mutable, and you must be mindful of the fact that
keys of dicts are supposed to be immutable.
If seeking to use a :class:`SimilarityIndex` in a dict as a key, encapsulate this to
make sure others aren't able to modify the object after it has been hashed.
"""
return hash(
(
tuple(self.inputs.to_storage()),
tuple(self.outputs.to_storage()),
self.return_key,
self.sorted_sparse_matrices,
# Convert to key, value tuples recursively. Unlike astuple, this has field names
dataclasses.asdict(
self.config, dict_factory=T.cast(T.Callable[[T.List], T.Tuple], tuple)
),
)
)