| import contextlib |
| import functools |
| import hashlib |
| import os |
| import re |
| import sys |
| import textwrap |
| from argparse import Namespace |
| from dataclasses import fields, is_dataclass |
| from enum import auto, Enum |
| from typing import ( |
| Any, |
| Callable, |
| Dict, |
| Generic, |
| Iterable, |
| Iterator, |
| List, |
| NoReturn, |
| Optional, |
| Sequence, |
| Set, |
| Tuple, |
| TypeVar, |
| Union, |
| ) |
| |
| from typing_extensions import Literal # Python 3.8+ |
| |
| from torchgen.code_template import CodeTemplate |
| |
| # Safely load fast C Yaml loader/dumper if they are available |
| try: |
| from yaml import CSafeLoader as Loader |
| except ImportError: |
| from yaml import SafeLoader as Loader # type: ignore[misc] |
| |
| try: |
| from yaml import CSafeDumper as Dumper |
| except ImportError: |
| from yaml import SafeDumper as Dumper # type: ignore[misc] |
| YamlDumper = Dumper |
| |
| # A custom loader for YAML that errors on duplicate keys. |
| # This doesn't happen by default: see https://github.com/yaml/pyyaml/issues/165 |
| class YamlLoader(Loader): |
| def construct_mapping(self, node, deep=False): # type: ignore[no-untyped-def] |
| mapping = [] |
| for key_node, value_node in node.value: |
| key = self.construct_object(key_node, deep=deep) # type: ignore[no-untyped-call] |
| assert ( |
| key not in mapping |
| ), f"Found a duplicate key in the yaml. key={key}, line={node.start_mark.line}" |
| mapping.append(key) |
| mapping = super().construct_mapping(node, deep=deep) # type: ignore[no-untyped-call] |
| return mapping |
| |
| |
| # Many of these functions share logic for defining both the definition |
| # and declaration (for example, the function signature is the same), so |
| # we organize them into one function that takes a Target to say which |
| # code we want. |
| # |
| # This is an OPEN enum (we may add more cases to it in the future), so be sure |
| # to explicitly specify with Literal[Target.XXX] or Literal[Target.XXX, Target.YYY] |
| # what targets are valid for your use. |
| class Target(Enum): |
| # top level namespace (not including at) |
| DEFINITION = auto() |
| DECLARATION = auto() |
| # TORCH_LIBRARY(...) { ... } |
| REGISTRATION = auto() |
| # namespace { ... } |
| ANONYMOUS_DEFINITION = auto() |
| # namespace cpu { ... } |
| NAMESPACED_DEFINITION = auto() |
| NAMESPACED_DECLARATION = auto() |
| |
| |
| # Matches "foo" in "foo, bar" but not "foobar". Used to search for the |
| # occurrence of a parameter in the derivative formula |
| IDENT_REGEX = r"(^|\W){}($|\W)" |
| |
| # TODO: Use a real parser here; this will get bamboozled |
| def split_name_params(schema: str) -> Tuple[str, List[str]]: |
| m = re.match(r"(\w+)(\.\w+)?\((.*)\)", schema) |
| if m is None: |
| raise RuntimeError(f"Unsupported function schema: {schema}") |
| name, _, params = m.groups() |
| return name, params.split(", ") |
| |
| |
| T = TypeVar("T") |
| S = TypeVar("S") |
| |
| # These two functions purposely return generators in analogy to map() |
| # so that you don't mix up when you need to list() them |
| |
| # Map over function that may return None; omit Nones from output sequence |
| def mapMaybe(func: Callable[[T], Optional[S]], xs: Iterable[T]) -> Iterator[S]: |
| for x in xs: |
| r = func(x) |
| if r is not None: |
| yield r |
| |
| |
| # Map over function that returns sequences and cat them all together |
| def concatMap(func: Callable[[T], Sequence[S]], xs: Iterable[T]) -> Iterator[S]: |
| for x in xs: |
| for r in func(x): |
| yield r |
| |
| |
| # Conveniently add error context to exceptions raised. Lets us |
| # easily say that an error occurred while processing a specific |
| # context. |
| @contextlib.contextmanager |
| def context(msg_fn: Callable[[], str]) -> Iterator[None]: |
| try: |
| yield |
| except Exception as e: |
| # TODO: this does the wrong thing with KeyError |
| msg = msg_fn() |
| msg = textwrap.indent(msg, " ") |
| msg = f"{e.args[0]}\n{msg}" if e.args else msg |
| e.args = (msg,) + e.args[1:] |
| raise |
| |
| |
| # A little trick from https://github.com/python/mypy/issues/6366 |
| # for getting mypy to do exhaustiveness checking |
| # TODO: put this somewhere else, maybe |
| def assert_never(x: NoReturn) -> NoReturn: |
| raise AssertionError("Unhandled type: {}".format(type(x).__name__)) |
| |
| |
| @functools.lru_cache(maxsize=None) |
| def _read_template(template_fn: str) -> CodeTemplate: |
| return CodeTemplate.from_file(template_fn) |
| |
| |
| # String hash that's stable across different executions, unlike builtin hash |
| def string_stable_hash(s: str) -> int: |
| sha1 = hashlib.sha1(s.encode("latin1")).digest() |
| return int.from_bytes(sha1, byteorder="little") |
| |
| |
| # A small abstraction for writing out generated files and keeping track |
| # of what files have been written (so you can write out a list of output |
| # files) |
| class FileManager: |
| install_dir: str |
| template_dir: str |
| dry_run: bool |
| filenames: Set[str] |
| |
| def __init__(self, install_dir: str, template_dir: str, dry_run: bool) -> None: |
| self.install_dir = install_dir |
| self.template_dir = template_dir |
| self.filenames = set() |
| self.dry_run = dry_run |
| |
| def _write_if_changed(self, filename: str, contents: str) -> None: |
| old_contents: Optional[str] |
| try: |
| with open(filename, "r") as f: |
| old_contents = f.read() |
| except IOError: |
| old_contents = None |
| if contents != old_contents: |
| # Create output directory if it doesn't exist |
| os.makedirs(os.path.dirname(filename), exist_ok=True) |
| with open(filename, "w") as f: |
| f.write(contents) |
| |
| # Read from template file and replace pattern with callable (type could be dict or str). |
| def substitute_with_template( |
| self, template_fn: str, env_callable: Callable[[], Union[str, Dict[str, Any]]] |
| ) -> str: |
| template_path = os.path.join(self.template_dir, template_fn) |
| env = env_callable() |
| if isinstance(env, dict): |
| # TODO: Update the comment reference to the correct location |
| if "generated_comment" not in env: |
| comment = "@" + "generated by torchgen/gen.py" |
| comment += " from {}".format(os.path.basename(template_path)) |
| env["generated_comment"] = comment |
| template = _read_template(template_path) |
| return template.substitute(env) |
| elif isinstance(env, str): |
| return env |
| else: |
| assert_never(env) |
| |
| def write_with_template( |
| self, |
| filename: str, |
| template_fn: str, |
| env_callable: Callable[[], Union[str, Dict[str, Any]]], |
| ) -> None: |
| filename = "{}/{}".format(self.install_dir, filename) |
| assert filename not in self.filenames, "duplicate file write {filename}" |
| self.filenames.add(filename) |
| if not self.dry_run: |
| substitute_out = self.substitute_with_template( |
| template_fn=template_fn, |
| env_callable=env_callable, |
| ) |
| self._write_if_changed(filename=filename, contents=substitute_out) |
| |
| def write( |
| self, |
| filename: str, |
| env_callable: Callable[[], Union[str, Union[str, Dict[str, Any]]]], |
| ) -> None: |
| self.write_with_template(filename, filename, env_callable) |
| |
| def write_sharded( |
| self, |
| filename: str, |
| items: Iterable[T], |
| *, |
| key_fn: Callable[[T], str], |
| env_callable: Callable[[T], Dict[str, List[str]]], |
| num_shards: int, |
| base_env: Optional[Dict[str, Any]] = None, |
| sharded_keys: Set[str], |
| ) -> None: |
| |
| everything: Dict[str, Any] = {"shard_id": "Everything"} |
| shards: List[Dict[str, Any]] = [ |
| {"shard_id": f"_{i}"} for i in range(num_shards) |
| ] |
| all_shards = [everything] + shards |
| |
| if base_env is not None: |
| for shard in all_shards: |
| shard.update(base_env) |
| |
| for key in sharded_keys: |
| for shard in all_shards: |
| if key in shard: |
| assert isinstance( |
| shard[key], list |
| ), "sharded keys in base_env must be a list" |
| shard[key] = shard[key].copy() |
| else: |
| shard[key] = [] |
| |
| def merge_env(into: Dict[str, List[str]], from_: Dict[str, List[str]]) -> None: |
| for k, v in from_.items(): |
| assert k in sharded_keys, f"undeclared sharded key {k}" |
| into[k] += v |
| |
| if self.dry_run: |
| # Dry runs don't write any templates, so incomplete environments are fine |
| items = () |
| |
| for item in items: |
| key = key_fn(item) |
| sid = string_stable_hash(key) % num_shards |
| env = env_callable(item) |
| |
| merge_env(shards[sid], env) |
| merge_env(everything, env) |
| |
| dot_pos = filename.rfind(".") |
| if dot_pos == -1: |
| dot_pos = len(filename) |
| base_filename = filename[:dot_pos] |
| extension = filename[dot_pos:] |
| |
| for shard in all_shards: |
| shard_id = shard["shard_id"] |
| self.write_with_template( |
| f"{base_filename}{shard_id}{extension}", filename, lambda: shard |
| ) |
| |
| # filenames is used to track compiled files, but FooEverything.cpp isn't meant to be compiled |
| self.filenames.discard( |
| f"{self.install_dir}/{base_filename}Everything{extension}" |
| ) |
| |
| def write_outputs(self, variable_name: str, filename: str) -> None: |
| """Write a file containing the list of all outputs which are |
| generated by this script.""" |
| content = "set({}\n {})".format( |
| variable_name, |
| "\n ".join('"' + name + '"' for name in sorted(self.filenames)), |
| ) |
| self._write_if_changed(filename, content) |
| |
| def template_dir_for_comments(self) -> str: |
| """ |
| This needs to be deterministic. The template dir is an absolute path |
| that varies across builds. So, just use the path relative to this file, |
| which will point to the codegen source but will be stable. |
| """ |
| return os.path.relpath(self.template_dir, os.path.dirname(__file__)) |
| |
| |
| # Helper function to generate file manager |
| def make_file_manager( |
| options: Namespace, install_dir: Optional[str] = None |
| ) -> FileManager: |
| template_dir = os.path.join(options.source_path, "templates") |
| install_dir = install_dir if install_dir else options.install_dir |
| return FileManager( |
| install_dir=install_dir, template_dir=template_dir, dry_run=options.dry_run |
| ) |
| |
| |
| # Helper function to create a pretty representation for dataclasses |
| def dataclass_repr( |
| obj: Any, |
| indent: int = 0, |
| width: int = 80, |
| ) -> str: |
| # built-in pprint module support dataclasses from python 3.10 |
| if sys.version_info >= (3, 10): |
| from pprint import pformat |
| |
| return pformat(obj, indent, width) |
| |
| return _pformat(obj, indent=indent, width=width) |
| |
| |
| def _pformat( |
| obj: Any, |
| indent: int, |
| width: int, |
| curr_indent: int = 0, |
| ) -> str: |
| assert is_dataclass(obj), f"obj should be a dataclass, received: {type(obj)}" |
| |
| class_name = obj.__class__.__name__ |
| # update current indentation level with class name |
| curr_indent += len(class_name) + 1 |
| |
| fields_list = [(f.name, getattr(obj, f.name)) for f in fields(obj) if f.repr] |
| |
| fields_str = [] |
| for name, attr in fields_list: |
| # update the current indent level with the field name |
| # dict, list, set and tuple also add indent as done in pprint |
| _curr_indent = curr_indent + len(name) + 1 |
| if is_dataclass(attr): |
| str_repr = _pformat(attr, indent, width, _curr_indent) |
| elif isinstance(attr, dict): |
| str_repr = _format_dict(attr, indent, width, _curr_indent) |
| elif isinstance(attr, (list, set, tuple)): |
| str_repr = _format_list(attr, indent, width, _curr_indent) |
| else: |
| str_repr = repr(attr) |
| |
| fields_str.append(f"{name}={str_repr}") |
| |
| indent_str = curr_indent * " " |
| body = f",\n{indent_str}".join(fields_str) |
| return f"{class_name}({body})" |
| |
| |
| def _format_dict( |
| attr: Dict[Any, Any], |
| indent: int, |
| width: int, |
| curr_indent: int, |
| ) -> str: |
| curr_indent += indent + 3 |
| dict_repr = [] |
| for k, v in attr.items(): |
| k_repr = repr(k) |
| v_str = ( |
| _pformat(v, indent, width, curr_indent + len(k_repr)) |
| if is_dataclass(v) |
| else repr(v) |
| ) |
| dict_repr.append(f"{k_repr}: {v_str}") |
| |
| return _format(dict_repr, indent, width, curr_indent, "{", "}") |
| |
| |
| def _format_list( |
| attr: Union[List[Any], Set[Any], Tuple[Any, ...]], |
| indent: int, |
| width: int, |
| curr_indent: int, |
| ) -> str: |
| curr_indent += indent + 1 |
| list_repr = [ |
| _pformat(l, indent, width, curr_indent) if is_dataclass(l) else repr(l) |
| for l in attr |
| ] |
| start, end = ("[", "]") if isinstance(attr, list) else ("(", ")") |
| return _format(list_repr, indent, width, curr_indent, start, end) |
| |
| |
| def _format( |
| fields_str: List[str], |
| indent: int, |
| width: int, |
| curr_indent: int, |
| start: str, |
| end: str, |
| ) -> str: |
| delimiter, curr_indent_str = "", "" |
| # if it exceed the max width then we place one element per line |
| if len(repr(fields_str)) >= width: |
| delimiter = "\n" |
| curr_indent_str = " " * curr_indent |
| |
| indent_str = " " * indent |
| body = f", {delimiter}{curr_indent_str}".join(fields_str) |
| return f"{start}{indent_str}{body}{end}" |
| |
| |
| class NamespaceHelper: |
| """A helper for constructing the namespace open and close strings for a nested set of namespaces. |
| |
| e.g. for namespace_str torch::lazy, |
| |
| prologue: |
| namespace torch { |
| namespace lazy { |
| |
| epilogue: |
| } // namespace lazy |
| } // namespace torch |
| """ |
| |
| def __init__(self, namespace_str: str, entity_name: str = "", max_level: int = 2): |
| # cpp_namespace can be a colon joined string such as torch::lazy |
| cpp_namespaces = namespace_str.split("::") |
| assert ( |
| len(cpp_namespaces) <= max_level |
| ), f"Codegen doesn't support more than {max_level} level(s) of custom namespace. Got {namespace_str}." |
| self.cpp_namespace_ = namespace_str |
| self.prologue_ = "\n".join([f"namespace {n} {{" for n in cpp_namespaces]) |
| self.epilogue_ = "\n".join( |
| [f"}} // namespace {n}" for n in reversed(cpp_namespaces)] |
| ) |
| self.namespaces_ = cpp_namespaces |
| self.entity_name_ = entity_name |
| |
| @staticmethod |
| def from_namespaced_entity( |
| namespaced_entity: str, max_level: int = 2 |
| ) -> "NamespaceHelper": |
| """ |
| Generate helper from nested namespaces as long as class/function name. E.g.: "torch::lazy::add" |
| """ |
| names = namespaced_entity.split("::") |
| entity_name = names[-1] |
| namespace_str = "::".join(names[:-1]) |
| return NamespaceHelper( |
| namespace_str=namespace_str, entity_name=entity_name, max_level=max_level |
| ) |
| |
| @property |
| def prologue(self) -> str: |
| return self.prologue_ |
| |
| @property |
| def epilogue(self) -> str: |
| return self.epilogue_ |
| |
| @property |
| def entity_name(self) -> str: |
| return self.entity_name_ |
| |
| # Only allow certain level of namespaces |
| def get_cpp_namespace(self, default: str = "") -> str: |
| """ |
| Return the namespace string from joining all the namespaces by "::" (hence no leading "::"). |
| Return default if namespace string is empty. |
| """ |
| return self.cpp_namespace_ if self.cpp_namespace_ else default |
| |
| |
| class OrderedSet(Generic[T]): |
| storage: Dict[T, Literal[None]] |
| |
| def __init__(self, iterable: Optional[Iterable[T]] = None): |
| if iterable is None: |
| self.storage = {} |
| else: |
| self.storage = {k: None for k in iterable} |
| |
| def __contains__(self, item: T) -> bool: |
| return item in self.storage |
| |
| def __iter__(self) -> Iterator[T]: |
| return iter(self.storage.keys()) |
| |
| def update(self, items: "OrderedSet[T]") -> None: |
| self.storage.update(items.storage) |
| |
| def add(self, item: T) -> None: |
| self.storage[item] = None |
| |
| def copy(self) -> "OrderedSet[T]": |
| ret: OrderedSet[T] = OrderedSet() |
| ret.storage = self.storage.copy() |
| return ret |
| |
| @staticmethod |
| def union(*args: "OrderedSet[T]") -> "OrderedSet[T]": |
| ret = args[0].copy() |
| for s in args[1:]: |
| ret.update(s) |
| return ret |
| |
| def __or__(self, other: "OrderedSet[T]") -> "OrderedSet[T]": |
| return OrderedSet.union(self, other) |
| |
| def __ior__(self, other: "OrderedSet[T]") -> "OrderedSet[T]": |
| self.update(other) |
| return self |
| |
| def __eq__(self, other: object) -> bool: |
| if isinstance(other, OrderedSet): |
| return self.storage == other.storage |
| else: |
| return set(self.storage.keys()) == other |