blob: 68371c12dc72f9306bda28d81ee7560762e2f238 [file] [log] [blame]
import contextlib
import functools
import hashlib
import os
import re
import textwrap
import sys
from argparse import Namespace
from dataclasses import (
fields,
is_dataclass,
)
from typing import (
Tuple,
List,
Iterable,
Iterator,
Callable,
Sequence,
TypeVar,
Optional,
Dict,
Any,
Union,
Set,
NoReturn,
)
from enum import Enum
from torchgen.code_template import CodeTemplate
# Safely load fast C Yaml loader/dumper if they are available
try:
from yaml import CSafeLoader as Loader
except ImportError:
from yaml import SafeLoader as Loader # type: ignore[misc]
try:
from yaml import CSafeDumper as Dumper
except ImportError:
from yaml import SafeDumper as Dumper # type: ignore[misc]
YamlDumper = Dumper
# A custom loader for YAML that errors on duplicate keys.
# This doesn't happen by default: see https://github.com/yaml/pyyaml/issues/165
class YamlLoader(Loader):
def construct_mapping(self, node, deep=False): # type: ignore[no-untyped-def]
mapping = []
for key_node, value_node in node.value:
key = self.construct_object(key_node, deep=deep) # type: ignore[no-untyped-call]
assert (
key not in mapping
), f"Found a duplicate key in the yaml. key={key}, line={node.start_mark.line}"
mapping.append(key)
mapping = super().construct_mapping(node, deep=deep) # type: ignore[no-untyped-call]
return mapping
# Many of these functions share logic for defining both the definition
# and declaration (for example, the function signature is the same), so
# we organize them into one function that takes a Target to say which
# code we want.
#
# This is an OPEN enum (we may add more cases to it in the future), so be sure
# to explicitly specify with Union[Literal[Target.XXX]] what targets are valid
# for your use.
Target = Enum(
"Target",
(
# top level namespace (not including at)
"DEFINITION",
"DECLARATION",
# TORCH_LIBRARY(...) { ... }
"REGISTRATION",
# namespace { ... }
"ANONYMOUS_DEFINITION",
# namespace cpu { ... }
"NAMESPACED_DEFINITION",
"NAMESPACED_DECLARATION",
),
)
# Matches "foo" in "foo, bar" but not "foobar". Used to search for the
# occurrence of a parameter in the derivative formula
IDENT_REGEX = r"(^|\W){}($|\W)"
# TODO: Use a real parser here; this will get bamboozled
def split_name_params(schema: str) -> Tuple[str, List[str]]:
m = re.match(r"(\w+)(\.\w+)?\((.*)\)", schema)
if m is None:
raise RuntimeError(f"Unsupported function schema: {schema}")
name, _, params = m.groups()
return name, params.split(", ")
T = TypeVar("T")
S = TypeVar("S")
# These two functions purposely return generators in analogy to map()
# so that you don't mix up when you need to list() them
# Map over function that may return None; omit Nones from output sequence
def mapMaybe(func: Callable[[T], Optional[S]], xs: Iterable[T]) -> Iterator[S]:
for x in xs:
r = func(x)
if r is not None:
yield r
# Map over function that returns sequences and cat them all together
def concatMap(func: Callable[[T], Sequence[S]], xs: Iterable[T]) -> Iterator[S]:
for x in xs:
for r in func(x):
yield r
# Conveniently add error context to exceptions raised. Lets us
# easily say that an error occurred while processing a specific
# context.
@contextlib.contextmanager
def context(msg_fn: Callable[[], str]) -> Iterator[None]:
try:
yield
except Exception as e:
# TODO: this does the wrong thing with KeyError
msg = msg_fn()
msg = textwrap.indent(msg, " ")
msg = f"{e.args[0]}\n{msg}" if e.args else msg
e.args = (msg,) + e.args[1:]
raise
# A little trick from https://github.com/python/mypy/issues/6366
# for getting mypy to do exhaustiveness checking
# TODO: put this somewhere else, maybe
def assert_never(x: NoReturn) -> NoReturn:
raise AssertionError("Unhandled type: {}".format(type(x).__name__))
@functools.lru_cache(maxsize=None)
def _read_template(template_fn: str) -> CodeTemplate:
return CodeTemplate.from_file(template_fn)
# String hash that's stable across different executions, unlike builtin hash
def string_stable_hash(s: str) -> int:
sha1 = hashlib.sha1(s.encode("latin1")).digest()
return int.from_bytes(sha1, byteorder="little")
# A small abstraction for writing out generated files and keeping track
# of what files have been written (so you can write out a list of output
# files)
class FileManager:
install_dir: str
template_dir: str
dry_run: bool
filenames: Set[str]
def __init__(self, install_dir: str, template_dir: str, dry_run: bool) -> None:
self.install_dir = install_dir
self.template_dir = template_dir
self.filenames = set()
self.dry_run = dry_run
def _write_if_changed(self, filename: str, contents: str) -> None:
old_contents: Optional[str]
try:
with open(filename, "r") as f:
old_contents = f.read()
except IOError:
old_contents = None
if contents != old_contents:
# Create output directory if it doesn't exist
os.makedirs(os.path.dirname(filename), exist_ok=True)
with open(filename, "w") as f:
f.write(contents)
def write_with_template(
self,
filename: str,
template_fn: str,
env_callable: Callable[[], Union[str, Dict[str, Any]]],
) -> None:
filename = "{}/{}".format(self.install_dir, filename)
assert filename not in self.filenames, "duplicate file write {filename}"
self.filenames.add(filename)
if not self.dry_run:
env = env_callable()
if isinstance(env, dict):
# TODO: Update the comment reference to the correct location
if "generated_comment" not in env:
comment = "@" + "generated by torchgen/gen.py"
comment += " from {}".format(os.path.basename(template_fn))
env["generated_comment"] = comment
template = _read_template(os.path.join(self.template_dir, template_fn))
self._write_if_changed(filename, template.substitute(env))
elif isinstance(env, str):
self._write_if_changed(filename, env)
else:
assert_never(env)
def write(
self,
filename: str,
env_callable: Callable[[], Union[str, Union[str, Dict[str, Any]]]],
) -> None:
self.write_with_template(filename, filename, env_callable)
def write_sharded(
self,
filename: str,
items: Iterable[T],
*,
key_fn: Callable[[T], str],
env_callable: Callable[[T], Dict[str, List[str]]],
num_shards: int,
base_env: Optional[Dict[str, Any]] = None,
sharded_keys: Set[str],
) -> None:
everything: Dict[str, Any] = {"shard_id": "Everything"}
shards: List[Dict[str, Any]] = [
{"shard_id": f"_{i}"} for i in range(num_shards)
]
all_shards = [everything] + shards
if base_env is not None:
for shard in all_shards:
shard.update(base_env)
for key in sharded_keys:
for shard in all_shards:
if key in shard:
assert isinstance(
shard[key], list
), "sharded keys in base_env must be a list"
shard[key] = shard[key].copy()
else:
shard[key] = []
def merge_env(into: Dict[str, List[str]], from_: Dict[str, List[str]]) -> None:
for k, v in from_.items():
assert k in sharded_keys, f"undeclared sharded key {k}"
into[k] += v
if self.dry_run:
# Dry runs don't write any templates, so incomplete environments are fine
items = ()
for item in items:
key = key_fn(item)
sid = string_stable_hash(key) % num_shards
env = env_callable(item)
merge_env(shards[sid], env)
merge_env(everything, env)
dot_pos = filename.rfind(".")
if dot_pos == -1:
dot_pos = len(filename)
base_filename = filename[:dot_pos]
extension = filename[dot_pos:]
for shard in all_shards:
shard_id = shard["shard_id"]
self.write_with_template(
f"{base_filename}{shard_id}{extension}", filename, lambda: shard
)
# filenames is used to track compiled files, but FooEverything.cpp isn't meant to be compiled
self.filenames.discard(
f"{self.install_dir}/{base_filename}Everything{extension}"
)
def write_outputs(self, variable_name: str, filename: str) -> None:
"""Write a file containing the list of all outputs which are
generated by this script."""
content = "set({}\n {})".format(
variable_name,
"\n ".join('"' + name + '"' for name in sorted(self.filenames)),
)
self._write_if_changed(filename, content)
# Helper function to generate file manager
def make_file_manager(
options: Namespace, install_dir: Optional[str] = None
) -> FileManager:
template_dir = os.path.join(options.source_path, "templates")
install_dir = install_dir if install_dir else options.install_dir
return FileManager(
install_dir=install_dir, template_dir=template_dir, dry_run=options.dry_run
)
# Helper function to create a pretty representation for dataclasses
def dataclass_repr(
obj: Any,
indent: int = 0,
width: int = 80,
) -> str:
# built-in pprint module support dataclasses from python 3.10
if sys.version_info >= (3, 10):
from pprint import pformat
return pformat(obj, indent, width)
return _pformat(obj, indent=indent, width=width)
def _pformat(
obj: Any,
indent: int,
width: int,
curr_indent: int = 0,
) -> str:
assert is_dataclass(obj), f"obj should be a dataclass, received: {type(obj)}"
class_name = obj.__class__.__name__
# update current indentation level with class name
curr_indent += len(class_name) + 1
fields_list = [(f.name, getattr(obj, f.name)) for f in fields(obj) if f.repr]
fields_str = []
for name, attr in fields_list:
# update the current indent level with the field name
# dict, list, set and tuple also add indent as done in pprint
_curr_indent = curr_indent + len(name) + 1
if is_dataclass(attr):
str_repr = _pformat(attr, indent, width, _curr_indent)
elif isinstance(attr, dict):
str_repr = _format_dict(attr, indent, width, _curr_indent)
elif isinstance(attr, (list, set, tuple)):
str_repr = _format_list(attr, indent, width, _curr_indent)
else:
str_repr = repr(attr)
fields_str.append(f"{name}={str_repr}")
indent_str = curr_indent * " "
body = f",\n{indent_str}".join(fields_str)
return f"{class_name}({body})"
def _format_dict(
attr: Dict[Any, Any],
indent: int,
width: int,
curr_indent: int,
) -> str:
curr_indent += indent + 3
dict_repr = []
for k, v in attr.items():
k_repr = repr(k)
v_str = (
_pformat(v, indent, width, curr_indent + len(k_repr))
if is_dataclass(v)
else repr(v)
)
dict_repr.append(f"{k_repr}: {v_str}")
return _format(dict_repr, indent, width, curr_indent, "{", "}")
def _format_list(
attr: Union[List[Any], Set[Any], Tuple[Any, ...]],
indent: int,
width: int,
curr_indent: int,
) -> str:
curr_indent += indent + 1
list_repr = [
_pformat(l, indent, width, curr_indent) if is_dataclass(l) else repr(l)
for l in attr
]
start, end = ("[", "]") if isinstance(attr, list) else ("(", ")")
return _format(list_repr, indent, width, curr_indent, start, end)
def _format(
fields_str: List[str],
indent: int,
width: int,
curr_indent: int,
start: str,
end: str,
) -> str:
delimiter, curr_indent_str = "", ""
# if it exceed the max width then we place one element per line
if len(repr(fields_str)) >= width:
delimiter = "\n"
curr_indent_str = " " * curr_indent
indent_str = " " * indent
body = f", {delimiter}{curr_indent_str}".join(fields_str)
return f"{start}{indent_str}{body}{end}"
class NamespaceHelper:
"""A helper for constructing the namespace open and close strings for a nested set of namespaces.
e.g. for namespace_str torch::lazy,
prologue:
namespace torch {
namespace lazy {
epilogue:
} // namespace lazy
} // namespace torch
"""
def __init__(self, namespace_str: str, entity_name: str = "", max_level: int = 2):
# cpp_namespace can be a colon joined string such as torch::lazy
cpp_namespaces = namespace_str.split("::")
assert (
len(cpp_namespaces) <= max_level
), f"Codegen doesn't support more than {max_level} level(s) of custom namespace. Got {namespace_str}."
self.cpp_namespace_ = namespace_str
self.prologue_ = "\n".join([f"namespace {n} {{" for n in cpp_namespaces])
self.epilogue_ = "\n".join(
[f"}} // namespace {n}" for n in reversed(cpp_namespaces)]
)
self.namespaces_ = cpp_namespaces
self.entity_name_ = entity_name
@staticmethod
def from_namespaced_entity(
namespaced_entity: str, max_level: int = 2
) -> "NamespaceHelper":
"""
Generate helper from nested namespaces as long as class/function name. E.g.: "torch::lazy::add"
"""
names = namespaced_entity.split("::")
entity_name = names[-1]
namespace_str = "::".join(names[:-1])
return NamespaceHelper(
namespace_str=namespace_str, entity_name=entity_name, max_level=max_level
)
@property
def prologue(self) -> str:
return self.prologue_
@property
def epilogue(self) -> str:
return self.epilogue_
@property
def entity_name(self) -> str:
return self.entity_name_
# Only allow certain level of namespaces
def get_cpp_namespace(self, default: str = "") -> str:
"""
Return the namespace string from joining all the namespaces by "::" (hence no leading "::").
Return default if namespace string is empty.
"""
return self.cpp_namespace_ if self.cpp_namespace_ else default