blob: 6f31c09d3e08c9c64fdec586427e1448f96d3793 [file] [log] [blame]
# Generates Python bindings for ATen functions
#
# The bindings are generated as methods on python_variable or functions on the
# torch._C._nn. torch._C._fft, torch._C._linalg, torch._C._sparse or torch._C._special objects.
#
# Code tries to stick to the following rules:
#
# - templates should be colocated with the functions that use them.
# no templates are currently shared between functions, but if that
# happens, maybe put the template with the first one
#
# - don't use environment dictionaries when calling template.substitute().
# pass named arguments directly for everything, otherwise it's much too
# hard to track what's actually being used and by who
#
# - colocate any new hacks/adjustments with existing ones of the same kind.
# ideally in a data structure rather than code if possible. See e.g.
# SCHEMA_DEFAULT_CONVERSION_HACKS, etc.
#
# - similarly, conversions from one format to another should ideally happen
# all at once in a single place.
#
# - no nontrivial nested functions. couple-liners are ok but please no more.
# especially avoid functions that read/write outer variables defined far away.
#
# - raise RuntimeError instead of asserting, and put as much
# information as is available into the message. I.e. no need to
# plumb in new params whose only purpose is to fill out an error
# message, but use what's there
#
from collections import defaultdict
import itertools
import re
import yaml
from .gen_trace_type import should_trace
from tools.codegen.code_template import CodeTemplate
from tools.codegen.api import cpp
from tools.codegen.api.types import CppSignatureGroup
from tools.codegen.api.python import (
PythonArgument,
PythonSignature,
PythonSignatureDeprecated,
PythonSignatureGroup,
PythonSignatureNativeFunctionPair,
arg_parser_output_exprs,
argument_type_str,
cpp_dispatch_exprs,
cpp_dispatch_target,
dispatch_lambda_args,
dispatch_lambda_exprs,
dispatch_lambda_return_str,
has_tensor_options,
namedtuple_fieldnames,
signature,
)
from tools.codegen.gen import cpp_string, parse_native_yaml
from tools.codegen.context import with_native_function
from tools.codegen.model import (
Argument,
BaseOperatorName,
NativeFunction,
Type,
Variant,
)
from tools.codegen.utils import split_name_params, YamlLoader, FileManager
from typing import Dict, Optional, List, Tuple, Set, Sequence, Callable
#
# declarations blocklist
# We skip codegen for these functions, for various reasons.
# Future PRs will categorize this list and eliminate or hoist
# them out of eager-only codegen.
# See https://github.com/pytorch/pytorch/issues/30788
#
# These functions require manual Python bindings or are not exposed to Python
_SKIP_PYTHON_BINDINGS = [
"alias",
"contiguous",
"is_cuda",
"is_sparse",
"is_sparse_csr",
"size",
"stride",
".*_backward",
".*_backward_(out|input|weight|bias)",
".*_forward",
".*_forward_out",
"_unsafe_view",
"tensor",
"_?sparse_coo_tensor.*",
"_?sparse_csr_tensor.*",
"_arange.*",
"_range.*",
"linspace.*",
"logspace.*",
"_sparse_add_out",
"_sparse_div.*",
"_sparse_mul.*",
"_sparse_sub.*",
"_sparse_dense_add_out",
"index",
"unique_dim_consecutive",
"_cumsum.*",
"_cumprod.*",
"_sum.*",
"_prod.*",
"_th_.*",
"_thnn_.*",
"arange.*",
"range.*",
"_solve.*",
"_inverse.*",
"full(_out)?",
"_cholesky.*",
"_triangular_solve.*",
"_qr.*",
"_symeig.*",
"_svd.*",
"slice",
"randint(_out)?",
"item",
"_local_scalar_dense",
"to",
"_to_copy",
"copy_sparse_to_sparse_",
"copy_",
"numpy_T",
"matrix_H",
"mT",
"mH", # these need to be an attributes in Python, not functions
"nonzero(_(out|numpy))?",
"set_data",
".*_overrideable", # overrideable functions for backend extension
"data",
"is_leaf",
"output_nr",
"_version",
"requires_grad_",
"retains_grad",
"set_",
"_fw_primal",
"fake_quantize_per_tensor_affine_cachemask",
"fake_quantize_per_channel_affine_cachemask",
"_new_zeros_with_same_feature_meta",
"_has_same_storage_numel", # used for forward AD internals
"_reshape_alias",
"replace_", # only used by the functionalization pass, doesn't need to be exposed to python
]
SKIP_PYTHON_BINDINGS = list(
map(lambda pattern: re.compile(rf"^{pattern}$"), _SKIP_PYTHON_BINDINGS)
)
# These function signatures are not exposed to Python. Note that this signature
# list does not support regex.
SKIP_PYTHON_BINDINGS_SIGNATURES = [
"add.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> Tensor",
"add_.Scalar(Tensor(a!) self, Scalar other, Scalar alpha=1) -> Tensor(a!)",
"sub.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> Tensor",
"sub_.Scalar(Tensor(a!) self, Scalar other, Scalar alpha=1) -> Tensor(a!)",
"mul.Scalar(Tensor self, Scalar other) -> Tensor",
"mul_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)",
"div.Scalar(Tensor self, Scalar other) -> Tensor",
"div_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)",
]
@with_native_function
def should_generate_py_binding(f: NativeFunction) -> bool:
name = cpp.name(f.func)
for skip_regex in SKIP_PYTHON_BINDINGS:
if skip_regex.match(name):
return False
signature = str(f.func)
for pattern in SKIP_PYTHON_BINDINGS_SIGNATURES:
if pattern == signature:
return False
return True
def get_pycname(name: BaseOperatorName) -> str:
return f"THPVariable_{name}"
def is_noarg(overloads: Sequence[PythonSignatureNativeFunctionPair]) -> bool:
return len(overloads) == 1 and overloads[0].signature.arguments_count() == 0
def is_py_variable_method(f: NativeFunction) -> bool:
return f.python_module is None and Variant.method in f.variants
def is_py_torch_function(f: NativeFunction) -> bool:
return f.python_module is None and Variant.function in f.variants
def is_py_nn_function(f: NativeFunction) -> bool:
return f.python_module == "nn"
def is_py_fft_function(f: NativeFunction) -> bool:
return f.python_module == "fft"
def is_py_linalg_function(f: NativeFunction) -> bool:
return f.python_module == "linalg"
def is_py_sparse_function(f: NativeFunction) -> bool:
return f.python_module == "sparse"
def is_py_special_function(f: NativeFunction) -> bool:
return f.python_module == "special"
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
#
# Main Function
#
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
def gen(
out: str, native_yaml_path: str, deprecated_yaml_path: str, template_path: str
) -> None:
fm = FileManager(install_dir=out, template_dir=template_path, dry_run=False)
native_functions = parse_native_yaml(native_yaml_path).native_functions
native_functions = list(filter(should_generate_py_binding, native_functions))
methods = load_signatures(native_functions, deprecated_yaml_path, method=True)
create_python_bindings(
fm,
methods,
is_py_variable_method,
None,
"python_variable_methods.cpp",
method=True,
)
# NOTE: num_shards here must be synced with gatherTorchFunctions in
# torch/csrc/autograd/python_torch_functions_manual.cpp
functions = load_signatures(native_functions, deprecated_yaml_path, method=False)
create_python_bindings_sharded(
fm,
functions,
is_py_torch_function,
"torch",
"python_torch_functions.cpp",
method=False,
num_shards=3,
)
create_python_bindings(
fm,
functions,
is_py_nn_function,
"torch.nn",
"python_nn_functions.cpp",
method=False,
)
create_python_bindings(
fm,
functions,
is_py_fft_function,
"torch.fft",
"python_fft_functions.cpp",
method=False,
)
create_python_bindings(
fm,
functions,
is_py_linalg_function,
"torch.linalg",
"python_linalg_functions.cpp",
method=False,
)
create_python_bindings(
fm,
functions,
is_py_sparse_function,
"torch.sparse",
"python_sparse_functions.cpp",
method=False,
)
create_python_bindings(
fm,
functions,
is_py_special_function,
"torch.special",
"python_special_functions.cpp",
method=False,
)
# Currently, we only use `functions` to generate `return_types` bindings.
# All methods which return namedtuple have function variant at this point.
# If any method only operator with namedtuple is added in the future,
# we will have to address that.
create_python_return_type_bindings(
fm, functions, lambda fn: True, "python_return_types.cpp"
)
def group_filter_overloads(
pairs: Sequence[PythonSignatureNativeFunctionPair],
pred: Callable[[NativeFunction], bool],
) -> Dict[BaseOperatorName, List[PythonSignatureNativeFunctionPair]]:
grouped: Dict[
BaseOperatorName, List[PythonSignatureNativeFunctionPair]
] = defaultdict(list)
for pair in pairs:
if pred(pair.function):
grouped[pair.function.func.name.name].append(pair)
return grouped
def create_python_bindings(
fm: FileManager,
pairs: Sequence[PythonSignatureNativeFunctionPair],
pred: Callable[[NativeFunction], bool],
module: Optional[str],
filename: str,
*,
method: bool,
) -> None:
"""Generates Python bindings to ATen functions"""
py_methods: List[str] = []
ops_headers: List[str] = []
py_method_defs: List[str] = []
py_forwards: List[str] = []
grouped = group_filter_overloads(pairs, pred)
for name in sorted(grouped.keys(), key=lambda x: str(x)):
overloads = grouped[name]
py_methods.append(method_impl(name, module, overloads, method=method))
py_method_defs.append(method_def(name, module, overloads, method=method))
py_forwards.extend(forward_decls(name, overloads, method=method))
ops_headers.append(f"#include <ATen/ops/{name.base}.h>")
fm.write_with_template(
filename,
filename,
lambda: {
"generated_comment": "@" + f"generated from {fm.template_dir}/{filename}",
"ops_headers": ops_headers,
"py_forwards": py_forwards,
"py_methods": py_methods,
"py_method_defs": py_method_defs,
},
)
def create_python_return_type_bindings(
fm: FileManager,
pairs: Sequence[PythonSignatureNativeFunctionPair],
pred: Callable[[NativeFunction], bool],
filename: str,
) -> None:
"""
Generate function to initialize and return named tuple for native functions
which returns named tuple and relevant entry for the map in `python_return_types.cpp`.
"""
py_return_types_definition: List[str] = []
py_return_types_map: List[str] = []
grouped = group_filter_overloads(pairs, pred)
for name in sorted(grouped.keys(), key=lambda x: str(x)):
overloads = grouped[name]
definitions, map_entries = generate_return_type_definition_and_map_entry(
overloads
)
py_return_types_definition.append(
"" if not definitions else "\n".join(definitions)
)
py_return_types_map.append("" if not map_entries else "\n".join(map_entries))
fm.write_with_template(
filename,
filename,
lambda: {
"generated_comment": "@" + f"generated from {fm.template_dir}/{filename}",
"py_return_types": py_return_types_definition,
"py_return_types_map": py_return_types_map,
},
)
def create_python_bindings_sharded(
fm: FileManager,
pairs: Sequence[PythonSignatureNativeFunctionPair],
pred: Callable[[NativeFunction], bool],
module: Optional[str],
filename: str,
*,
method: bool,
num_shards: int,
) -> None:
"""Generates Python bindings to ATen functions"""
grouped = group_filter_overloads(pairs, pred)
def key_func(
kv: Tuple[BaseOperatorName, List[PythonSignatureNativeFunctionPair]]
) -> str:
return kv[0].base
def env_func(
kv: Tuple[BaseOperatorName, List[PythonSignatureNativeFunctionPair]]
) -> Dict[str, List[str]]:
name, fn_pairs = kv
return {
"ops_headers": [f"#include <ATen/ops/{name.base}.h>"],
"py_forwards": list(forward_decls(name, fn_pairs, method=method)),
"py_methods": [method_impl(name, module, fn_pairs, method=method)],
"py_method_defs": [method_def(name, module, fn_pairs, method=method)],
}
fm.write_sharded(
filename,
grouped.items(),
base_env={
"generated_comment": "@" + f"generated from {fm.template_dir}/{filename}",
},
key_fn=key_func,
env_callable=env_func,
num_shards=num_shards,
sharded_keys={"ops_headers", "py_forwards", "py_methods", "py_method_defs"},
)
def load_signatures(
native_functions: List[NativeFunction],
deprecated_yaml_path: str,
*,
method: bool,
skip_deprecated: bool = False,
pyi: bool = False,
) -> Sequence[PythonSignatureNativeFunctionPair]:
@with_native_function
def gen_signature_pairs(f: NativeFunction) -> PythonSignatureNativeFunctionPair:
return PythonSignatureNativeFunctionPair(
signature=signature(f, method=method, pyi=pyi),
function=f,
)
pairs = list(map(gen_signature_pairs, native_functions))
deprecated = load_deprecated_signatures(
pairs, deprecated_yaml_path, method=method, pyi=pyi
)
return pairs if skip_deprecated else pairs + deprecated
def load_deprecated_signatures(
pairs: Sequence[PythonSignatureNativeFunctionPair],
deprecated_yaml_path: str,
*,
method: bool,
pyi: bool,
) -> List[PythonSignatureNativeFunctionPair]:
# The deprecated.yaml doesn't have complete type information, we need
# find and leverage the original ATen signature (to which it delegates
# the call) to generate the full python signature.
# We join the deprecated and the original signatures using type-only form.
# native function -> type-only signature
@with_native_function
def signature_original(f: NativeFunction) -> str:
# remove inplace suffix but keep outplace suffix
opname = str(f.func.name.name.base)
if f.func.is_out_fn():
opname += "_out"
if f.func.name.name.inplace and pyi:
opname += "_"
args = CppSignatureGroup.from_native_function(
f, method=False
).signature.arguments()
# Simply ignore TensorOptionsArguments as it does not exist in deprecated.yaml.
types = ", ".join(
argument_type_str(a.argument.type)
for a in args
if isinstance(a.argument, Argument)
)
return f"{opname}({types})"
# deprecated -> type-only native signature (according to the call order)
def signature_deprecated(
opname: str, params: List[str], call_args: List[str]
) -> str:
# create a mapping of parameter name to parameter type
types: Dict[str, str] = {}
for param in params:
if param == "*":
continue
type, name = param.split(" ")
types[name] = type
# if the name in the call is not in the parameter list, assume it's
# a literal Scalar
rearranged_types = ", ".join(types.get(arg, "Scalar") for arg in call_args)
return f"{opname}({rearranged_types})"
# group the original ATen signatures by type-only signature
grouped: Dict[str, List[PythonSignatureNativeFunctionPair]] = defaultdict(list)
for pair in pairs:
grouped[signature_original(pair.function)].append(pair)
# find matching original signatures for each deprecated signature
results: List[PythonSignatureNativeFunctionPair] = []
with open(deprecated_yaml_path, "r") as f:
deprecated_defs = yaml.load(f, Loader=YamlLoader)
for deprecated in deprecated_defs:
_, params = split_name_params(deprecated["name"])
aten_name, call_args = split_name_params(deprecated["aten"])
for pair in grouped[signature_deprecated(aten_name, params, call_args)]:
# It uses the types from the original ATen declaration, but the
# ordering and parameter names from the deprecated overload. Any
# default parameter values from the original ATen declaration are
# ignored.
# Deprecated signature might reorder input_args and input_kwargs,
# but never changes output_args nor TensorOptions (if any?),
# so here we only look into these two types of args.
python_sig = pair.signature
src_args: Dict[str, PythonArgument] = {
a.name: PythonArgument(
name=a.name,
type=a.type,
default=None,
default_init=None,
)
for a in itertools.chain(python_sig.input_args, python_sig.input_kwargs)
}
args: List[str] = []
input_args: List[PythonArgument] = []
input_kwargs: List[PythonArgument] = []
kwarg_only = False
for param in params:
if param == "*":
kwarg_only = True
continue
_, param_name = param.split(" ")
args.append(param_name)
if param_name not in src_args:
# output argument
continue
if not kwarg_only:
if not method or param_name != "self":
input_args.append(src_args[param_name])
else:
input_kwargs.append(src_args[param_name])
results.append(
PythonSignatureNativeFunctionPair(
signature=PythonSignatureDeprecated(
name=python_sig.name,
input_args=tuple(input_args),
input_kwargs=tuple(input_kwargs),
output_args=python_sig.output_args,
tensor_options_args=python_sig.tensor_options_args,
method=python_sig.method,
deprecated_args_names=tuple(args),
deprecated_args_exprs=tuple(call_args),
returns=python_sig.returns,
),
function=pair.function,
)
)
return results
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
#
# Named Tuple Codegen
#
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
@with_native_function
def gen_namedtuple_typename_key(f: NativeFunction) -> str:
name = cpp.name(f.func)
fieldnames = namedtuple_fieldnames(f.func.returns)
return "_".join([name] + fieldnames)
def emit_namedtuple_call(
overloads: Sequence[PythonSignatureNativeFunctionPair],
) -> Tuple[List[str], Dict[str, str]]:
"""
Generate block of named tuple type def inits, and add typeref snippets
to declarations that use them
"""
typenames: Dict[
str, str
] = {} # map from unique name + field name lists to typedef name
typedefs: List[str] = [] # typedef declarations and init code
for overload in overloads:
fieldnames = namedtuple_fieldnames(overload.function.func.returns)
if not fieldnames:
continue
name = cpp.name(overload.function.func) # use @with_native_function?
tn_key = gen_namedtuple_typename_key(overload.function)
typename = typenames.get(tn_key)
if typename is None:
typename = f'NamedTuple{"" if not typedefs else len(typedefs)}'
typenames[tn_key] = typename
typedefs.append(
f"""\
static PyTypeObject* {typename} = get_namedtuple("{name}");"""
)
return typedefs, typenames
def generate_return_type_definition_and_map_entry(
overloads: Sequence[PythonSignatureNativeFunctionPair],
) -> Tuple[List[str], List[str]]:
"""
Generate block of function in `python_return_types.cpp` to initialize
and return named tuple for a native function which returns named tuple
and relevant entry for the map in same file.
"""
typenames: Dict[
str, str
] = {} # map from unique name + field name lists to typedef name
definitions: List[str] = [] # function defintion to register the typedef
map_entries: List[
str
] = [] # C++ map entry of <function_name, function creates it namedtuple>
for overload in overloads:
fieldnames = namedtuple_fieldnames(overload.function.func.returns)
if not fieldnames:
continue
fields = ", ".join(f'{{"{fn}", ""}}' for fn in fieldnames)
name = cpp.name(overload.function.func) # use @with_native_function?
tn_key = gen_namedtuple_typename_key(overload.function)
typename = typenames.get(tn_key)
if typename is None:
typename = f'{name}NamedTuple{"" if not definitions else len(definitions)}'
typenames[tn_key] = typename
definitions.append(
f"""\
PyTypeObject* get_{name}_namedtuple() {{
static PyStructSequence_Field NamedTuple_fields[] = {{ {fields}, {{nullptr}} }};
static PyTypeObject {typename};
static bool is_initialized = false;
static PyStructSequence_Desc desc = {{ "torch.return_types.{name}", nullptr, NamedTuple_fields, {len(fieldnames)} }};
if (!is_initialized) {{
PyStructSequence_InitType(&{typename}, &desc);
{typename}.tp_repr = (reprfunc)torch::utils::returned_structseq_repr;
is_initialized = true;
}}
return &{typename};
}}
"""
)
map_entries.append(f'{{"{name}", get_{name}_namedtuple()}}, ')
return definitions, map_entries
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
#
# Method Impl Codegen
#
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
# python binding for all overloads of a particular function/method
PY_VARIABLE_METHOD_VARARGS = CodeTemplate(
r"""\
// ${name}
static PyObject * ${pycname}(PyObject* self_, PyObject* args, PyObject* kwargs)
{
${method_header}
static PythonArgParser parser({
${signatures}
}, /*traceable=*/${traceable});
ParsedArgs<${max_args}> parsed_args;
auto _r = parser.parse(${self_}, args, kwargs, parsed_args);
${check_has_torch_function}
switch (_r.idx) {
${dispatch}
}
${method_footer}
}
"""
)
# handler for a single parsed signature - may be a single overload or
# a pair of overloads that whose signatures only differ in output params
# (plugged into PY_VARIABLE_METHOD_VARARGS as an item in ${dispatch})
PY_VARIABLE_CASE = CodeTemplate(
"""\
case ${overload_index}: {
${body}
}
"""
)
# python binding for single-overload function/method
PY_VARIABLE_METHOD_VARARGS_SINGLETON = CodeTemplate(
"""\
// ${name}
static PyObject * ${pycname}(PyObject* self_, PyObject* args, PyObject* kwargs)
{
${method_header}
static PythonArgParser parser({
${signatures}
}, /*traceable=*/${traceable});
ParsedArgs<${max_args}> parsed_args;
auto _r = parser.parse(${self_}, args, kwargs, parsed_args);
${check_has_torch_function}
${dispatch}
${method_footer}
}
"""
)
# python binding for a method with no args, shortcuts parsing
PY_VARIABLE_METHOD_NOARGS = CodeTemplate(
"""\
// ${name}
static PyObject * ${pycname}(PyObject* self_, PyObject* args)
{
${method_header}
${check_has_torch_function}
${dispatch}
${method_footer}
}
"""
)
def method_impl(
name: BaseOperatorName,
module: Optional[str],
overloads: Sequence[PythonSignatureNativeFunctionPair],
*,
method: bool,
) -> str:
"""
Generate a python binding for all overloads of an op.
"""
pycname = get_pycname(name)
noarg = is_noarg(overloads)
namedtuple_inits, namedtuple_typenames = emit_namedtuple_call(overloads)
method_header = ["HANDLE_TH_ERRORS"]
method_header += namedtuple_inits
method_header += (
["const Tensor& self = THPVariable_Unpack(self_);"] if method else []
)
method_footer = ([] if noarg else ["Py_RETURN_NONE;"]) + ["END_HANDLE_TH_ERRORS"]
traceable = "true" if all(should_trace(o.function) for o in overloads) else "false"
grouped_overloads: Sequence[PythonSignatureGroup] = group_overloads(overloads)
is_singleton = len(grouped_overloads) == 1
signatures: List[str] = []
dispatch: List[str] = []
for overload_index, overload in enumerate(grouped_overloads):
signature = overload.signature.signature_str()
signatures.append(f"{cpp_string(str(signature))},")
dispatch_body = emit_dispatch_case(overload, namedtuple_typenames)
dispatch.append(
PY_VARIABLE_CASE.substitute(
overload_index=overload_index, body=dispatch_body
)
if not is_singleton
else dispatch_body
)
if noarg:
template = PY_VARIABLE_METHOD_NOARGS
elif is_singleton:
template = PY_VARIABLE_METHOD_VARARGS_SINGLETON
else:
template = PY_VARIABLE_METHOD_VARARGS
return template.substitute(
name=name,
pycname=pycname,
method_header=method_header,
max_args=max(map(lambda o: o.signature.arguments_count(), overloads)),
signatures=signatures,
traceable=traceable,
check_has_torch_function=gen_has_torch_function_check(
name=name,
module=module,
noarg=noarg,
method=method,
),
dispatch=dispatch,
method_footer=method_footer,
self_="self_" if method else "nullptr",
)
def gen_has_torch_function_check(
name: BaseOperatorName, module: Optional[str], *, noarg: bool, method: bool
) -> str:
if noarg:
if method:
return f"""\
if(check_has_torch_function(self_)) {{
return handle_torch_function(self_, "{name}");
}}
"""
else:
return ""
self_ = "self_" if method else "nullptr"
namespace = (
{
"torch": "THPVariableFunctionsModule",
"torch.nn": "THPNNVariableFunctionsModule",
"torch.fft": "THPFFTVariableFunctionsModule",
"torch.linalg": "THPLinalgVariableFunctionsModule",
"torch.sparse": "THPSparseVariableFunctionsModule",
"torch.special": "THPSpecialVariableFunctionsModule",
}[module]
if module
else "THPVariableClass"
)
return f"""\
if(_r.has_torch_function()) {{
return handle_torch_function(_r, {self_}, args, kwargs, {namespace}, "{module or "torch.Tensor"}");
}}
"""
# handler for output/no-output overload pair
PY_VARIABLE_OUT = CodeTemplate(
"""\
if (_r.isNone(${out_idx})) {
${call_dispatch}
} else {
${call_dispatch_out}
}
"""
)
def emit_dispatch_case(
overload: PythonSignatureGroup,
namedtuple_typenames: Dict[str, str],
) -> str:
"""
Emit dispatch code for a single parsed signature. This corresponds to either
a single native function, or a pair that differ only in output params. In the
latter case, a single python signature is used for both and dispatching
switches on the presence/absence of passed output args.
"""
if overload.outplace is not None:
# dispatch output and no-output variants, branch on _r.isNone(<out_idx>)
return PY_VARIABLE_OUT.substitute(
out_idx=overload.signature.output_idx(),
call_dispatch=emit_single_dispatch(
overload.signature, overload.base, namedtuple_typenames
),
call_dispatch_out=emit_single_dispatch(
overload.signature, overload.outplace, namedtuple_typenames
),
)
else:
# no-output version only
return emit_single_dispatch(
overload.signature, overload.base, namedtuple_typenames
)
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
#
# Forward Declarations Codegen
#
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
def forward_decls(
name: BaseOperatorName,
overloads: Sequence[PythonSignatureNativeFunctionPair],
*,
method: bool,
) -> Tuple[str, ...]:
if method:
return ()
pycname = get_pycname(name)
if is_noarg(overloads):
return (
f"""\
static PyObject * {pycname}(PyObject* self_, PyObject* args);
""",
)
else:
return (
f"""\
static PyObject * {pycname}(PyObject* self_, PyObject* args, PyObject* kwargs);
""",
)
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
#
# Method Def (Binding Table Entry) Codegen
#
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
def method_def(
name: BaseOperatorName,
module: Optional[str],
overloads: Sequence[PythonSignatureNativeFunctionPair],
*,
method: bool,
) -> str:
"""
Generate method def entry.
"""
pycname = get_pycname(name)
if is_noarg(overloads):
pyfunc_cast = ""
flags = "METH_NOARGS" if method else "METH_VARARGS | METH_KEYWORDS"
else:
pyfunc_cast = "castPyCFunctionWithKeywords"
flags = "METH_VARARGS | METH_KEYWORDS"
if module == "torch":
flags += " | METH_STATIC"
if name.dunder_method:
# PyMethodDef entry for binary op, throws not implemented error
return f"""\
{{"{name}", {pyfunc_cast}(TypeError_to_NotImplemented_<{pycname}>), {flags}, NULL}},"""
else:
# PyMethodDef entry
return f"""\
{{"{name}", {pyfunc_cast}({pycname}), {flags}, NULL}},"""
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
#
# Overload Sorting and Grouping
#
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
def group_overloads(
overloads: Sequence[PythonSignatureNativeFunctionPair],
) -> Sequence[PythonSignatureGroup]:
bases: Dict[str, PythonSignatureNativeFunctionPair] = {}
outplaces: Dict[str, PythonSignatureNativeFunctionPair] = {}
# first group by signature ignoring out arguments
for overload in overloads:
sig = overload.signature.signature_str(skip_outputs=True)
if overload.function.func.is_out_fn():
if sig in outplaces:
raise RuntimeError(
f"Found duplicated function definition:\n- {overload.function.func}.\n"
f"Existing definition:\n- {outplaces[sig].function.func}."
)
outplaces[sig] = overload
else:
if sig in bases:
raise RuntimeError(
f"Found duplicated function definition:\n- {overload.function.func}.\n"
f"Existing definition:\n- {bases[sig].function.func}."
)
bases[sig] = overload
for sig, out in outplaces.items():
if sig not in bases:
candidates: List[str] = []
for overload in overloads:
if (
str(overload.function.func.name.name)
== str(out.function.func.name.name)
and not overload.function.func.is_out_fn()
and not overload.signature.deprecated
):
candidates.append(
overload.signature.signature_str(skip_outputs=True)
)
out_sig = out.signature.signature_str()
raise RuntimeError(
f"While identifying overloads, we found an out schema {out_sig} without a corresponding non-out variant. "
f"We expected the non-out variant to have schema: \n- {sig}\nPlease check that you spelled the schema "
"correctly in native_functions.yaml. We discovered the following candidate(s): \n"
+ "\n".join(f"- {candidate}" for candidate in candidates)
)
grouped: List[PythonSignatureGroup] = []
for sig, base in bases.items():
outplace = outplaces.get(sig)
grouped.append(
PythonSignatureGroup(
# prefer the signature with optional out=... arguments because it's the
# superset that can be used to parse input for both base and outplace.
signature=outplace.signature
if outplace is not None
else base.signature,
base=base.function,
outplace=outplace.function if outplace is not None else None,
)
)
return sort_overloads(grouped)
# This function declares a partial order on declarations, and sorts them according
# to its linear extension. This is necessary, because there's some ambiguity in the
# choice of overload, and we want a different order.
#
# See Note[Order of overloads matters]
#
# A few examples of ambiguous python signature pairs.
#
# All parameters have the same type, except one taking Tensor the other taking
# Scalar. A numeric PyObject can be casted into Tensor, and a zero-dim Tensor
# object can be accepted as Scalar type parameter (see python_arg_parser.cpp).
# Therefore, same input arguments might be accepted by either python signature.
# We want to always parse the one taking Tensor first.
#
# bitwise_and(Tensor input, Tensor other, *, Tensor out=None)
# bitwise_and(Tensor input, Scalar other, *, Tensor out=None)
#
# If they have different number of parameters then they are not ambiguous - but
# the difference on output param can be ignored as it's optional.
#
# multiply(Tensor input, Tensor other, *, Tensor out=None)
# multiply(Tensor input, Scalar other)
#
# Both positional args and keyword-only args are considered together.
#
# subtract(Tensor other, *, Scalar alpha=1)
# subtract(Scalar other, Scalar alpha=1)
#
# A few ambiguous cases which it does NOT handle yet.
#
# If there is any difference in other parameters besides the Tensor/Scalar
# difference, then they are not considered ambiguous by this method anymore.
# However, the difference could be too trivial to disambiguate.
#
# foo(Tensor input, Scalar other, Scalar bar)
# foo(Tensor input, Tensor other, double bar)
#
# If they are taking different number of parameters then they are not considered
# ambiguous anymore, even if the difference is only on optional kwargs.
#
# foo(Scalar other, Scalar alpha=1)
# foo(Tensor other, *, Scalar alpha=1, Scalar beta=1)
#
def sort_overloads(
grouped_overloads: Sequence[PythonSignatureGroup],
) -> Sequence[PythonSignatureGroup]:
def is_arg_smaller(t1: Type, t2: Type) -> bool:
return (
str(t1) == "Scalar"
and str(t2) == "Tensor"
or "Dimname" in str(t1)
and "Dimname" not in str(t2)
or
# In the discussion https://github.com/pytorch/pytorch/issues/54555 it has been
# discussed why it is important to prioritize int/int? over int[]
str(t1) == "int[]"
and (str(t2) == "int" or str(t2) == "int?")
or
# TensorList currently throws an error during argument parsing, that's why it needs to be
# last in signature ordering. See discussion: https://github.com/pytorch/pytorch/issues/58087
str(t1) == "Tensor[]"
and str(t2).find("[]") != -1
)
def is_smaller(s1: PythonSignature, s2: PythonSignature) -> bool:
"""Returns True if s1 < s2 in the partial order."""
args1, args2 = s1.arguments(skip_outputs=True), s2.arguments(skip_outputs=True)
if len(args1) != len(args2):
return False
# TODO: should use some canonical form instead of 'str(arg.type)' - see comments
# above. The old codegen used the deprecated 'dynamic_type(arg.type)', which
# ignores the optional annotation, i.e. 'Scalar' and 'Scalar?'.
equal = all(arg1.type == arg2.type for arg1, arg2 in zip(args1, args2))
smaller_or_equal = all(
str(arg1.type) == str(arg2.type) or is_arg_smaller(arg1.type, arg2.type)
for arg1, arg2 in zip(args1, args2)
)
return smaller_or_equal and not equal
# First sort by signature
grouped_overloads = sorted(
grouped_overloads, key=lambda x: x.signature.signature_str()
)
# Construct the relation graph
larger_than: Dict[int, Set[int]] = defaultdict(set)
for i1, overload1 in enumerate(grouped_overloads):
for i2, overload2 in enumerate(grouped_overloads):
if is_smaller(overload1.signature, overload2.signature):
larger_than[i1].add(i2)
if not larger_than:
return list(grouped_overloads)
# Use a topological sort to sort overloads according to the partial order.
N = len(grouped_overloads)
sorted_ids: List[int] = list(filter(lambda x: x not in larger_than, range(N)))
for idx in range(N):
# The size of sorted_ids will grow to N eventually.
i = sorted_ids[idx]
for j in sorted(larger_than.keys()):
larger = larger_than[j]
larger.discard(i)
if not larger:
del larger_than[j]
sorted_ids.append(j)
return list(map(lambda x: grouped_overloads[x], sorted_ids))
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
#
# Codegen API Integration
#
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
def emit_single_dispatch(
ps: PythonSignature, f: NativeFunction, namedtuple_typenames: Dict[str, str]
) -> str:
"""
Emit dispatch code for a single native function.
"""
@with_native_function
def go(f: NativeFunction) -> str:
# header comments
deprecated = "[deprecated] " if ps.deprecated else ""
schema_comment = f"// {deprecated}aten::{f.func}"
# dispatch lambda signature
name = cpp.name(f.func)
lambda_formals = ", ".join(
map(lambda a: f"{a.type_str} {a.name}", dispatch_lambda_args(ps, f))
)
lambda_return = dispatch_lambda_return_str(f)
# dispatch lambda body
dispatch_callee = cpp_dispatch_target(f)
dispatch_args = ", ".join(cpp_dispatch_exprs(f, python_signature=ps))
# from arg parser outputs to dispatch lambda arguments
parser_outputs = arg_parser_output_exprs(ps, f)
lambda_arg_exprs = dispatch_lambda_exprs(ps, f)
inits = "\n".join(lambda_arg_exprs.inits)
lambda_args = ", ".join(lambda_arg_exprs.exprs)
# scatter fields
# TODO: Checking `ps.method and ('requires_grad' in parser_outputs)` is a hacky
# solution for enabling the 'requires_grad' argument for tensor methods
# new_full, new_empty, and new_zeros. A much better but more difficult to
# implement solution involves refactoring according to Ed's description here:
# https://github.com/pytorch/pytorch/issues/36455#issuecomment-614767589
need_set_requires_grad = ps.tensor_options_args and (
not has_tensor_options(f)
or (ps.method and ("requires_grad" in parser_outputs))
)
set_requires_grad = (
f'.set_requires_grad({parser_outputs["requires_grad"].expr})'
if need_set_requires_grad
else ""
)
if lambda_return == "void":
return f"""\
{schema_comment}
{inits}
auto dispatch_{name} = []({lambda_formals}) -> {lambda_return} {{
pybind11::gil_scoped_release no_gil;
{dispatch_callee}({dispatch_args});
}};
dispatch_{name}({lambda_args}){set_requires_grad};
Py_RETURN_NONE;
"""
else:
typename = namedtuple_typenames.get(gen_namedtuple_typename_key(f))
namedtuple_typeref = f"{typename}, " if typename is not None else ""
return f"""\
{schema_comment}
{inits}
auto dispatch_{name} = []({lambda_formals}) -> {lambda_return} {{
pybind11::gil_scoped_release no_gil;
return {dispatch_callee}({dispatch_args});
}};
return wrap({namedtuple_typeref}dispatch_{name}({lambda_args}){set_requires_grad});
"""
return go(f)