blob: df889e379cbade90edaa9e8e2850791ec1ceb022 [file] [log] [blame] [edit]
# Generates RegisterCodegenUnboxedKernels.cpp, UnboxingFunctions.h and UnboxingFunctions.cpp.
from __future__ import annotations
import argparse
import os
import sys
from dataclasses import dataclass
from pathlib import Path
from typing import Literal, Sequence, TYPE_CHECKING
import yaml
from torchgen.api import cpp, unboxing
from torchgen.api.translate import translate
from torchgen.api.types import CppSignatureGroup
from torchgen.api.unboxing import convert_arguments
from torchgen.context import method_with_native_function
from torchgen.gen import cpp_string, get_custom_build_selector, parse_native_yaml
from torchgen.model import Argument, NativeFunction, NativeFunctionsGroup, Variant
from torchgen.utils import FileManager, make_file_manager, mapMaybe, Target
if TYPE_CHECKING:
from torchgen.selective_build.selector import SelectiveBuilder
# Generates UnboxingFunctions.h & UnboxingFunctions.cpp.
@dataclass(frozen=True)
class ComputeUnboxingFunctions:
target: Literal[Target.DECLARATION, Target.DEFINITION]
selector: SelectiveBuilder
@method_with_native_function
def __call__(self, f: NativeFunction) -> str:
if not self.selector.is_root_operator(f"aten::{f.func.name}"):
return ""
if self.target is Target.DECLARATION:
# Note [The ATen Codegen Unboxing API]
# Similar to the ATen Operators API, ATen Codegen Unboxing API lives in the at::unboxing namespace, and
# will be used by codegen unboxing wrappers (CodegenUnboxingWrappers.cpp).
# The Wrappers will be registered into torch::jit::OperatorRegistry using RegisterOperators API.
#
# Important characteristics about the Codegen Unboxing API:
# (1) It follows the OperatorRegistry API.
# This is kind of necessary to avoid overhead.
# For example: if it followed the C++ API, then all of the faithful C++ factory functions
# would need to wrap their arguments into TensorOptions only to unwrap them again.
# (2) Under the hood it calls C++ API.
return f"""
// aten::{f.func}
TORCH_API void {f.func.name.unambiguous_name()}(Stack & stack);
"""
else:
sig_group = CppSignatureGroup.from_native_function(
f, method=(Variant.method in f.variants)
)
sig = sig_group.most_faithful_signature()
# parse arguments into C++ code
binding_list, code_list = convert_arguments(f)
# for each C++ argument, generate the conversion code
code_connector = "\n\t"
arg_connector = ", "
# function call and push back to stack
prefix = "self_base." if sig.method else "at::"
translated_args = translate(
binding_list, sig.arguments(), method=sig.method
)
args_str = f"{arg_connector.join(e.expr for e in translated_args)}"
if len(f.func.returns) == 0:
ret_str = ""
push_str = ""
else:
ret_str = "auto result_ = "
push_str = """
pack(stack, std::move(result_));
"""
return f"""
// aten::{f.func}
TORCH_API void {f.func.name.unambiguous_name()}(Stack & stack) {{
{code_connector.join(code_list)}
drop(stack, {len(binding_list)});
{ret_str}{prefix}{sig.name()}({args_str});
{push_str}
}}
"""
# Generates RegisterCodegenUnboxedKernels.cpp.
@dataclass(frozen=True)
class ComputeCodegenUnboxedKernels:
selector: SelectiveBuilder
@method_with_native_function
def __call__(self, f: NativeFunction) -> str:
if not self.selector.is_root_operator(f"aten::{f.func.name}"):
return ""
# We unconditionally generate function wrappers,
sig_group = CppSignatureGroup.from_native_function(f, method=False)
sig = sig_group.most_faithful_signature()
# escape double quote in schema, get rid of extra double quotes
schema = cpp_string(str(sig.func))[1:-1]
# arguments
args = sig.arguments()
connector = ",\n\t\t"
args_code = []
for arg in args:
# Using method=False faithful C++ API, so we should not see SelfArgument/TensorOptionsArgument
assert isinstance(arg.argument, Argument)
if not arg.argument.default:
arg_cpp = "c10::IValue(::std::nullopt)"
else:
# The unboxing code uses the faithful C++ API to avoid the overhead
# from wrapping/unwrapping TensorOptios.
# However, we would look to include default args for schema parsing.
# Default args only show up in the nonfaithful C++ API,
arg_default = cpp.default_expr(
arg.argument.default, arg.argument.type, symint=False
)
if arg_default.startswith("{"):
arg_cpp = f"c10::IntArrayRef({arg_default})"
else:
arg_cpp = f"c10::IValue({arg_default})"
args_code.append(
f"""c10::Argument("{arg.name}", nullptr, ::std::nullopt, {arg_cpp})"""
)
returns = f.func.returns
returns_code = []
for ret in returns:
returns_code.append(f"""c10::Argument("{ret.name if ret.name else ""}")""")
return f"""
// aten::{schema}
OperatorGenerator(
"aten::{f.func.name.name}",
"{f.func.name.overload_name}",
{{
{connector.join(args_code)}
}},
{{
{connector.join(returns_code)}
}},
[](Stack & stack) {{
RECORD_FUNCTION("{sig.name()}", std::vector<c10::IValue>());
at::unboxing::{unboxing.name(f)}(stack);
}},
aliasAnalysisFromSchema()
),
"""
def gen_unboxing(
*,
native_functions: Sequence[NativeFunction],
cpu_fm: FileManager,
selector: SelectiveBuilder,
) -> None:
def key_func(fn: NativeFunction | NativeFunctionsGroup) -> str:
return fn.root_name
selected_op_num: int = len(selector.operators)
# a best practice threshold of operators to enable sharding
sharding_threshold: int = 100
cpu_fm.write_sharded(
"UnboxingFunctions.cpp",
native_functions,
key_fn=key_func,
env_callable=lambda fn: {
"definitions": [ComputeUnboxingFunctions(Target.DEFINITION, selector)(fn)]
},
num_shards=1 if selected_op_num < sharding_threshold else 5,
sharded_keys={"definitions"},
)
cpu_fm.write(
"UnboxingFunctions.h",
lambda: {
"declarations": list(
mapMaybe(
ComputeUnboxingFunctions(Target.DECLARATION, selector),
native_functions,
)
),
},
)
cpu_fm.write_sharded(
"RegisterCodegenUnboxedKernels.cpp",
native_functions,
key_fn=key_func,
env_callable=lambda fn: {
"unboxed_ops": [ComputeCodegenUnboxedKernels(selector)(fn)]
},
num_shards=1 if selected_op_num < sharding_threshold else 10,
sharded_keys={"unboxed_ops"},
)
def main(args: list[str]) -> None:
parser = argparse.ArgumentParser(description="Generate unboxing source files")
parser.add_argument(
"-s",
"--source-path",
help="path to source directory for ATen",
default="aten/src/ATen",
)
parser.add_argument(
"-d",
"--install-dir",
"--install_dir",
help="output directory",
default="build/aten/src/ATen",
)
parser.add_argument(
"-o",
"--output-dependencies",
help="output a list of dependencies into the given file and exit",
)
parser.add_argument(
"--dry-run",
action="store_true",
help="run without writing any files (still updates outputs)",
)
parser.add_argument(
"--op-selection-yaml-path",
"--op_selection_yaml_path",
help="Provide a path to the operator selection (for custom build) YAML "
"that contains the information about the set of selected operators "
"and their categories (training, ...). Each operator is either a "
"full operator name with overload or just a bare operator name. "
"The operator names also contain the namespace prefix (e.g. aten::)",
)
parser.add_argument(
"--op-registration-allowlist",
"--op_registration_allowlist",
nargs="*",
help="filter op registrations by the allowlist (if set); "
"each item is `namespace`::`operator name` without overload name; "
"e.g.: aten::empty aten::conv2d ...",
)
parser.add_argument(
"--TEST-ONLY-op-registration-allowlist-yaml-path",
"--TEST_ONLY_op_registration_allowlist_yaml_path",
help="Provide a path to the operator selection (for custom build) YAML "
"which contains a list of operators. It is to serve testing purpose and "
"each item is `namespace`::`operator name` without overload name; "
"e.g.: aten::empty aten::conv2d ...",
)
options = parser.parse_args(args)
if options.op_registration_allowlist:
op_registration_allowlist = options.op_registration_allowlist
elif options.TEST_ONLY_op_registration_allowlist_yaml_path:
with open(options.TEST_ONLY_op_registration_allowlist_yaml_path) as f:
op_registration_allowlist = yaml.safe_load(f)
else:
op_registration_allowlist = None
selector = get_custom_build_selector(
op_registration_allowlist,
options.op_selection_yaml_path,
)
native_yaml_path = os.path.join(options.source_path, "native/native_functions.yaml")
tags_yaml_path = os.path.join(options.source_path, "native/tags.yaml")
parsed_yaml = parse_native_yaml(native_yaml_path, tags_yaml_path)
native_functions, backend_indices = (
parsed_yaml.native_functions,
parsed_yaml.backend_indices,
)
cpu_fm = make_file_manager(options=options)
gen_unboxing(native_functions=native_functions, cpu_fm=cpu_fm, selector=selector)
if options.output_dependencies:
depfile_path = Path(options.output_dependencies).resolve()
depfile_name = depfile_path.name
depfile_stem = depfile_path.stem
path = depfile_path.parent / depfile_name
cpu_fm.write_outputs(depfile_stem, str(path))
if __name__ == "__main__":
main(sys.argv[1:])