blob: 511aed3c270a7dda6d3e695824651b24676d0907 [file] [log] [blame]
import os
from typing import List, Dict, Optional, Tuple, Set, Any, Union, Sequence, TypeVar
from typing_extensions import Literal
import yaml
from collections import OrderedDict, defaultdict, namedtuple
import argparse
import pathlib
import json
from dataclasses import dataclass
import functools
from torchgen.model import (
STRUCTURED_DISPATCH_KEYS,
Argument,
DispatchKey,
FunctionSchema,
Location,
NativeFunction,
NativeFunctionsGroup,
OperatorName,
BackendIndex,
BackendMetadata,
OptionalType,
SchemaKind,
SelfArgument,
TensorOptionsArguments,
Type,
Variant,
is_cuda_dispatch_key,
is_generic_dispatch_key,
is_ufunc_dispatch_key,
NativeFunctionsViewGroup,
ViewSchemaKind,
BaseOperatorName,
)
from torchgen.native_function_generation import (
pre_group_native_functions,
add_generated_native_functions,
gen_composite_functional_kernel,
gen_composite_out_kernel,
)
from torchgen.api.types import (
Binding,
CppSignatureGroup,
DispatcherSignature,
NamedCType,
NativeSignature,
SpecialArgName,
)
from torchgen.api import cpp
import torchgen.api.dispatcher as dispatcher
import torchgen.api.native as native
import torchgen.api.meta as meta
import torchgen.api.structured as structured
from torchgen.api.translate import translate
from torchgen.selective_build.selector import SelectiveBuilder
from torchgen.utils import (
Target,
concatMap,
context,
mapMaybe,
YamlDumper,
YamlLoader,
FileManager,
assert_never,
make_file_manager,
)
from torchgen.context import (
method_with_native_function,
native_function_manager,
with_native_function_and_indices,
with_native_function,
)
import torchgen.dest as dest
from torchgen.gen_functionalization_type import (
gen_functionalization_definition,
gen_functionalization_registration,
gen_functionalization_view_inverse_declaration,
gen_composite_view_copy_kernel,
gen_symint_view_copy_kernel,
)
T = TypeVar("T")
# Welcome to the ATen code generator v2! The ATen code generator is
# responsible for parsing native_functions.yaml and then generating
# various generated files (e.g., TypeDefault.cpp) based on the operators
# defined in this file. This means that the code generator knows how to
# parse function schema, and then translate this into various C++ types
# and boilerplate code.
#
# Some things to know about this file when you modify it:
#
# - This file has STRICT mypy typechecking. Typecheck it with
# `mypy --config mypy-strict.ini` in the root source directory
#
# - Most of the heavy lifting lives in external modules:
# - 'model' has the data model for native_functions.yaml. The classes
# in those file represent what you see when you look at
# a native_functions.yaml
# - 'api' has conversions for how to translate JIT schema into
# the various C++ APIs that the codegen interacts with. There
# are in fact THREE different C++ APIs: the public C++ API,
# the dispatcher API, and the legacy disaptcher API. See each
# of these respective files for more information
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
#
# HELPER FUNCTIONS
#
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
class NamespaceHelper:
"""A helper for constructing the namespace open and close strings for a nested set of namespaces.
e.g. for namespace_str torch::lazy,
prologue:
namespace torch {
namespace lazy {
epilogue:
} // namespace lazy
} // namespace torch
"""
def __init__(self, namespace_str: str):
# cpp_namespace can be a colon joined string such as torch::lazy
cpp_namespaces = namespace_str.split("::")
self.prologue_ = "\n".join([f"namespace {n} {{" for n in cpp_namespaces])
self.epilogue_ = "\n".join(
[f"}} // namespace {n}" for n in reversed(cpp_namespaces)]
)
@property
def prologue(self) -> str:
return self.prologue_
@property
def epilogue(self) -> str:
return self.epilogue_
# A custom loader for YAML to let us also keep track of line numbers
# of each entry in the YAML file
class LineLoader(YamlLoader):
def construct_mapping(self, node, deep=False): # type: ignore[no-untyped-def]
mapping = super().construct_mapping(node, deep=deep) # type: ignore[no-untyped-call]
# Add 1 so line numbering starts at 1
mapping["__line__"] = node.start_mark.line + 1
return mapping
_GLOBAL_PARSE_NATIVE_YAML_CACHE = {}
_GLOBAL_PARSE_TAGS_YAML_CACHE = {}
# Parse native_functions.yaml into a sequence of NativeFunctions and Backend Indices.
ParsedYaml = namedtuple("ParsedYaml", ["native_functions", "backend_indices"])
def parse_native_yaml_struct(
es: object,
valid_tags: Set[str],
ignore_keys: Optional[Set[DispatchKey]] = None,
path: str = "<stdin>",
skip_native_fns_gen: bool = False,
) -> ParsedYaml:
assert isinstance(es, list)
rs: List[NativeFunction] = []
bs: Dict[DispatchKey, Dict[OperatorName, BackendMetadata]] = defaultdict(dict)
for e in es:
assert isinstance(e.get("__line__"), int), e
loc = Location(path, e["__line__"])
funcs = e.get("func")
with context(lambda: f"in {loc}:\n {funcs}"):
func, m = NativeFunction.from_yaml(e, loc, valid_tags, ignore_keys)
rs.append(func)
BackendIndex.grow_index(bs, m)
error_check_native_functions(rs)
# Default dict is to prevent the codegen from barfing when we have a dispatch key that has no kernels yet.
indices: Dict[DispatchKey, BackendIndex] = defaultdict(
lambda: BackendIndex(
dispatch_key=DispatchKey.Undefined,
use_out_as_primary=True,
external=False,
device_guard=False,
index={},
)
)
if not skip_native_fns_gen:
add_generated_native_functions(rs, bs)
for k, v in bs.items():
# All structured in-tree operators are implemented in terms of their out operator.
indices[k] = BackendIndex(
dispatch_key=k,
use_out_as_primary=True,
external=False,
# Only cuda-like devices in tree require device guards
device_guard=is_cuda_dispatch_key(k),
index=v,
)
return ParsedYaml(rs, indices)
def parse_tags_yaml_struct(es: object, path: str = "<stdin>") -> Set[str]:
assert isinstance(es, list)
rs: Set[str] = set()
for e in es:
assert isinstance(e.get("__line__"), int), e
loc = Location(path, e["__line__"])
tags = e.get("tag")
with context(lambda: f"in {loc}:\n {tags}"):
e_i = e.copy()
name = e_i.pop("tag")
desc = e_i.pop("desc", "")
# ensure that each tag has a non-empty description
assert desc != ""
rs.add(name)
return rs
@functools.lru_cache(maxsize=None)
def parse_tags_yaml(path: str) -> Set[str]:
global _GLOBAL_PARSE_TAGS_YAML_CACHE
if path not in _GLOBAL_PARSE_TAGS_YAML_CACHE:
with open(path, "r") as f:
es = yaml.load(f, Loader=LineLoader)
_GLOBAL_PARSE_TAGS_YAML_CACHE[path] = parse_tags_yaml_struct(es, path=path)
return _GLOBAL_PARSE_TAGS_YAML_CACHE[path]
def parse_native_yaml(
path: str,
tags_yaml_path: str,
ignore_keys: Optional[Set[DispatchKey]] = None,
*,
skip_native_fns_gen: bool = False,
) -> ParsedYaml:
global _GLOBAL_PARSE_NATIVE_YAML_CACHE
if path not in _GLOBAL_PARSE_NATIVE_YAML_CACHE:
valid_tags = parse_tags_yaml(tags_yaml_path)
with open(path, "r") as f:
es = yaml.load(f, Loader=LineLoader)
_GLOBAL_PARSE_NATIVE_YAML_CACHE[path] = parse_native_yaml_struct(
es,
valid_tags,
ignore_keys,
path=path,
skip_native_fns_gen=skip_native_fns_gen,
)
return _GLOBAL_PARSE_NATIVE_YAML_CACHE[path]
# Some assertions are already performed during parsing, but those are only within a single NativeFunction.
# Assertions here are meant to be performed across NativeFunctions.
def error_check_native_functions(funcs: Sequence[NativeFunction]) -> None:
func_map: Dict[OperatorName, NativeFunction] = {}
base_func_map: Dict[BaseOperatorName, List[NativeFunction]] = defaultdict(list)
for f in funcs:
func_map[f.func.name] = f
base_func_map[f.func.name.name].append(f)
for f in funcs:
if f.structured_delegate is not None:
delegate_func = func_map[f.structured_delegate]
assert delegate_func.structured, (
f"{f.func.name} is marked as a structured_delegate pointing to "
f"{f.structured_delegate}, but {f.structured_delegate} is not marked as structured. "
f"Consider adding 'structured=True' to the delegated operator"
)
if "inplace_view" in f.tags:
base_name = f.func.name.name
overload_name = f.func.name.overload_name
assert base_name.inplace, (
f"{f.func.name} is marked with tag: inplace_view, but it doesn't follow the naming "
"convention for inplace ops - the codegen expects the base name to have a trailing underscore. "
)
out_of_place_base_name = BaseOperatorName(
base_name.base, False, base_name.dunder_method
)
assert len(base_func_map[out_of_place_base_name]) > 0, (
f"{f.func.name} is marked with tag: inplace_view. The codegen expects there to be a corresponding "
f"out-of-place view op with the name '{base_name}' and matching schema, but it didn't find one. "
)
def cpp_string(s: str) -> str:
"""Convert a python string into a c++ string literal"""
s = s.replace("\\", "\\\\")
s = s.replace('"', '\\"')
s = s.replace("\a", "\\a")
s = s.replace("\b", "\\b")
s = s.replace("\f", "\\f")
s = s.replace("\n", "\\n")
s = s.replace("\v", "\\v")
s = s.replace("\t", "\\t")
return f'"{s}"'
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
#
# C++ CODE GENERATION
#
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
# Most functions in this section are curried: they consist of a function
# that takes some parameters (e.g., what is to be generated) which itself
# returns a function that actually maps NativeFunction to the code
# to be generated. This pattern makes it convenient to use map, concatMap
# and similar functional combinators.
def static_dispatch_keys(backends: List[BackendIndex]) -> List[DispatchKey]:
if len(backends) == 0:
return []
else:
return [backend.dispatch_key for backend in backends] + [
DispatchKey.CompositeImplicitAutograd,
DispatchKey.CompositeExplicitAutograd,
DispatchKey.CompositeExplicitAutogradNonFunctional,
]
def get_static_dispatch_backend(
f: NativeFunction, backend_index: BackendIndex
) -> Optional[DispatchKey]:
if f.structured_delegate is not None or backend_index.has_kernel(f):
# TODO: for ops with structured_delegate it should check the dispatch table of
# the out variant instead. For now, these structured ops all have CPU/CUDA kernels
# so we always dispatch to the `backend`, but this could be wrong when we
# migrate math/default_backend ops to use structured delegate.
return backend_index.dispatch_key
elif f.has_composite_explicit_autograd_kernel:
return DispatchKey.CompositeExplicitAutograd
elif f.has_composite_explicit_autograd_non_functional_kernel:
return DispatchKey.CompositeExplicitAutogradNonFunctional
elif f.has_composite_implicit_autograd_kernel:
return DispatchKey.CompositeImplicitAutograd
return None
def static_dispatch_ops_header(
f: NativeFunction, backend_index: List[BackendIndex]
) -> Optional[str]:
if backend_index is None or f.manual_kernel_registration:
return None
output = []
for index in backend_index:
dispatch_key = get_static_dispatch_backend(f, index)
if dispatch_key is not None:
output.append(
f"#include <ATen/ops/{f.root_name}_{dispatch_key.lower()}_dispatch.h>"
)
return "\n".join(output)
def static_dispatch_extra_headers(backends: List[BackendIndex]) -> List[str]:
return [
f"#include <ATen/{dispatch_key}Functions.h>"
for dispatch_key in static_dispatch_keys(backends)
]
# Translates arguments of a native function from DispatcherSignature form to CppSignature form with support for
# supporting usecases even when there is a memory_format argument along with tensor_option arguments.
# This usecase is not covered by tools.codegen.api.translate() yet as its application is limited to static dispatch
def translate_args_dispatcher_to_cpp(
f: NativeFunction,
) -> str:
# Adds SpecialArgName.possibly_redundant_memory_format NamedCType for memory_format bindings
def add_spl_memory_format_binding(input_bindings: List[Binding]) -> List[Binding]:
output_bindings: List[Binding] = []
for binding in input_bindings:
if binding.name == "memory_format":
spl_mem_format_binding = Binding(
nctype=NamedCType(
SpecialArgName.possibly_redundant_memory_format,
binding.nctype.type,
),
name=binding.name,
default=binding.default,
argument=binding.argument,
)
output_bindings.append(spl_mem_format_binding)
else:
output_bindings.append(binding)
return output_bindings
disp_sig = DispatcherSignature.from_schema(f.func)
cpp_sig = CppSignatureGroup.from_native_function(
f, method=False, fallback_binding=False
).signature
disp_bindings = disp_sig.arguments()
# When last argument of CPP signature has SpecialArgName.possibly_redundant_memory_format NCType,
# get memory_format bindings of dispatcher signature to have the same NCType as well
for arg in cpp_sig.arguments():
if arg.nctype.name == SpecialArgName.possibly_redundant_memory_format:
disp_bindings = add_spl_memory_format_binding(disp_sig.arguments())
break
exprs = translate(disp_bindings, cpp_sig.arguments())
return ", ".join(a.expr for a in exprs)
def generate_static_dispatch_backend_call(
f: NativeFunction,
backend_index: BackendIndex,
ns: str = "at",
) -> str:
name = DispatcherSignature.from_schema(f.func).name()
exprs = translate_args_dispatcher_to_cpp(f)
return f"return {ns}::{backend_index.dispatch_key.lower()}::{name}({exprs});"
def generate_static_dispatch_fallback_call(
f: NativeFunction,
backend_indices: List[BackendIndex],
ns: str = "at",
) -> str:
name = DispatcherSignature.from_schema(f.func).name()
exprs = translate_args_dispatcher_to_cpp(f)
if f.has_composite_explicit_autograd_kernel:
return f"return {ns}::{DispatchKey.CompositeExplicitAutograd.lower()}::{name}({exprs});"
elif f.has_composite_explicit_autograd_non_functional_kernel:
return f"return {ns}::{DispatchKey.CompositeExplicitAutogradNonFunctional.lower()}::{name}({exprs});"
elif f.has_composite_implicit_autograd_kernel:
return f"return {ns}::{DispatchKey.CompositeImplicitAutograd.lower()}::{name}({exprs});"
else:
return f"""TORCH_CHECK(false, "Static dispatch does not support {name} for\
{', '.join([str(index.dispatch_key)for index in backend_indices])} ");"""
def static_dispatch(
f: NativeFunction,
backend_indices: List[BackendIndex],
namespace: str = "at",
) -> str:
if len(backend_indices) == 0 or f.manual_kernel_registration:
return ""
keys = [
b
for b in backend_indices
if b.has_kernel(f)
or (
f.structured_delegate is not None
and b.dispatch_key in STRUCTURED_DISPATCH_KEYS
)
]
if len(keys) == 1:
return generate_static_dispatch_backend_call(f, keys[0], namespace)
elif len(keys) == 0:
return generate_static_dispatch_fallback_call(f, backend_indices, namespace)
sig = DispatcherSignature.from_schema(f.func)
native_tensor_args = [
a.name
for a in sig.arguments()
if isinstance(a.argument, SelfArgument)
or isinstance(a.argument, Argument)
and a.argument.type.is_tensor_like()
]
tensor_args = ", ".join(native_tensor_args)
tensor_opts = f.func.arguments.tensor_options
stmts = []
subexprs: List[str] = []
if tensor_opts is not None:
subexprs.append(
"DispatchKeySet(c10::computeDispatchKey(dtype, layout, device))"
)
if tensor_args != "":
subexprs.append(f"c10::detail::multi_dispatch_key_set({tensor_args})")
stmts.append(f"""DispatchKeySet _dk_set = {' | '.join(subexprs)};""")
stmts.append("DispatchKey _dk = c10::highestPriorityBackendTypeId(_dk_set);")
dispatch_code = []
for index in keys:
dispatch_code.append(f"""case DispatchKey::{index.dispatch_key}:""")
dispatch_code.append(
f"""\t{generate_static_dispatch_backend_call(f, index, namespace)};"""
)
fallback = generate_static_dispatch_fallback_call(f, backend_indices, namespace)
connector = "\n\t\t"
return f"""
{connector.join(stmts)}
switch (_dk) {{
{connector.join(dispatch_code)}
default:
{fallback}
}}
"""
# Generates RegisterSchema.cpp. Depending on the selector, either
# all schemas are registered, or only some are (in the case of
# selective build)
@dataclass(frozen=True)
class RegisterSchema:
selector: SelectiveBuilder
@method_with_native_function
def __call__(self, f: NativeFunction) -> Optional[str]:
if not self.selector.is_native_function_selected(f):
return None
tags = "{" + ", ".join([f"at::Tag::{tag}" for tag in f.tags]) + "}"
return f"m.def({cpp_string(str(f.func))}, {tags});\n"
# Generates Operators.h and Operators.cpp.
# These provide macros that, given an operator and overload name, allow users
# to access an "un-overloaded" function version of the operator. This
# is useful for extension writers who want to (1) want to decltype the operator
# and (2) don't want to worry about method-only operators.
@dataclass(frozen=True)
class ComputeOperators:
target: Union[Literal[Target.DECLARATION], Literal[Target.DEFINITION]]
static_dispatch_backend_indices: List[BackendIndex]
@method_with_native_function
def __call__(self, f: NativeFunction) -> str:
sig = DispatcherSignature.from_schema(f.func)
name = f.func.name.unambiguous_name()
call_method_name = "call"
redispatch_method_name = "redispatch"
if self.target is Target.DECLARATION:
# Note [The ATen Operators API]
# The ATen Operators API lives in the at::_ops namespace, and contains compile-time
# metadata about each operator + entry points into the Dispatcher.
# The C++ function, method, and redispatch API's are all implemented as wrappers
# into various bits of the structs defined here.
#
# Important characteristics about the Operators API:
# (1) It follows the Dispatcher API.
# This is kind of necessary to avoid overhead.
# For example: if it followed the C++ API, then all of the faithful C++ factory functions
# would need to wrap their arguments into TensorOptions only to unwrap them again.
# (2) Overload names are disambiguated.
# This is helpful for pytorch extenders who would like to decltype() an aten operator,
# that has overloads, e.g. decltype(at::_ops::mul_Tensor::call)
# (3) No argument defaulting is allowed.
# This is more of an implementation detail to avoid #include cycles,
# since TensorBody.h (which defines the Tensor class) needs to include this file.
# (4) manual_cpp_bindings and faithful names are not included in the API.
# This applies to stuff like __dispatch__is_complex(), and add_outf().
# These aren't "real aten ops", they're just additional functions provided by the C++ API.
# They're implemented as wrappers in Functions.h that call into the actual operators
# defined here, i.e. at::_ops::is_complex::call() and at::_ops::add_out::call().
# This means that ATEN_OP(is_complex) will not fastpath, and will go through the dispatcher.
return f"""
struct TORCH_API {name} {{
using schema = {sig.type()};
using ptr_schema = schema*;
// See Note [static constexpr char* members for windows NVCC]
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(name, "aten::{f.func.name.name}")
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(overload_name, "{f.func.name.overload_name}")
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(schema_str, {cpp_string(str(f.func))})
static {sig.defn(name=call_method_name, is_redispatching_fn=False)};
static {sig.defn(name=redispatch_method_name, is_redispatching_fn=True)};
}};"""
elif self.target is Target.DEFINITION:
defns = f"""
STATIC_CONST_STR_OUT_OF_LINE_FOR_WIN_CUDA({name}, name, "aten::{f.func.name.name}")
STATIC_CONST_STR_OUT_OF_LINE_FOR_WIN_CUDA({name}, overload_name, "{f.func.name.overload_name}")
STATIC_CONST_STR_OUT_OF_LINE_FOR_WIN_CUDA({name}, schema_str, {cpp_string(str(f.func))})
// aten::{f.func}
static C10_NOINLINE c10::TypedOperatorHandle<{name}::schema> create_{name}_typed_handle() {{
return c10::Dispatcher::singleton()
.findSchemaOrThrow({name}::name, {name}::overload_name)
.typed<{name}::schema>();
}}
"""
for is_redispatching_fn in [False, True]:
if is_redispatching_fn:
dispatcher_exprs_str = ", ".join(
["dispatchKeySet"] + [a.name for a in sig.arguments()]
)
dispatcher_call = "redispatch"
method_name = f"{name}::{redispatch_method_name}"
else:
method_name = f"{name}::{call_method_name}"
dispatcher_exprs_str = ", ".join([a.name for a in sig.arguments()])
dispatcher_call = "call"
fn_body = f"""
static auto op = create_{name}_typed_handle();
return op.{dispatcher_call}({dispatcher_exprs_str});"""
if (
not is_redispatching_fn
and len(self.static_dispatch_backend_indices) > 0
):
# call() should go through static dispatch
fn_body = static_dispatch(
f, backend_indices=self.static_dispatch_backend_indices
)
defns += f"""
// aten::{f.func}
{sig.defn(name=method_name, is_redispatching_fn=is_redispatching_fn)} {{
{fn_body}
}}
"""
return defns
else:
assert_never(self.target)
# Generates Functions.h, which provides the functional public C++ API,
# and the scaffolding to call into the dispatcher from these functions.
@dataclass(frozen=True)
class ComputeFunction:
@method_with_native_function
def __call__(self, f: NativeFunction) -> Optional[str]:
if Variant.function not in f.variants:
return None
sig_group = CppSignatureGroup.from_native_function(
f, method=False, fallback_binding=f.manual_cpp_binding
)
def generate_defn(faithful: bool) -> str:
if faithful:
sig = sig_group.faithful_signature
assert sig is not None
else:
sig = sig_group.signature
# See Note [The ATen Operators API]
target_sig = DispatcherSignature.from_schema(f.func)
exprs = translate(sig.arguments(), target_sig.arguments())
exprs_str = ", ".join([e.expr for e in exprs])
return f"""
// aten::{f.func}
TORCH_API inline {sig.decl()} {{
return at::_ops::{f.func.name.unambiguous_name()}::call({exprs_str});
}}
"""
result = generate_defn(False)
if sig_group.faithful_signature is not None:
result += generate_defn(True)
return result
# Generates TensorBody.h. This file provides the object-oriented (method-based)
# public C++ API, and the scaffolding to call into the dispatcher from these functions.
@dataclass(frozen=True)
class ComputeTensorMethod:
target: Union[Literal[Target.DECLARATION], Literal[Target.DEFINITION]]
static_dispatch_backend_indices: List[BackendIndex]
@method_with_native_function
def __call__(self, f: NativeFunction) -> Optional[str]:
if Variant.method not in f.variants:
return None
assert not f.func.is_out_fn()
assert f.func.arguments.self_arg is not None
sig_group = CppSignatureGroup.from_native_function(
f, method=True, fallback_binding=f.manual_cpp_binding
)
if self.target is Target.DECLARATION:
result = f"{sig_group.signature.decl()} const;\n"
if sig_group.faithful_signature is not None:
result += f"{sig_group.faithful_signature.decl()} const;\n"
return result
if self.target is not Target.DEFINITION:
assert_never(self.target)
def generate_defn(faithful: bool) -> str:
if faithful:
sig = sig_group.faithful_signature
assert sig is not None
else:
sig = sig_group.signature
target_sig = DispatcherSignature.from_schema(f.func)
exprs = translate(sig.arguments(), target_sig.arguments(), method=True)
exprs_str = ", ".join([e.expr for e in exprs])
return f"""
// aten::{f.func}
inline {sig.defn(prefix="Tensor::")} const {{
return at::_ops::{f.func.name.unambiguous_name()}::call({exprs_str});
}}
"""
result = generate_defn(faithful=False)
if sig_group.faithful_signature is not None:
result += generate_defn(faithful=True)
return result
# Generates RedispatchFunctions.h.
# This is similar to the C++ API defined in Functions.h, but provides access
# to the dispatcher's redispatch API.
@dataclass(frozen=True)
class ComputeRedispatchFunction:
@method_with_native_function
def __call__(self, f: NativeFunction) -> Optional[str]:
# We unconditionally generate function variants of the redispatch API.
# This is mainly because we can namespace functions separately, but not methods,
sig_group = CppSignatureGroup.from_native_function(
f, method=False, fallback_binding=f.manual_cpp_binding
)
def generate_defn(faithful: bool) -> str:
if faithful:
sig = sig_group.faithful_signature
assert sig is not None
else:
sig = sig_group.signature
target_sig = DispatcherSignature.from_schema(f.func)
exprs = translate(sig.arguments(), target_sig.arguments())
exprs_str = ", ".join(["dispatchKeySet"] + [a.expr for a in exprs])
return f"""
// aten::{f.func}
TORCH_API inline {sig.decl(is_redispatching_fn=True)} {{
return at::_ops::{f.func.name.unambiguous_name()}::redispatch({exprs_str});
}}
"""
result = generate_defn(False)
if sig_group.faithful_signature is not None:
result += generate_defn(True)
return result
# Generates ATenOpList.cpp, a runtime accessible list of all aten
# operators.
# TODO: This was historically used to help some JIT interop code
# figure out whether or not to treat aten namespace'd operators
# one way or another, we should reevaluate if this is actually needed.
@with_native_function
def compute_aten_op(f: NativeFunction) -> str:
return f'{{"aten::{f.func.name.name}", "{f.func.name.overload_name}"}},'
# Generates MetaFunctions.h
def compute_meta_function_declaration(g: NativeFunctionsGroup) -> Optional[str]:
if not g.structured:
return None
with native_function_manager(g.out):
name = meta.name(g)
args = structured.meta_arguments(g)
args_str = ", ".join(a.decl() for a in args)
parent_class = g.out.structured_inherits
if parent_class is None:
parent_class = "at::impl::MetaBase"
meta_return = "void"
precomputed = g.out.precomputed if g.structured else None
if precomputed:
# Generate the template declaration with one bool parameter for each
# precomputed element. Each parameter is true if the corresponding (in
# terms of position) precomputed element has been set.
precomputed_values = [*precomputed.replace.values(), precomputed.add]
precomputed_elements = [
elem for replace_list in precomputed_values for elem in replace_list
]
precomputed_template_parameters = [
elem.name.upper() for elem in precomputed_elements
]
precomputed_template_params_str = ", ".join(
f"bool {param} = false" for param in precomputed_template_parameters
)
precompute_template_decl = f"template <{precomputed_template_params_str}>"
# Generate a string containing declarations of all precomputed elements.
precomputed_elements_with_cpp_types = [
structured.argument_type(elem, binds=elem.name)
for elem in precomputed_elements
]
precomputed_elements_decl = ";\n".join(
f"{elem.cpp_type(strip_ref=True)} {elem.name}"
for elem in precomputed_elements_with_cpp_types
)
# Generate "setter" methods for each precomputed element. Each method will return
# a new instance of precompute_out with the template parameter that corresponds to
# the member set by the method to true (to indicate that it has been set).
setter_methods = []
for i, elem in enumerate(precomputed_elements):
# Generate the signature. The return type will be the same
# as the type of `this` but with the template parameter
# corresponding to the element set by this method set to true.
# The assert generated below will ensure that this template
# parameter is false on the type of `this`.
return_ty_templates = ", ".join(
precomputed_template_parameters[:i]
+ ["true"]
+ precomputed_template_parameters[i + 1 :]
)
return_ty = f"precompute_out<{return_ty_templates}>"
elem_cpp_ty = precomputed_elements_with_cpp_types[i].cpp_type(
strip_ref=True
)
signature = f"{return_ty} set_{elem.name}({elem_cpp_ty} value)"
# Generate an assert which checks that the
# template parameter corresponding to the precomputed
# element that is set by this method is false on the
# class corresponding to the object that `this` points to.
# This ensures that each element can be set only once.
assert_msg = f'"{precomputed_elements[i].name} already set"'
assert_stmt = f"static_assert({precomputed_template_parameters[i]} == false, {assert_msg});"
# Generate the new object construction block. All state
# except the element that this method sets is copied from the
# object that `this` points to. The value for the element that
# the method sets is taken from a method parameter.
construction_stmts = []
construction_stmts.append(f"{return_ty} ret;")
for j, elem in enumerate(precomputed_elements):
if i == j:
construction_stmts.append(f"ret.{elem.name} = value;")
else:
construction_stmts.append(
f"ret.{elem.name} = this->{elem.name};"
)
construction_stmts.append("return ret;")
construction_block = "\n".join(construction_stmts)
setter_methods.append(
f"""
{signature} {{
{assert_stmt}
{construction_block}
}}
"""
)
setter_methods_decl = "\n".join(setter_methods)
# Meta should return an instance of the struct containing the precomputed elements.
meta_return_template_params = ", ".join(
["true"] * len(precomputed_template_parameters)
)
# This typedef (actually a using statement) is needed so that TORCH_META_FUNC can reuse the return
# type (which has a variable number of template parameters).
meta_return_typedef = f"using meta_return_ty = precompute_out <{meta_return_template_params}>;"
meta_return = "meta_return_ty"
precomputed_decl = f"""
{precompute_template_decl}
struct TORCH_API precompute_out {{
{setter_methods_decl}
{precomputed_elements_decl};
}};"""
else:
meta_return_typedef = ""
precomputed_decl = ""
return f"""\
struct TORCH_API structured_{name} : public {parent_class} {{
{precomputed_decl}
{meta_return_typedef}
{meta_return} meta({args_str});
}};
"""
def needs_backend_select(f: NativeFunction, selector: SelectiveBuilder) -> bool:
name = str(f.func.name.name)
if name.endswith("_like") or name.startswith("new_"):
return False
if f.func.arguments.tensor_options is None:
return False
return selector.is_native_function_selected(f)
# Generates RegisterBackendSelect.cpp, a series of kernels which provide
# specialized computation of dispatch key for operator signatures which cannot
# be easily done automatically using templating.
@dataclass(frozen=True)
class ComputeBackendSelect:
target: Union[Literal[Target.DEFINITION], Literal[Target.REGISTRATION]]
# Selector object to determine which operators to generate
# registration code for.
selector: SelectiveBuilder
@method_with_native_function
def __call__(self, f: NativeFunction) -> Optional[str]:
if not needs_backend_select(f, self.selector):
return None
name = native.name(f.func)
native_sig = NativeSignature(f.func)
native_tensor_args = [
a
for a in native_sig.arguments()
if isinstance(a.argument, Argument) and a.argument.type.is_tensor_like()
]
dispatcher_sig = DispatcherSignature.from_schema(f.func)
sig: Union[NativeSignature, DispatcherSignature]
sig = dispatcher_sig
dispatcher_exprs = dispatcher_sig.exprs()
dispatch_key = "c10::computeDispatchKey(dtype, layout, device)"
if self.target is Target.DEFINITION:
# I don't think there's actually a good reason to generate
# these two cases differently
# The first case could probably be improved though- it calls computeDispatchKeySet(),
# which looks at TLS dispatch keys- there should not be any by the time we reach backend select.
if native_tensor_args:
tensor_args = ", ".join(a.name for a in native_tensor_args)
compute_dk = f"""\
DispatchKeySet _dk_set = c10::DispatchKeySet({dispatch_key}) | c10::detail::multi_dispatch_key_set({tensor_args});
DispatchKeySet _dk_mask = c10::DispatchKeySet(DispatchKeySet::FULL_AFTER, DispatchKey::BackendSelect);
DispatchKeySet _dk = c10::impl::computeDispatchKeySet(_dk_set, _dk_mask);"""
else:
compute_dk = (
f"DispatchKeySet _dk = c10::DispatchKeySet({dispatch_key});"
)
return f"""\
// aten::{f.func}
C10_ALWAYS_INLINE
{sig.defn(name)} {{
{compute_dk}
return at::_ops::{f.func.name.unambiguous_name()}::redispatch(
_dk, {', '.join(a.expr for a in dispatcher_exprs)});
}}
"""
elif self.target is Target.REGISTRATION:
return f"""m.impl("aten::{f.func.name}", TORCH_FN({name}));"""
else:
assert_never(self.target)
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
#
# YAML CODE GENERATION
#
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
def format_yaml(data: object) -> str:
# Ignore alias in Dumper
YamlDumper.ignore_aliases = lambda self, data: True # type: ignore[assignment]
# Support serializing OrderedDict
def dict_representer(dumper: Any, data: Any) -> Any:
return dumper.represent_dict(data.items())
YamlDumper.add_representer(OrderedDict, dict_representer) # type: ignore[no-untyped-call]
# Some yaml parsers (e.g. Haskell's) don't understand line breaks.
# width=1e9 turns off optional line breaks and improves
# the portability of the outputted yaml.
return yaml.dump(data, default_flow_style=False, Dumper=YamlDumper, width=1e9) # type: ignore[no-any-return, call-overload]
# For some reason, some defaults we write to YAML are written as native
# YAML objects, rather than doing them uniformly as strings. This
# function detects those cases and converts them into native Python
# objects.
def pythonify_default(s: str) -> object:
if s == "true":
return True
elif s == "false":
return False
try:
return int(s)
except ValueError:
try:
return float(s)
except ValueError:
return s
# What is a dynamic type? Over time, the semantic meaning of
# dynamic type has degraded to meaninglessness (in the old days,
# it captured dtype-ness of types, but that has gone away with
# the removal of TH). These days, it's mostly the same thing as
# the C++ API argument type, except that Tensor and Tensor?
# arguments simply present as Tensor.
#
# TODO: Get rid of dynamic_type, after getting tools/autograd
# to use the new codegen framework
def dynamic_type(t: Type) -> str:
if isinstance(t, OptionalType):
return dynamic_type(t.elem)
# Note we don't use t.is_tensor_like() here because it would
# also include Tensor[]
if str(t) == "Tensor":
return "at::Tensor"
return cpp.argumenttype_type(t, mutable=False, binds="__placeholder__").cpp_type()
def compute_method_of_yaml(variants: Set[Variant]) -> List[str]:
# This is written out explicitly to ensure that Tensor and
# namespace are put into the list in the right order
method_of = ["Type"]
if Variant.method in variants:
method_of.append("Tensor")
if Variant.function in variants:
method_of.append("namespace")
return method_of
def compute_returns_yaml(
f: NativeFunction,
) -> Tuple[List[Dict[str, str]], Dict[str, str]]:
# Note [name and field_name]
# ~~~~~~~~~~~~~~~~~~~~~~~~~~
# To understand name_to_field_name, we must first talk about this
# schema:
#
# lstsq.X(Tensor self, Tensor A, *, Tensor(a!) X, Tensor(b!) qr) -> (Tensor(a!) solution, Tensor(b!) QR)
#
# There is something very odd about this schema: it is an out
# variant of the function (that is to say, it will convert into
# at::lstsq_out() in the C++ API), but the names of the output
# return arguments don't match the keyword argument names of
# the inputs. It TURNS OUT that in this situation, the historical
# Declarations.yaml we want to output is this (abbreviated to
# only show relevant fields):
#
# arguments:
# ...
# - field_name: solution
# name: X
# - field_name: QR
# name: qr
# ...
#
# returns:
# - field_name: solution
# name: X
# - field_name: QR
# name: qr
#
# The name of the return fields is stored in 'field_name', and the
# name of the arguments is stored in 'name'. So when we process
# arguments, we need a way to get at the corresponding return. At
# the moment, this is most conveniently done by constructing a
# mapping from name (the argument concept) to field_name (the
# return concept) while processing return arguments, since we don't
# directly maintain this correspondence in the modeling of function
# schema itself.
#
# See also https://github.com/pytorch/pytorch/issues/43114
name_to_field_name: Dict[str, str] = {}
# Compute the returns field of the YAML entry
names = cpp.return_names(f)
returns = []
for i, (r, name) in enumerate(zip(f.func.returns, names)):
ret = {
"dynamic_type": dynamic_type(r.type),
"name": name,
"type": cpp.return_type(r).cpp_type(),
}
if r.name:
# See Note [name and field_name]
ret["field_name"] = r.name
if f.func.is_out_fn():
name_to_field_name[f.func.arguments.out[i].name] = r.name
returns.append(ret)
return returns, name_to_field_name
# arguments in yaml roughly corresponds to the public C++ API
def compute_cpp_argument_yaml(
cpp_a: Binding,
*,
schema_order: bool,
kwarg_only_set: Set[str],
out_arg_set: Set[str],
name_to_field_name: Dict[str, str],
) -> object:
if isinstance(cpp_a.argument, TensorOptionsArguments):
arg: Dict[str, object] = {
"annotation": None,
"dynamic_type": "at::TensorOptions",
"is_nullable": False,
"name": cpp_a.name,
"type": cpp_a.type,
"kwarg_only": True,
}
if cpp_a.default is not None:
arg["default"] = cpp_a.default
return arg
elif isinstance(cpp_a.argument, SelfArgument):
raise AssertionError()
elif isinstance(cpp_a.argument, Argument):
return compute_argument_yaml(
cpp_a.argument,
schema_order=schema_order,
kwarg_only_set=kwarg_only_set,
out_arg_set=out_arg_set,
name_to_field_name=name_to_field_name,
)
def compute_argument_yaml(
a: Argument,
*,
schema_order: bool,
kwarg_only_set: Set[str],
out_arg_set: Set[str],
name_to_field_name: Dict[str, str],
) -> object:
arg: Dict[str, object] = {
"annotation": str(a.annotation) if a.annotation else None,
"dynamic_type": dynamic_type(a.type),
"is_nullable": a.type.is_nullable(),
"name": a.name,
"type": cpp.argument_type(a, binds="__placeholder__").cpp_type(),
}
if a.default is not None:
arg["default"] = pythonify_default(cpp.default_expr(a.default, a.type))
if a.name in kwarg_only_set:
arg["kwarg_only"] = True
if a.name in out_arg_set:
arg["output"] = True
arg["allocate"] = True
# See Note [name and field_name]
if a.name in name_to_field_name:
arg["field_name"] = name_to_field_name[a.name]
# Historically, booleans don't get their size recorded, because it
# is already built into the cpp type (e.g., std::array<bool, 4>)
l = a.type.is_list_like()
if l is not None and l.size is not None and str(l.elem) != "bool":
arg["size"] = l.size
return arg
@with_native_function
def compute_declaration_yaml(f: NativeFunction) -> object:
returns, name_to_field_name = compute_returns_yaml(f)
# These sets are used to conveniently test if an argument is a
# kwarg-only or out argument
kwarg_only_set = set(a.name for a in f.func.arguments.flat_kwarg_only)
out_arg_set = set(a.name for a in f.func.arguments.out)
sig_group = CppSignatureGroup.from_native_function(
f, method=False, fallback_binding=False
)
cpp_args = sig_group.signature.arguments()
arguments = [
compute_cpp_argument_yaml(
cpp_a,
schema_order=False,
kwarg_only_set=kwarg_only_set,
out_arg_set=out_arg_set,
name_to_field_name=name_to_field_name,
)
for cpp_a in cpp_args
]
schema_order_jit_arguments = list(f.func.schema_order_arguments())
schema_order_arguments = [
compute_argument_yaml(
a,
schema_order=True,
kwarg_only_set=kwarg_only_set,
out_arg_set=out_arg_set,
name_to_field_name=name_to_field_name,
)
for a in schema_order_jit_arguments
]
cpp_schema_order_types = [
# NB: method here doesn't matter
r.type
for a in schema_order_jit_arguments
for r in cpp.argument(
a,
method=False,
cpp_no_default_args=set(),
faithful=False,
has_tensor_options=False,
)
]
cpp_returns = cpp.returns_type(f.func.returns).cpp_type()
schema_order_cpp_signature = f"{cpp_returns} ({', '.join(cpp_schema_order_types)})"
is_factory_method = (
any(isinstance(a.argument, TensorOptionsArguments) for a in cpp_args)
and Variant.method not in f.variants
)
return OrderedDict(
[
("name", cpp.name(f.func)),
("operator_name", str(f.func.name.name)),
("overload_name", str(f.func.name.overload_name)),
("manual_kernel_registration", f.manual_kernel_registration),
(
"category_override",
f.category_override if f.category_override is not None else "",
),
("schema_string", f"aten::{f.func}"),
("arguments", arguments),
("schema_order_cpp_signature", schema_order_cpp_signature),
("schema_order_arguments", schema_order_arguments),
("method_of", compute_method_of_yaml(f.variants)),
("mode", "native"),
("python_module", "" if f.python_module is None else f.python_module),
("returns", returns),
("inplace", f.func.name.name.inplace),
("is_factory_method", is_factory_method),
("abstract", f.is_abstract),
("device_guard", f.device_guard),
("with_gil", False),
("deprecated", False),
("has_math_kernel", f.has_composite_implicit_autograd_kernel),
]
)
# See Note [Auto generated composite kernels]
def has_autogenerated_composite_kernel(f: NativeFunction) -> bool:
return (f.structured or f.structured_delegate is not None) and (
f.func.kind() == SchemaKind.functional or f.func.kind() == SchemaKind.inplace
)
@with_native_function_and_indices
def compute_registration_declarations(
f: NativeFunction, backend_indices: Dict[DispatchKey, BackendIndex]
) -> str:
name = dispatcher.name(f.func)
returns_type = dispatcher.returns_type(
f.func.returns
).cpp_type_registration_declarations()
args = dispatcher.arguments(f.func)
args_str = ", ".join(a.no_default().decl_registration_declarations() for a in args)
comment_data: Dict[str, str] = {
"schema": f"aten::{f.func}",
# TODO: What exactly is the semantics of the 'dispatch' field?
"dispatch": str(
{k for k, v in backend_indices.items() if v.has_kernel(f)}
!= {DispatchKey.CompositeImplicitAutograd}
),
"default": str(f.has_composite_kernel or has_autogenerated_composite_kernel(f)),
}
return f"""{returns_type} {name}({args_str}); // {json.dumps(comment_data)}
"""
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
#
# RUN IT ALL
#
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
def get_custom_build_selector(
provided_op_registration_allowlist: Optional[List[str]],
op_selection_yaml_path: Optional[str],
) -> SelectiveBuilder:
assert not (
provided_op_registration_allowlist is not None
and op_selection_yaml_path is not None
), (
"Both provided_op_registration_allowlist and "
+ "op_selection_yaml_path can NOT be provided at the "
+ "same time."
)
op_registration_allowlist: Optional[Set[str]] = None
if provided_op_registration_allowlist is not None:
op_registration_allowlist = set(provided_op_registration_allowlist)
if op_registration_allowlist is not None:
selector = SelectiveBuilder.from_legacy_op_registration_allow_list(
op_registration_allowlist,
True,
False,
)
elif op_selection_yaml_path is not None:
selector = SelectiveBuilder.from_yaml_path(op_selection_yaml_path)
else:
selector = SelectiveBuilder.get_nop_selector()
return selector
def get_grouped_by_view_native_functions(
native_functions: Sequence[NativeFunction],
) -> Sequence[Union[NativeFunction, NativeFunctionsViewGroup]]:
def maybe_create_view_group(
d: Dict[Union[ViewSchemaKind, SchemaKind], NativeFunction]
) -> List[Union[NativeFunction, NativeFunctionsViewGroup]]:
funcs: List[Union[NativeFunction, NativeFunctionsViewGroup]] = []
if ViewSchemaKind.aliasing in d:
view = d.pop(ViewSchemaKind.aliasing)
view_inplace = d.pop(ViewSchemaKind.aliasing_inplace, None)
view_copy = d.pop(SchemaKind.functional, None)
funcs.append(
NativeFunctionsViewGroup(
view=view,
view_copy=view_copy,
view_inplace=view_inplace,
)
)
# Take the remaining functions that weren't part of the view group
# and emit them separately
for func in d.values():
funcs.append(func)
return funcs
grouped_by_views: Dict[
FunctionSchema, Dict[Union[SchemaKind, ViewSchemaKind], NativeFunction]
] = defaultdict(dict)
for f in native_functions:
schema = f.func.view_signature()
view_kind: ViewSchemaKind = f.view_schema_kind
# We need to group up ops relevant to the same "view", consisting of:
# view op (ViewSchemaKind.aliasing)
# view_inplace op (ViewSchemaKind.aliasing_inplace)
# view_copy op (SchemaKind.functional)
if view_kind == ViewSchemaKind.non_aliasing:
kind = f.func.kind()
assert kind not in grouped_by_views[schema]
grouped_by_views[schema][kind] = f
else:
assert view_kind not in grouped_by_views[schema]
grouped_by_views[schema][view_kind] = f
return list(concatMap(maybe_create_view_group, grouped_by_views.values()))
def get_grouped_native_functions(
native_functions: Sequence[NativeFunction],
) -> Sequence[Union[NativeFunction, NativeFunctionsGroup]]:
def flatten_pre_group(
d: Dict[SchemaKind, NativeFunction]
) -> Sequence[Union[NativeFunction, NativeFunctionsGroup]]:
r = NativeFunctionsGroup.from_dict(d)
if r is None:
# Invariant: any NativeFunctions that are code-generated
# should have been grouped into NativeFunctionsGroup objects
assert not any("generated" in f.tags for f in d.values())
return list(d.values())
else:
return [r]
# TODO: how come ValuesView isn't a Sequence lol
pre_grouped_native_functions = pre_group_native_functions(native_functions)
return list(
concatMap(flatten_pre_group, list(pre_grouped_native_functions.values()))
)
def gen_aggregated_headers(
*,
native_functions: Sequence[NativeFunction],
grouped_native_functions: Sequence[Union[NativeFunction, NativeFunctionsGroup]],
structured_native_functions: Sequence[NativeFunctionsGroup],
static_dispatch_idx: List[BackendIndex],
selector: SelectiveBuilder,
backend_indices: Dict[DispatchKey, BackendIndex],
cpu_fm: FileManager,
cuda_fm: FileManager,
functions_keys: Set[DispatchKey],
dispatch_keys: Sequence[DispatchKey],
rocm: bool,
) -> None:
# Buck doesn't support dynamic output files, so we aggregate all operator
# headers into a single file
cpu_fm.write(
"NativeMetaFunctions.h",
lambda: {
"NativeMetaFunctions_includes": [],
"NativeMetaFunctions_declarations": list(
mapMaybe(compute_meta_function_declaration, structured_native_functions)
),
},
)
method_native_functions = [
fn for fn in native_functions if Variant.method in fn.variants
]
non_method_native_functions = [
fn for fn in native_functions if fn not in method_native_functions
]
cpu_fm.write(
"MethodOperators.h",
lambda: {
"MethodOperators_includes": [],
"MethodOperators_declarations": list(
mapMaybe(
ComputeOperators(
Target.DECLARATION,
static_dispatch_backend_indices=static_dispatch_idx,
),
method_native_functions,
)
),
},
)
cpu_fm.write(
"Operators.h",
lambda: {
"Operators_includes": ["#include <ATen/MethodOperators.h>"],
"Operators_declarations": list(
mapMaybe(
ComputeOperators(
Target.DECLARATION,
static_dispatch_backend_indices=static_dispatch_idx,
),
non_method_native_functions,
)
),
},
)
cpu_fm.write(
"Functions.h",
lambda: {
"static_dispatch_extra_headers": static_dispatch_extra_headers(
static_dispatch_idx
),
"Functions_includes": ["#include <ATen/Operators.h>"],
"Functions_declarations": list(
mapMaybe(
ComputeFunction(),
native_functions,
)
),
},
)
cpu_fm.write(
"NativeFunctions.h",
lambda: {
"NativeFunctions_includes": ["#include <ATen/NativeMetaFunctions.h>"],
"NativeFunctions_declarations": list(
concatMap(
# Convert to a set first to remove duplicate kernel names.
# Backends are allowed to repeat kernel names; only generate the declaration once!
lambda f: list(
OrderedDict.fromkeys(
concatMap(
lambda backend_idx: dest.compute_native_function_declaration(
f, backend_idx
),
backend_indices.values(),
)
)
),
grouped_native_functions,
)
),
},
)
for dispatch_key in dispatch_keys:
fm = cuda_fm if is_cuda_dispatch_key(dispatch_key) else cpu_fm
if dispatch_key in functions_keys:
inl_headers = f"#include <ATen/{dispatch_key}Functions_inl.h>"
fm.write_with_template(
f"{dispatch_key}Functions.h",
"DispatchKeyFunctions.h",
lambda: {
"dispatch_key": str(dispatch_key),
"inline_headers": inl_headers,
},
)
fm.write_with_template(
f"{dispatch_key}Functions_inl.h",
"DispatchKeyFunctions_inl.h",
lambda: {
"DispatchKeyFunctions_inl_includes": [],
"dispatch_namespace": dispatch_key.lower(),
"dispatch_namespaced_declarations": list(
concatMap(
dest.RegisterDispatchKey(
backend_indices[dispatch_key],
Target.NAMESPACED_DECLARATION,
selector,
rocm=rocm,
class_method_name=None,
skip_dispatcher_op_registration=False,
),
grouped_native_functions,
)
),
},
)
del fm
def gen_per_operator_headers(
*,
native_functions: Sequence[NativeFunction],
grouped_native_functions: Sequence[Union[NativeFunction, NativeFunctionsGroup]],
static_dispatch_idx: List[BackendIndex],
selector: SelectiveBuilder,
backend_indices: Dict[DispatchKey, BackendIndex],
cpu_fm: FileManager,
cuda_fm: FileManager,
ops_fm: FileManager,
functions_keys: Set[DispatchKey],
dispatch_keys: Sequence[DispatchKey],
rocm: bool,
) -> None:
# For CMake builds, split operator declarations into separate headers in
# the ATen/ops folder to split up header dependencies
functions_by_root_name: Dict[str, List[NativeFunction]] = defaultdict(lambda: [])
for fn in native_functions:
functions_by_root_name[fn.root_name].append(fn)
grouped_functions_by_root_name: Dict[
str, List[Union[NativeFunction, NativeFunctionsGroup]]
] = defaultdict(lambda: [])
for group in grouped_native_functions:
name = group.root_name
grouped_functions_by_root_name[name].append(group)
for name, functions in functions_by_root_name.items():
ops_fm.write_with_template(
f"{name}_ops.h",
"Operator.h",
lambda: {
"declarations": list(
mapMaybe(
ComputeOperators(
Target.DECLARATION,
static_dispatch_backend_indices=static_dispatch_idx,
),
functions,
)
),
},
)
ops_fm.write_with_template(
f"{name}.h",
"Function.h",
lambda: {
"static_dispatch_ops_headers": list(
mapMaybe(
lambda fn: static_dispatch_ops_header(
fn, backend_index=static_dispatch_idx
),
functions,
)
),
"operator_includes": f"#include <ATen/ops/{name}_ops.h>",
"function_definitions": list(
mapMaybe(
ComputeFunction(),
functions,
)
),
},
)
grouped_functions = grouped_functions_by_root_name.get(name, [])
structured_functions = [
fn
for fn in grouped_functions
if isinstance(fn, NativeFunctionsGroup) and fn.structured
]
is_structured = len(structured_functions) > 0
if is_structured:
ops_fm.write_with_template(
f"{name}_meta.h",
"NativeMetaFunction.h",
lambda: {
"meta_function_declarations": list(
mapMaybe(
compute_meta_function_declaration, structured_functions
)
),
},
)
ops_fm.write_with_template(
f"{name}_native.h",
"NativeFunction.h",
lambda: {
"extra_includes": (
f"#include <ATen/ops/{name}_meta.h>" if is_structured else []
),
"native_function_declarations": list(
concatMap(
# Convert to a set first to remove duplicate kernel names.
# Backends are allowed to repeat kernel names; only generate the declaration once!
lambda f: list(
OrderedDict.fromkeys(
concatMap(
lambda backend_idx: dest.compute_native_function_declaration(
f, backend_idx
),
backend_indices.values(),
)
)
),
grouped_functions,
)
),
},
)
for category, suffix in [
("Functions", ""),
("Operators", "_ops"),
("NativeMetaFunctions", "_meta"),
("NativeFunctions", "_native"),
]:
cpu_fm.write(
f"{category}.h",
lambda: {
f"{category}_includes": [
f"#include <ATen/ops/{name}{suffix}.h>"
for name in sorted(functions_by_root_name.keys())
],
f"{category}_declarations": [],
},
)
for dispatch_key in dispatch_keys:
if dispatch_key not in functions_keys:
continue
dispatch_namespace = dispatch_key.lower()
dispatch_names = []
for name, functions in functions_by_root_name.items():
grouped_functions = grouped_functions_by_root_name.get(name, [])
declarations = list(
concatMap(
dest.RegisterDispatchKey(
backend_indices[dispatch_key],
Target.NAMESPACED_DECLARATION,
selector,
rocm=rocm,
class_method_name=None,
skip_dispatcher_op_registration=False,
),
grouped_functions,
)
)
if len(declarations) == 0:
continue
dispatch_names.append(name)
ops_fm.write_with_template(
f"{name}_{dispatch_namespace}_dispatch.h",
"DispatchKeyFunction.h",
lambda: {
"dispatch_namespace": dispatch_namespace,
"dispatch_namespaced_declarations": declarations,
},
)
fm = cuda_fm if is_cuda_dispatch_key(dispatch_key) else cpu_fm
inl_headers = f"#include <ATen/{dispatch_key}Functions_inl.h>"
fm.write_with_template(
f"{dispatch_key}Functions.h",
"DispatchKeyFunctions.h",
lambda: {
"dispatch_key": str(dispatch_key),
"inline_headers": inl_headers,
},
)
fm.write_with_template(
f"{dispatch_key}Functions_inl.h",
"DispatchKeyFunctions_inl.h",
lambda: {
"dispatch_namespace": dispatch_namespace,
"DispatchKeyFunctions_inl_includes": [
f"#include <ATen/ops/{name}_{dispatch_namespace}_dispatch.h>"
for name in sorted(dispatch_names)
],
"dispatch_namespaced_declarations": [],
},
)
del fm
cpu_fm.write(
"MethodOperators.h",
lambda: {
"MethodOperators_includes": sorted(
f"#include <ATen/ops/{name}_ops.h>"
for name, functions in functions_by_root_name.items()
if any(Variant.method in fn.variants for fn in functions)
),
"MethodOperators_declarations": [],
},
)
def gen_headers(
*,
native_functions: Sequence[NativeFunction],
valid_tags: Set[str],
grouped_native_functions: Sequence[Union[NativeFunction, NativeFunctionsGroup]],
structured_native_functions: Sequence[NativeFunctionsGroup],
static_dispatch_idx: List[BackendIndex],
selector: SelectiveBuilder,
backend_indices: Dict[DispatchKey, BackendIndex],
core_fm: FileManager,
cpu_fm: FileManager,
cuda_fm: FileManager,
ops_fm: FileManager,
dispatch_keys: Sequence[DispatchKey],
functions_keys: Set[DispatchKey],
rocm: bool,
per_operator_headers: bool,
) -> None:
if per_operator_headers:
gen_per_operator_headers(
native_functions=native_functions,
grouped_native_functions=grouped_native_functions,
static_dispatch_idx=static_dispatch_idx,
selector=selector,
backend_indices=backend_indices,
cpu_fm=cpu_fm,
cuda_fm=cuda_fm,
ops_fm=ops_fm,
dispatch_keys=dispatch_keys,
functions_keys=functions_keys,
rocm=rocm,
)
else:
gen_aggregated_headers(
native_functions=native_functions,
grouped_native_functions=grouped_native_functions,
structured_native_functions=structured_native_functions,
static_dispatch_idx=static_dispatch_idx,
selector=selector,
backend_indices=backend_indices,
cpu_fm=cpu_fm,
cuda_fm=cuda_fm,
dispatch_keys=dispatch_keys,
functions_keys=functions_keys,
rocm=rocm,
)
core_fm.write(
"TensorBody.h",
lambda: {
"tensor_method_declarations": list(
mapMaybe(
ComputeTensorMethod(
target=Target.DECLARATION,
static_dispatch_backend_indices=static_dispatch_idx,
),
native_functions,
)
),
"tensor_method_definitions": list(
mapMaybe(
ComputeTensorMethod(
target=Target.DEFINITION,
static_dispatch_backend_indices=static_dispatch_idx,
),
native_functions,
)
),
},
)
cpu_fm.write(
"RedispatchFunctions.h",
lambda: {
"function_redispatch_definitions": list(
mapMaybe(ComputeRedispatchFunction(), native_functions)
),
},
)
cpu_fm.write(
"RegistrationDeclarations.h",
lambda: {
"registration_declarations": [
compute_registration_declarations(f, backend_indices)
for f in native_functions
],
},
)
def gen_aten_interned_strings() -> Dict[str, str]:
attrs = set() # All function argument names
names = set() # All ATen function names
for func in native_functions:
names.add(str(func.func.name.name))
# Some operators don't have a functional variant but we still create a
# symbol without the underscore
names.add(func.func.name.name.base)
for arg in func.func.schema_order_arguments():
attrs.add(arg.name)
# These are keywords in C++, so aren't valid symbol names
# https://en.cppreference.com/w/cpp/language/operator_alternative
names -= set(
[
"and",
"and_eq",
"bitand",
"bitor",
"compl",
"not",
"not_eq",
"or",
"or_eq",
"xor",
"xor_eq",
]
)
return {
"aten_symbols": " \\\n".join(
[f"_(aten, {name})" for name in sorted(names)]
),
"attr_symbols": " \\\n".join(
[f"_(attr, {name})" for name in sorted(attrs)]
),
}
core_fm.write("aten_interned_strings.h", gen_aten_interned_strings)
def gen_tags_enum() -> Dict[str, str]:
return {"enum_of_valid_tags": (",\n".join([f"{tag}" for tag in valid_tags]))}
core_fm.write("enum_tag.h", gen_tags_enum)
def gen_source_files(
*,
native_functions: Sequence[NativeFunction],
grouped_native_functions: Sequence[Union[NativeFunction, NativeFunctionsGroup]],
structured_native_functions: Sequence[NativeFunctionsGroup],
view_groups: Sequence[NativeFunctionsViewGroup],
selector: SelectiveBuilder,
static_dispatch_idx: List[BackendIndex],
backend_indices: Dict[DispatchKey, BackendIndex],
core_fm: FileManager,
cpu_fm: FileManager,
cpu_vec_fm: FileManager,
cuda_fm: FileManager,
dispatch_keys: Sequence[DispatchKey],
functions_keys: Set[DispatchKey],
rocm: bool,
force_schema_registration: bool,
per_operator_headers: bool,
skip_dispatcher_op_registration: bool,
) -> None:
extra_cuda_headers = """\
#include <c10/cuda/CUDAGuard.h>
#include <ATen/cuda/ATenCUDAGeneral.h>
#include <ATen/cuda/CUDADevice.h>
#include <ATen/cuda/CUDAContext.h>"""
if rocm:
extra_cuda_headers = """\
#include <ATen/hip/impl/HIPGuardImplMasqueradingAsCUDA.h>
#include <ATen/hip/ATenHIPGeneral.h>
#include <ATen/hip/HIPDevice.h>
#include <ATen/hip/HIPContext.h>"""
for dispatch_key in dispatch_keys:
fm = cuda_fm if is_cuda_dispatch_key(dispatch_key) else cpu_fm
if per_operator_headers:
def operator_headers() -> List[str]:
headers = []
for g in grouped_native_functions:
is_registered = False
if backend_index.has_kernel(g):
is_registered = True
# The above has_kernel test on a group will only test for
# the existence of out dispatch, because that's how
# structured kernels work. But sometimes functions can be
# grouped but not be structured, and then you need to check
# each individual piece, as they may have manual dispatch
# entries.
elif isinstance(g, NativeFunctionsGroup) and any(
backend_index.has_kernel(fn) for fn in g.functions()
):
is_registered = True
# TODO: this condition is a bit questionable
# (It has to do with the fact that structured kernels get generated kernels
# to the Meta + CompositeExplicitAutogradNonFunctional keys).
elif g.structured and dispatch_key in (
DispatchKey.Meta,
DispatchKey.CompositeExplicitAutogradNonFunctional,
):
is_registered = True
if not is_registered:
continue
headers.append(f"#include <ATen/ops/{g.root_name}_native.h>")
if (
dispatch_key
== DispatchKey.CompositeExplicitAutogradNonFunctional
):
headers.append(f"#include <ATen/ops/{g.root_name}.h>")
if dispatch_key in functions_keys:
headers.append(
f"#include <ATen/ops/{g.root_name}_{dispatch_namespace}_dispatch.h>"
)
return sorted(set(headers))
else:
def operator_headers() -> List[str]:
headers = ["#include <ATen/NativeFunctions.h>"]
if dispatch_key == DispatchKey.CompositeExplicitAutogradNonFunctional:
headers.append("#include <ATen/Functions.h>")
if dispatch_key in functions_keys:
headers.append(f"#include <ATen/{dispatch_key!s}Functions.h>")
return headers
backend_index = backend_indices[dispatch_key]
ns_grouped_native_functions = defaultdict(list)
for grouped_native_function in grouped_native_functions:
namespace = (
grouped_native_function.namespace
if isinstance(grouped_native_function, NativeFunction)
else grouped_native_function.functional.namespace
)
ns_grouped_native_functions[namespace].append(grouped_native_function)
static_init_dispatch_registrations = ""
for namespace, functions in ns_grouped_native_functions.items():
dispatch_registrations_body = (
""
if skip_dispatcher_op_registration
else "\n".join(
list(
concatMap(
dest.RegisterDispatchKey(
backend_index,
Target.REGISTRATION,
selector,
rocm=rocm,
class_method_name=None,
skip_dispatcher_op_registration=skip_dispatcher_op_registration,
),
functions,
)
)
)
)
static_init_dispatch_registrations += f"""
TORCH_LIBRARY_IMPL({namespace}, {dispatch_key}, m) {{
{dispatch_registrations_body}
}};"""
dispatch_namespace = str(dispatch_key).lower()
fm.write_with_template(
f"Register{dispatch_key}.cpp",
"RegisterDispatchKey.cpp",
lambda: {
"extra_cuda_headers": extra_cuda_headers
if is_cuda_dispatch_key(dispatch_key)
else "",
"external_backend_headers": "",
"dispatch_headers": dest.gen_registration_headers(
backend_index, per_operator_headers, rocm
),
"ops_headers": operator_headers(),
"DispatchKey": dispatch_key,
"dispatch_namespace": dispatch_key.lower(),
"dispatch_helpers": dest.gen_registration_helpers(backend_index),
"dispatch_namespaced_definitions": list(
concatMap(
dest.RegisterDispatchKey(
backend_index,
Target.NAMESPACED_DEFINITION,
selector,
rocm=rocm,
class_method_name=None,
skip_dispatcher_op_registration=skip_dispatcher_op_registration,
),
grouped_native_functions,
)
),
"dispatch_anonymous_definitions": list(
concatMap(
dest.RegisterDispatchKey(
backend_index,
Target.ANONYMOUS_DEFINITION,
selector,
rocm=rocm,
class_method_name=None,
skip_dispatcher_op_registration=skip_dispatcher_op_registration,
),
grouped_native_functions,
)
),
"static_init_dispatch_registrations": static_init_dispatch_registrations,
"deferred_dispatch_registrations": "",
},
)
for g in structured_native_functions:
if not g.out.ufunc_inner_loop or not is_ufunc_dispatch_key(dispatch_key):
continue
name = g.functional.func.name.name
if dispatch_key is DispatchKey.CPU:
assert fm is cpu_fm
fm.write_with_template(
f"UfuncCPU_{name}.cpp",
"UfuncCPU.cpp",
lambda: {
"meta_declaration": compute_meta_function_declaration(g),
"native_declaration": dest.compute_native_function_declaration(
g, backend_indices[dispatch_key]
),
"native_definitions": dest.compute_ufunc_cpu(g),
},
)
cpu_vec_fm.write_with_template(
f"UfuncCPUKernel_{name}.cpp",
"UfuncCPUKernel.cpp",
lambda: {
"name": name,
"native_definitions": dest.compute_ufunc_cpu_kernel(g),
},
)
elif dispatch_key is DispatchKey.CUDA:
cuda_headers = "#include <ATen/native/cuda/Loops.cuh>"
if rocm:
cuda_headers = "#include <ATen/native/hip/Loops.cuh>"
fm.write_with_template(
f"UfuncCUDA_{name}.cu",
"UfuncCUDA.cu",
lambda: {
"name": name,
"cuda_headers": cuda_headers,
"meta_declaration": compute_meta_function_declaration(g),
"native_declaration": dest.compute_native_function_declaration(
g, backend_indices[dispatch_key]
),
"native_definitions": dest.compute_ufunc_cuda(g),
},
)
else:
raise AssertionError(f"unrecognized {dispatch_key} for ufunc")
del fm
# BackendSelect is generated specially
def gen_backend_select() -> Dict[str, List[str]]:
relevant_fns = [
fn for fn in native_functions if needs_backend_select(fn, selector)
]
return {
"ops_headers": [
f"#include <ATen/ops/{fn.root_name}_ops.h>" for fn in relevant_fns
],
"backend_select_method_definitions": list(
mapMaybe(
ComputeBackendSelect(Target.DEFINITION, selector), relevant_fns
)
),
"backend_select_function_registrations": list(
mapMaybe(
ComputeBackendSelect(Target.REGISTRATION, selector), relevant_fns
)
),
}
cpu_fm.write("RegisterBackendSelect.cpp", gen_backend_select)
schema_selector = selector
if force_schema_registration:
schema_selector = SelectiveBuilder.get_nop_selector()
ns_native_functions: Dict[str, List[NativeFunction]] = defaultdict(list)
for native_function in native_functions:
ns_native_functions[native_function.namespace].append(native_function)
schema_registrations = ""
aten_schema_registrations = []
custom_namespace = None
for namespace, funcs in ns_native_functions.items():
schema_registrations_body = list(
mapMaybe(RegisterSchema(schema_selector), funcs)
)
# NB: we have to separate aten namespace registration from other namespaces,
# because in the template we hardcoded an operator for ATen already.
if namespace == "aten":
aten_schema_registrations = schema_registrations_body
else:
assert custom_namespace is None or namespace == custom_namespace, (
"Only one custom namespace (other than 'aten') is currently supported, "
f" but getting {namespace} and {custom_namespace}"
)
custom_namespace = namespace
tab = "\t"
schema_registrations += f"""
TORCH_LIBRARY({custom_namespace}, m) {{
{tab.join(schema_registrations_body)}
}};"""
cpu_fm.write(
"RegisterSchema.cpp",
lambda: {
"aten_schema_registrations": []
if skip_dispatcher_op_registration
else aten_schema_registrations,
"schema_registrations": []
if skip_dispatcher_op_registration
else schema_registrations,
},
)
def key_func(
fn: Union[NativeFunction, NativeFunctionsGroup, NativeFunctionsViewGroup]
) -> str:
return fn.root_name
cpu_fm.write_sharded(
"Operators.cpp",
native_functions,
key_fn=key_func,
env_callable=lambda fn: {
"operator_headers": [f"#include <ATen/ops/{fn.root_name}.h>"],
"definitions": [
ComputeOperators(
Target.DEFINITION,
static_dispatch_backend_indices=static_dispatch_idx,
)(fn)
],
},
base_env={
"static_dispatch_extra_headers": static_dispatch_extra_headers(
static_dispatch_idx
),
},
num_shards=5,
sharded_keys={
"operator_headers",
"definitions",
"static_dispatch_extra_headers",
},
)
cpu_fm.write("Functions.cpp", lambda: {})
core_fm.write("TensorMethods.cpp", lambda: {})
core_fm.write(
"ATenOpList.cpp",
lambda: {
"aten_ops": list(mapMaybe(compute_aten_op, native_functions)),
},
)
def functionalization_env_callable(
g: Union[NativeFunction, NativeFunctionsGroup, NativeFunctionsViewGroup]
) -> Dict[str, List[str]]:
def gen_op_headers(
g: Union[NativeFunction, NativeFunctionsGroup, NativeFunctionsViewGroup]
) -> List[str]:
if isinstance(g, NativeFunctionsViewGroup):
# view ops always get a functionalization kernel
headers = [
f"#include <ATen/ops/{g.view.root_name}_native.h>",
f"#include <ATen/ops/{g.view.root_name}_ops.h>",
]
if g.view_copy is not None:
headers += [
f"#include <ATen/ops/{g.view_copy.root_name}_native.h>",
f"#include <ATen/ops/{g.view_copy.root_name}_ops.h>",
]
return headers
elif isinstance(g, NativeFunctionsGroup):
headers = [
f"#include <ATen/ops/{g.functional.root_name}_native.h>",
f"#include <ATen/ops/{g.functional.root_name}_ops.h>",
f"#include <ATen/ops/{g.out.root_name}_native.h>",
f"#include <ATen/ops/{g.out.root_name}_ops.h>",
]
if g.inplace is not None:
headers += [
f"#include <ATen/ops/{g.inplace.root_name}_native.h>",
f"#include <ATen/ops/{g.inplace.root_name}_ops.h>",
]
if g.mutable is not None:
headers += [
f"#include <ATen/ops/{g.mutable.root_name}_native.h>",
f"#include <ATen/ops/{g.mutable.root_name}_ops.h>",
]
return headers
else:
return [
f"#include <ATen/ops/{g.root_name}_native.h>",
f"#include <ATen/ops/{g.root_name}_ops.h>",
]
return {
"ops_headers": gen_op_headers(g),
"func_definitions": gen_functionalization_definition(
selector,
g,
),
"func_registrations": gen_functionalization_registration(
selector,
g,
backend_indices[DispatchKey.CompositeImplicitAutograd],
),
}
all_groups: List[
Union[NativeFunction, NativeFunctionsGroup, NativeFunctionsViewGroup]
] = list(structured_native_functions) + list(
view_groups # type: ignore[assignment, arg-type, operator]
)
# Note: all operators that functionalization needs to handle (mutable and aliasing ops) should be grouped properly.
# The only reason we really need to deal with direct NativeFunctions here (instead of the groups) is because:
# (1) We can provide better error checking (error out if someone introduces a mutable op that doesn't obey the grouping logic)
# (2) functionalization needs to manually register CompositeImplicitAutograd kernels, which might not be grouped.
# Although this could go away long-term if we add a dedicated dispatch key for decompositions.
structured_map: Dict[OperatorName, NativeFunction] = {
f.func.name: f
for f in concatMap(lambda g: list(g.functions()), structured_native_functions)
}
view_map: Dict[OperatorName, NativeFunction] = {
f.func.name: f for f in concatMap(lambda g: list(g.functions()), view_groups)
}
for f in native_functions:
if f.func.name not in structured_map and f.func.name not in view_map:
all_groups.append(f)
cpu_fm.write_sharded(
"RegisterFunctionalization.cpp",
all_groups,
key_fn=key_func,
env_callable=functionalization_env_callable,
num_shards=4,
sharded_keys={
"ops_headers",
"func_definitions",
"func_registrations",
"func_add_back_views_definitions",
"func_add_back_views_registrations",
},
)
cpu_fm.write(
"FunctionalInverses.h",
lambda: {
"view_inverse_declarations": list(
mapMaybe(
lambda g: gen_functionalization_view_inverse_declaration(
selector, g
),
view_groups,
)
)
},
)
view_copy_with_symint_pairs: List[Tuple[NativeFunction, NativeFunction]] = []
for g1 in view_groups:
for g2 in view_groups:
if g1.view_copy is None or g2.view_copy is None:
continue
# TODO: make this more first class in the data model
same_base_op = str(g1.view_copy.func.name.name) == str(
g2.view_copy.func.name.name
)
op1_not_symint = "SymInt" not in str(g1.view_copy.func.name.overload_name)
op2_symint = "SymInt" in str(g2.view_copy.func.name.overload_name)
if same_base_op and op1_not_symint and op2_symint:
view_copy_with_symint_pairs.append(
(
g1.view_copy,
g2.view_copy,
)
)
# Note [view_copy NativeFunctions]
# Every view operator in native_functions.yaml that is not CompositeImplicitAutograd
# needs to have a corresponding non-aliasing {view}_copy variant.
# Backends that use functionalization and don't know how to handle aliasing ops
# are expected to implement kernels for these {view}_copy kernels instead.
# The code for {view}_copy operators in core is pretty boilerplate-heavy however,
# so we codegen the following:
# (1) A CompositeExplicitAutogradNonFunctional kernel for every {view}_copy operator.
# These are never explicitly invoked by the functionalization pass,
# but they could theoretically be called from user code (I added these kernels for completeness,
# since the ops are part of the public API).
# (2) A derivative formula for every {view}_copy operator
# {view}_copy operators can re-use the same derivative formulas as their {view} op counterparts,
# so rather than stamping all of the entries out in derivatives.yaml,
# we codegen them in.
# This is similar to how autograd codegen doesn't require inplace ops to have a derivatives.yaml entry.
cpu_fm.write(
"CompositeViewCopyKernels.cpp",
lambda: {
"ops_headers": [
"\n".join(
f"#include <ATen/ops/{f.root_name}_ops.h>"
for f in (
[g.view] if g.view_copy is None else [g.view, g.view_copy]
)
)
for g in view_groups
]
+ [
"\n".join(
f"#include <ATen/ops/{f.root_name}_ops.h>"
for f in [g.inplace, g.mutable]
if f is not None and "generated" not in f.tags
)
for g in structured_native_functions
],
"CompositeViewCopyKernel_Definitions": list(
mapMaybe(gen_composite_view_copy_kernel, view_groups)
),
"SymIntViewCopyKernel_Definitions": list(
mapMaybe(
lambda pair: gen_symint_view_copy_kernel(pair[0], pair[1]),
view_copy_with_symint_pairs,
)
),
"GeneratedCompositeFunctional_Definitions": list(
mapMaybe(
gen_composite_functional_kernel,
structured_native_functions,
)
),
"GeneratedCompositeOut_Definitions": list(
mapMaybe(
gen_composite_out_kernel,
structured_native_functions,
)
),
},
)
def gen_declarations_yaml(
cpu_fm: FileManager, native_functions: Sequence[NativeFunction]
) -> None:
cpu_fm.write(
"Declarations.yaml",
lambda: format_yaml([compute_declaration_yaml(f) for f in native_functions]),
)
def get_torchgen_root() -> pathlib.Path:
"""
If you're depending on torchgen out-of-tree, you can use the root to figure
out the path to native_functions.yaml
"""
return pathlib.Path(__file__).parent.resolve()
def main() -> None:
parser = argparse.ArgumentParser(description="Generate ATen source files")
parser.add_argument(
"-s",
"--source-path",
help="path to source directory for ATen",
default="aten/src/ATen",
)
parser.add_argument(
"-o",
"--output-dependencies",
help="output a list of dependencies into the given file and exit",
)
parser.add_argument(
"--dry-run",
action="store_true",
help="run without writing any files (still updates outputs)",
)
parser.add_argument(
"--per-operator-headers",
action="store_true",
help="generate separate headers per operator in ATen/ops",
)
parser.add_argument(
"-d", "--install_dir", help="output directory", default="build/aten/src/ATen"
)
parser.add_argument(
"--rocm",
action="store_true",
help="reinterpret CUDA as ROCm/HIP and adjust filepaths accordingly",
)
parser.add_argument(
"--mps",
action="store_true",
help="Generate MPS registration code when set",
)
# TODO: --op_registration_whitelist will be removed when all call-sites
# for gen.py are moved over to using the operator YAML file for mobile
# custom build.
parser.add_argument(
"--op_registration_whitelist",
nargs="*",
help="filter op registrations by the whitelist (if set); "
"each item is `namespace`::`operator name` without overload name; "
"e.g.: aten::empty aten::conv2d ...",
)
parser.add_argument(
"--op_selection_yaml_path",
help="Provide a path to the operator selection (for custom build) YAML "
"that contains the information about the set of selected operators "
"and their categories (training, ...). Each operator is either a "
"full operator name with overload or just a bare operator name. "
"The operator names also contain the namespace prefix (e.g. aten::)",
)
parser.add_argument(
"--backend_whitelist",
nargs="*",
help="filter dispatch backend by the whitelist (if set), "
"e.g.: CPU CUDA QuantizedCPU ...",
)
parser.add_argument(
"--static_dispatch_backend",
nargs="*",
help="generate static dispatch code for the specific backend (if set)",
)
parser.add_argument(
"--skip_dispatcher_op_registration",
action="store_true",
help="Avoid registering operators into the dispatcher.",
)
parser.add_argument(
"--force_schema_registration",
action="store_true",
help="force it to generate schema-only registrations for all ops, including"
"those that are not listed on --op_registration_whitelist",
)
parser.add_argument(
"--generate",
type=str,
nargs="*",
choices=["headers", "sources", "declarations_yaml"],
default=["headers", "sources", "declarations_yaml"],
help="Generate only a subset of files",
)
options = parser.parse_args()
selector = get_custom_build_selector(
options.op_registration_whitelist,
options.op_selection_yaml_path,
)
native_yaml_path = os.path.join(options.source_path, "native/native_functions.yaml")
tags_yaml_path = os.path.join(options.source_path, "native/tags.yaml")
from torchgen.model import dispatch_keys
# TODO: stop generating CUDA kernels for non-CUDA builds
ignore_keys = set()
if not options.mps:
ignore_keys.add(DispatchKey.MPS)
if DispatchKey.MPS in dispatch_keys:
del dispatch_keys[dispatch_keys.index(DispatchKey.MPS)]
parsed_yaml = parse_native_yaml(native_yaml_path, tags_yaml_path, ignore_keys)
valid_tags = _GLOBAL_PARSE_TAGS_YAML_CACHE[tags_yaml_path]
native_functions, backend_indices = (
parsed_yaml.native_functions,
parsed_yaml.backend_indices,
)
grouped_native_functions = get_grouped_native_functions(native_functions)
structured_native_functions = [
g for g in grouped_native_functions if isinstance(g, NativeFunctionsGroup)
]
native_functions_with_view_groups = get_grouped_by_view_native_functions(
native_functions
)
view_groups = [
g
for g in native_functions_with_view_groups
if isinstance(g, NativeFunctionsViewGroup)
]
template_dir = os.path.join(options.source_path, "templates")
# NB: It is mandatory to NOT use os.path.join here, as the install directory
# will eventually be ingested by cmake, which does not respect Windows style
# path slashes. If you switch this to use os.path.join, you'll get an error
# like:
#
# Syntax error in cmake code when parsing string
#
# C:/Jenkins/workspace/pytorch-builds/pytorch-win-ws2016-cuda9-cudnn7-py3-build/build/aten/src/ATen\core/TensorMethods.h
#
# Invalid character escape '\c'.
core_install_dir = f"{options.install_dir}/core"
pathlib.Path(core_install_dir).mkdir(parents=True, exist_ok=True)
ops_install_dir = f"{options.install_dir}/ops"
pathlib.Path(ops_install_dir).mkdir(parents=True, exist_ok=True)
core_fm = make_file_manager(options=options, install_dir=core_install_dir)
cpu_fm = make_file_manager(options=options)
cpu_vec_fm = make_file_manager(options=options)
cuda_fm = make_file_manager(options=options)
ops_fm = make_file_manager(options=options, install_dir=ops_install_dir)
extra_cuda_headers = """\
#include <c10/cuda/CUDAGuard.h>
#include <ATen/cuda/ATenCUDAGeneral.h>
#include <ATen/cuda/CUDADevice.h>
#include <ATen/cuda/CUDAContext.h>"""
if options.rocm:
extra_cuda_headers = """\
#include <ATen/hip/impl/HIPGuardImplMasqueradingAsCUDA.h>
#include <ATen/hip/ATenHIPGeneral.h>
#include <ATen/hip/HIPDevice.h>
#include <ATen/hip/HIPContext.h>"""
# Only a limited set of dispatch keys get CPUFunctions.h headers generated
# for them; this is the set
functions_keys = {
DispatchKey.CPU,
DispatchKey.CUDA,
DispatchKey.CompositeImplicitAutograd,
DispatchKey.CompositeExplicitAutograd,
DispatchKey.CompositeExplicitAutogradNonFunctional,
DispatchKey.Meta,
}
if options.mps:
functions_keys.add(DispatchKey.MPS)
if options.backend_whitelist:
dispatch_keys = [
k
for k in dispatch_keys
if is_generic_dispatch_key(k) or str(k) in options.backend_whitelist
]
static_dispatch_idx: List[BackendIndex] = []
if options.static_dispatch_backend:
static_dispatch_idx = [
backend_indices[DispatchKey.parse(key)]
for key in options.static_dispatch_backend
]
for key in options.static_dispatch_backend:
dp_key = DispatchKey.parse(key)
if dp_key not in functions_keys:
functions_keys.add(dp_key)
if "sources" in options.generate:
gen_source_files(
native_functions=native_functions,
grouped_native_functions=grouped_native_functions,
structured_native_functions=structured_native_functions,
view_groups=view_groups,
selector=selector,
static_dispatch_idx=static_dispatch_idx,
backend_indices=backend_indices,
core_fm=core_fm,
cpu_fm=cpu_fm,
cpu_vec_fm=cpu_vec_fm,
cuda_fm=cuda_fm,
dispatch_keys=dispatch_keys,
functions_keys=functions_keys,
rocm=options.rocm,
force_schema_registration=options.force_schema_registration,
per_operator_headers=options.per_operator_headers,
skip_dispatcher_op_registration=options.skip_dispatcher_op_registration,
)
if "headers" in options.generate:
gen_headers(
native_functions=native_functions,
valid_tags=valid_tags,
grouped_native_functions=grouped_native_functions,
structured_native_functions=structured_native_functions,
static_dispatch_idx=static_dispatch_idx,
selector=selector,
backend_indices=backend_indices,
core_fm=core_fm,
cpu_fm=cpu_fm,
cuda_fm=cuda_fm,
ops_fm=ops_fm,
dispatch_keys=dispatch_keys,
functions_keys=functions_keys,
rocm=options.rocm,
per_operator_headers=options.per_operator_headers,
)
if "declarations_yaml" in options.generate:
gen_declarations_yaml(native_functions=native_functions, cpu_fm=cpu_fm)
if options.output_dependencies:
depfile_path = pathlib.Path(options.output_dependencies).resolve()
depfile_name = depfile_path.name
depfile_stem = depfile_path.stem
for fm, prefix in [
(cpu_fm, ""),
(cpu_vec_fm, "cpu_vec_"),
(core_fm, "core_"),
(cuda_fm, "cuda_"),
(ops_fm, "ops_"),
]:
varname = prefix + depfile_stem
path = depfile_path.parent / (prefix + depfile_name)
fm.write_outputs(varname, str(path))
if __name__ == "__main__":
main()