Source code for symforce.codegen.backends.pytorch.pytorch_config
# ----------------------------------------------------------------------------
# SymForce - Copyright 2022, Skydio, Inc.
# This source code is under the Apache 2.0 license found in the LICENSE file.
# ----------------------------------------------------------------------------
from dataclasses import dataclass
from pathlib import Path
from sympy.printing.codeprinter import CodePrinter
from symforce import typing as T
from symforce.codegen.backends.pytorch import pytorch_code_printer
from symforce.codegen.codegen_config import CodegenConfig
CURRENT_DIR = Path(__file__).parent
[docs]
@dataclass
class PyTorchConfig(CodegenConfig):
"""
Code generation config for the PyTorch backend.
Args:
doc_comment_line_prefix: Prefix applied to each line in a docstring
line_length: Maximum allowed line length in docstrings; used for formatting docstrings.
use_eigen_types: Use eigen_lcm types for vectors instead of lists
autoformat: Run a code formatter on the generated code
custom_preamble: An optional string to be prepended on the front of the rendered template
cse_optimizations: Optimizations argument to pass to :func:`sf.cse <symforce.symbolic.cse>`
zero_epsilon_behavior: What should codegen do if a default epsilon is not set?
normalize_results: Should function outputs be explicitly projected onto the manifold before
returning?
"""
doc_comment_line_prefix: str = ""
line_length: int = 100
use_eigen_types: bool = False
[docs]
@classmethod
def backend_name(cls) -> str:
return "pytorch"
[docs]
@classmethod
def template_dir(cls) -> Path:
return CURRENT_DIR / "templates"
[docs]
@staticmethod
def templates_to_render(generated_file_name: str) -> T.List[T.Tuple[str, str]]:
return [
("function/FUNCTION.py.jinja", f"{generated_file_name}.py"),
("function/__init__.py.jinja", "__init__.py"),
]
[docs]
@staticmethod
def printer() -> CodePrinter:
return pytorch_code_printer.PyTorchCodePrinter()