| """ |
| For procedural tests needed for __torch_function__, we use this function |
| to export method names and signatures as needed by the tests in |
| test/test_overrides.py. |
| |
| python -m tools.autograd.gen_annotated_fn_args \ |
| aten/src/ATen/native/native_functions.yaml \ |
| aten/src/ATen/native/tags.yaml \ |
| $OUTPUT_DIR \ |
| tools/autograd |
| |
| Where $OUTPUT_DIR is where you would like the files to be |
| generated. In the full build system, OUTPUT_DIR is |
| torch/testing/_internal/generated |
| """ |
| |
| from collections import defaultdict |
| import argparse |
| import os |
| import textwrap |
| |
| from typing import Dict, List, Any |
| |
| from torchgen.gen import parse_native_yaml |
| from torchgen.utils import FileManager |
| from torchgen.context import with_native_function |
| from torchgen.model import BaseOperatorName, NativeFunction |
| import torchgen.api.python as python |
| from .gen_python_functions import ( |
| should_generate_py_binding, |
| is_py_torch_function, |
| is_py_nn_function, |
| is_py_linalg_function, |
| is_py_variable_method, |
| is_py_special_function, |
| is_py_fft_function, |
| ) |
| |
| |
| def gen_annotated( |
| native_yaml_path: str, tags_yaml_path: str, out: str, autograd_dir: str |
| ) -> None: |
| native_functions = parse_native_yaml( |
| native_yaml_path, tags_yaml_path |
| ).native_functions |
| mappings = ( |
| (is_py_torch_function, "torch._C._VariableFunctions"), |
| (is_py_nn_function, "torch._C._nn"), |
| (is_py_linalg_function, "torch._C._linalg"), |
| (is_py_special_function, "torch._C._special"), |
| (is_py_fft_function, "torch._C._fft"), |
| (is_py_variable_method, "torch.Tensor"), |
| ) |
| annotated_args: List[str] = [] |
| for pred, namespace in mappings: |
| groups: Dict[BaseOperatorName, List[NativeFunction]] = defaultdict(list) |
| for f in native_functions: |
| if not should_generate_py_binding(f) or not pred(f): |
| continue |
| groups[f.func.name.name].append(f) |
| for group in groups.values(): |
| for f in group: |
| annotated_args.append(f"{namespace}.{gen_annotated_args(f)}") |
| |
| template_path = os.path.join(autograd_dir, "templates") |
| fm = FileManager(install_dir=out, template_dir=template_path, dry_run=False) |
| fm.write_with_template( |
| "annotated_fn_args.py", |
| "annotated_fn_args.py.in", |
| lambda: { |
| "annotated_args": textwrap.indent("\n".join(annotated_args), " "), |
| }, |
| ) |
| |
| |
| @with_native_function |
| def gen_annotated_args(f: NativeFunction) -> str: |
| out_args: List[Dict[str, Any]] = [] |
| for arg in f.func.arguments.flat_positional: |
| if arg.default is not None: |
| continue |
| out_arg: Dict[str, Any] = {} |
| out_arg["name"] = arg.name |
| out_arg["simple_type"] = python.argument_type_str(arg.type, simple_type=True) |
| size = python.argument_type_size(arg.type) |
| if size: |
| out_arg["size"] = size |
| out_args.append(out_arg) |
| |
| return f"{f.func.name.name}: {repr(out_args)}," |
| |
| |
| def main() -> None: |
| parser = argparse.ArgumentParser(description="Generate annotated_fn_args script") |
| parser.add_argument( |
| "native_functions", metavar="NATIVE", help="path to native_functions.yaml" |
| ) |
| parser.add_argument("tags", metavar="TAGS", help="path to tags.yaml") |
| parser.add_argument("out", metavar="OUT", help="path to output directory") |
| parser.add_argument( |
| "autograd", metavar="AUTOGRAD", help="path to template directory" |
| ) |
| args = parser.parse_args() |
| gen_annotated(args.native_functions, args.tags, args.out, args.autograd) |
| |
| |
| if __name__ == "__main__": |
| main() |