Source code for symforce.test_util.stubs_util
# ----------------------------------------------------------------------------
# SymForce - Copyright 2022, Skydio, Inc.
# This source code is under the Apache 2.0 license found in the LICENSE file.
# ----------------------------------------------------------------------------
from __future__ import annotations
import typing as T
import pybind11_stubgen
import pybind11_stubgen.parser.mixins.fix as fix_mixins
import pybind11_stubgen.parser.mixins.parse as parse_mixins
from pybind11_stubgen import ExtractSignaturesFromPybind11Docstrings
from pybind11_stubgen import FixCurrentModulePrefixInTypeNames
from pybind11_stubgen import FixMissingImports as PybindFixMissingImports
from pybind11_stubgen import (
FixMissingNoneHashFieldAnnotation as PybindFixMissingNoneHashFieldAnnotation,
)
from pybind11_stubgen import FixTypingTypeNames as PybindFixTypingTypeNames
from pybind11_stubgen import IParser
from pybind11_stubgen.structs import Class
from pybind11_stubgen.structs import Docstring
from pybind11_stubgen.structs import Field
from pybind11_stubgen.structs import Identifier
from pybind11_stubgen.structs import Import
from pybind11_stubgen.structs import InvalidExpression
from pybind11_stubgen.structs import QualifiedName
from pybind11_stubgen.structs import ResolvedType
from pybind11_stubgen.structs import Value
[docs]
class FixMissingImports(PybindFixMissingImports):
def _add_import(self, name: QualifiedName) -> None:
if len(name) == 0:
return
if name[0] == Identifier("lcmtypes"):
self.__extra_imports.add(Import(name=None, origin=name.parent))
return
super()._add_import(name)
# NOTE(aaron): Fixed in https://github.com/sizmailov/pybind11-stubgen/pull/263
[docs]
def parse_annotation_str(self, annotation_str: str) -> ResolvedType | InvalidExpression | Value:
result = super().parse_annotation_str(annotation_str)
def handle_annotation(annotation: ResolvedType | InvalidExpression | Value) -> None:
if isinstance(annotation, ResolvedType):
self._add_import(annotation.name)
if annotation.parameters is not None:
for p in annotation.parameters:
handle_annotation(p)
handle_annotation(result)
return result
[docs]
def patch_lcmtype_imports() -> None:
fix_mixins.FixMissingImports = FixMissingImports # type: ignore[misc]
pybind11_stubgen.FixMissingImports = FixMissingImports # type: ignore[misc]
[docs]
def patch_current_module_prefix() -> None:
"""
Fix use of the current module in nested types
Could upstream
"""
def parse_annotation_str(
self: FixCurrentModulePrefixInTypeNames,
annotation_str: str,
) -> ResolvedType | InvalidExpression | Value:
result = super(FixCurrentModulePrefixInTypeNames, self).parse_annotation_str(annotation_str) # type: ignore[safe-super]
def handle_annotation(annotation: ResolvedType | InvalidExpression | Value) -> None:
if isinstance(annotation, ResolvedType):
annotation.name = self._strip_current_module(annotation.name)
if annotation.parameters is not None:
for p in annotation.parameters:
handle_annotation(p)
handle_annotation(result)
return result
fix_mixins.FixCurrentModulePrefixInTypeNames.parse_annotation_str = parse_annotation_str # type: ignore[method-assign]
[docs]
def patch_handle_docstring() -> None:
"""
Patch BaseParser.handle_docstring to always strip empty lines from the start or end of
docstrings
"""
def handle_docstring(self: IParser, path: QualifiedName, value: T.Any) -> T.Optional[Docstring]:
if isinstance(value, str):
assert isinstance(self, ExtractSignaturesFromPybind11Docstrings)
return self._strip_empty_lines(value.splitlines())
return None
parse_mixins.BaseParser.handle_docstring = handle_docstring # type: ignore[method-assign]
[docs]
def patch_fix_missing_none_hash_field_annotation() -> None:
"""
See https://github.com/sizmailov/pybind11-stubgen/pull/236
"""
def handle_field(
self: PybindFixMissingNoneHashFieldAnnotation,
path: QualifiedName,
field: T.Any,
) -> T.Optional[Field]:
result = super(PybindFixMissingNoneHashFieldAnnotation, self).handle_field(path, field) # type: ignore[safe-super]
if result is None:
return None
if field is None and path[-1] == "__hash__":
result.attribute.annotation = self.parse_annotation_str("typing.ClassVar[typing.Any]")
return result
fix_mixins.FixMissingNoneHashFieldAnnotation.handle_field = handle_field # type: ignore[method-assign]
[docs]
def patch_numpy_annotations() -> None:
class FixTypingTypeNames(PybindFixTypingTypeNames):
def _parse_annotation_str(
self,
result: ResolvedType | InvalidExpression | Value,
) -> ResolvedType | InvalidExpression | Value:
if not isinstance(result, ResolvedType):
return result
result.parameters = (
[self._parse_annotation_str(p) for p in result.parameters]
if result.parameters is not None
else None
)
if len(result.name) != 1:
if result.name[0] == "typing" and result.name[1] in self.__typing_extensions_names:
result.name = QualifiedName.from_str(f"typing_extensions.{result.name[1]}")
return result
word = result.name[0]
if word in self.__typing_names:
package = "typing"
if word in self.__typing_extensions_names:
package = "typing_extensions"
result.name = QualifiedName.from_str(f"{package}.{word[0].upper()}{word[1:]}")
if word == "function" and result.parameters is None:
result.name = QualifiedName.from_str("typing.Callable")
if word in {"object", "handle"} and result.parameters is None:
result.name = QualifiedName.from_str("typing.Any")
return result
fix_mixins.FixTypingTypeNames = FixTypingTypeNames # type: ignore[misc]
pybind11_stubgen.FixTypingTypeNames = FixTypingTypeNames # type: ignore[misc]
[docs]
class FixNumpyArrayRemoveParameters(IParser):
__ndarray_name = QualifiedName.from_str("numpy.typing.ArrayLike")
[docs]
def handle_class(self, path: QualifiedName, class_: type) -> T.Optional[Class]:
maybe_class = super().handle_class(path, class_) # type: ignore[safe-super]
if maybe_class is None:
return maybe_class
methods = []
for method in maybe_class.methods:
if method not in methods:
methods.append(method)
maybe_class.methods = methods
return maybe_class
[docs]
def parse_annotation_str(self, annotation_str: str) -> ResolvedType | InvalidExpression | Value:
result = super().parse_annotation_str(annotation_str) # type: ignore[safe-super]
if isinstance(result, ResolvedType) and result.name == QualifiedName.from_str(
"typing.Annotated"
):
assert (
result.parameters is not None
and len(result.parameters) >= 1
and isinstance(result.parameters[0], ResolvedType)
)
if result.parameters[0].name == self.__ndarray_name:
return result.parameters[0]
return result
[docs]
def patch_remove_parameters() -> None:
"""
Fix NumpyArrayRemoveParameters to work with pybind 3.x and deduplicate overloads
"""
fix_mixins.FixNumpyArrayRemoveParameters = FixNumpyArrayRemoveParameters # type: ignore[misc,assignment]
pybind11_stubgen.FixNumpyArrayRemoveParameters = FixNumpyArrayRemoveParameters # type: ignore[misc,assignment]