# ----------------------------------------------------------------------------
# SymForce - Copyright 2022, Skydio, Inc.
# This source code is under the Apache 2.0 license found in the LICENSE file.
# ----------------------------------------------------------------------------
from __future__ import annotations
import dataclasses
import enum
import functools
import os
import textwrap
import warnings
from pathlib import Path
import jinja2
import jinja2.ext
from symforce import logger
from symforce import typing as T
from symforce.codegen import format_util
from symforce.codegen.codegen_config import RenderTemplateConfig
CURRENT_DIR = Path(__file__).parent
LCM_TEMPLATE_DIR = CURRENT_DIR / "lcm_templates"
[docs]class FileType(enum.Enum):
CPP = enum.auto()
PYTHON = enum.auto()
PYTHON_INTERFACE = enum.auto()
CUDA = enum.auto()
LCM = enum.auto()
MAKEFILE = enum.auto()
TYPESCRIPT = enum.auto()
TOML = enum.auto()
[docs] @staticmethod
def from_extension(extension: str) -> FileType:
if extension in ("c", "cpp", "cxx", "cc", "tcc", "h", "hpp", "hxx", "hh"):
return FileType.CPP
elif extension in ("cu", "cuh"):
return FileType.CUDA
elif extension == "py":
return FileType.PYTHON
elif extension == "pyi":
return FileType.PYTHON_INTERFACE
elif extension == "lcm":
return FileType.LCM
elif extension == "Makefile":
return FileType.MAKEFILE
elif extension == "ts":
return FileType.TYPESCRIPT
elif extension == "toml":
return FileType.TOML
else:
raise ValueError(f"Could not get FileType from extension {extension}")
[docs] @staticmethod
def from_template_path(template_path: Path) -> FileType:
parts = template_path.name.split(".")
if parts[-1] != "jinja":
raise ValueError(
f"template must be of the form path/to/file.ext.jinja, got {template_path}"
)
return FileType.from_extension(parts[-2])
[docs]class RelEnvironment(jinja2.Environment):
"""
Override ``join_path()`` to enable relative template paths. Modified from the below post.
https://stackoverflow.com/questions/8512677/how-to-include-a-template-with-relative-path-in-jinja2
"""
[docs] def join_path(self, template: T.Union[jinja2.Template, str], parent: str) -> str:
return os.path.normpath(os.path.join(os.path.dirname(parent), str(template)))
[docs]def add_preamble(source: str, name: Path, comment_prefix: str, custom_preamble: str) -> str:
dashes = "-" * 77
preamble = (
custom_preamble
+ textwrap.dedent(
f"""
{comment_prefix} {dashes}
{comment_prefix} This file was autogenerated by symforce from template:
{comment_prefix} {name}
{comment_prefix} Do NOT modify by hand.
{comment_prefix} {dashes}
"""
).lstrip()
)
return preamble + source
[docs]@functools.lru_cache
def jinja_env(
template_dir: T.Openable, search_paths: T.Tuple[T.Openable, ...] = ()
) -> RelEnvironment:
"""
Helper function to cache the Jinja environment, which enables caching of loaded templates
"""
all_search_paths = [os.fspath(template_dir)]
all_search_paths.extend((os.fspath(p) for p in search_paths))
loader = jinja2.FileSystemLoader(searchpath=all_search_paths)
env = RelEnvironment(
loader=loader,
trim_blocks=True,
lstrip_blocks=True,
keep_trailing_newline=True,
undefined=jinja2.StrictUndefined,
)
return env
[docs]def render_template(
template_path: T.Openable,
data: T.Dict[str, T.Any],
config: RenderTemplateConfig,
*,
template_dir: T.Openable,
output_path: T.Optional[T.Openable] = None,
search_paths: T.Iterable[T.Openable] = (),
) -> str:
"""
Boilerplate to render template. Returns the rendered string and optionally writes to file.
Args:
template_path: file path of the template to render
data: dictionary of inputs for template
config: configuration options for template rendering (see RenderTemplateConfig for more
information)
template_dir: Base directory where templates are found
output_path: If provided, writes to file
search_paths: Additional directories jinja should search when resolving imports
"""
if not isinstance(template_path, Path):
template_path = Path(template_path)
if not isinstance(template_dir, Path):
template_dir = Path(template_dir)
logger.debug(f"Template IN <-- {template_dir / template_path}")
if output_path:
logger.debug(f"Template OUT --> {output_path}")
filetype = FileType.from_template_path(Path(template_path))
template = jinja_env(template_dir, search_paths=tuple(search_paths)).get_template(
os.fspath(template_path)
)
rendered_str = add_preamble(
str(template.render(**data)),
template_path,
comment_prefix=filetype.comment_prefix(),
custom_preamble=config.custom_preamble,
)
if config.autoformat:
rendered_str = filetype.autoformat(
file_contents=rendered_str,
template_name=template_path,
output_path=output_path,
)
else:
warnings.warn(
"Config.autoformat == False is deprecated, this option will be removed in a future release",
DeprecationWarning,
)
if output_path:
output_path = Path(output_path)
output_path.parent.mkdir(exist_ok=True, parents=True)
output_path.write_text(rendered_str)
return rendered_str
[docs]class TemplateList:
"""
Helper class to keep a list of (template_path, data, config, template_dir, output_path)
and render all templates in one go.
"""
[docs] @dataclasses.dataclass
class TemplateListEntry:
template_path: T.Openable
data: T.Dict[str, T.Any]
config: RenderTemplateConfig
template_dir: T.Openable
output_path: T.Optional[T.Openable]
def __init__(self, template_dir: T.Optional[T.Openable] = None) -> None:
self.items: T.List = []
self.common_template_dir = template_dir
[docs] def add(
self,
template_path: T.Openable,
data: T.Dict[str, T.Any],
config: RenderTemplateConfig,
*,
template_dir: T.Optional[T.Openable] = None,
output_path: T.Optional[T.Openable] = None,
) -> None:
if template_dir is None:
if self.common_template_dir is None:
raise ValueError(
"Argument template_dir must be supplied if the TemplateList was not initialized with a template_dir"
)
template_dir = self.common_template_dir
self.items.append(
self.TemplateListEntry(
template_path=template_path,
data=data,
config=config,
template_dir=template_dir,
output_path=output_path,
)
)
[docs] def render(self, search_paths: T.Iterable[T.Openable] = ()) -> T.List[str]:
rendered_templates = []
for entry in self.items:
rendered_templates.append(
render_template(
template_path=entry.template_path,
data=entry.data,
config=entry.config,
template_dir=entry.template_dir,
output_path=entry.output_path,
search_paths=search_paths,
)
)
return rendered_templates