blob: 42548b9afa11d53a4202c4a941cb641cf539f93d [file] [log] [blame]
# this code should be common among cwrap and ATen preprocessing
# for now, I have put it in one place but right now is copied out of cwrap
import copy
from typing import Any, Dict, Iterable, List, Union
Arg = Dict[str, Any]
def parse_arguments(args: List[Union[str, Arg]]) -> List[Arg]:
new_args = []
for arg in args:
# Simple arg declaration of form "<type> <name>"
if isinstance(arg, str):
t, _, name = arg.partition(" ")
new_args.append({"type": t, "name": name})
elif isinstance(arg, dict):
if "arg" in arg:
arg["type"], _, arg["name"] = arg["arg"].partition(" ")
del arg["arg"]
new_args.append(arg)
else:
raise AssertionError()
return new_args
Declaration = Dict[str, Any]
def set_declaration_defaults(declaration: Declaration) -> None:
if "schema_string" not in declaration:
# This happens for legacy TH bindings like
# _thnn_conv_depthwise2d_backward
declaration["schema_string"] = ""
declaration.setdefault("arguments", [])
declaration.setdefault("return", "void")
if "cname" not in declaration:
declaration["cname"] = declaration["name"]
if "backends" not in declaration:
declaration["backends"] = ["CPU", "CUDA"]
assert "api_name" not in declaration
declaration["api_name"] = declaration["name"]
# NB: keep this in sync with gen_autograd.py
if declaration.get("overload_name"):
declaration["type_wrapper_name"] = "{}_{}".format(
declaration["name"], declaration["overload_name"]
)
else:
declaration["type_wrapper_name"] = declaration["name"]
# TODO: Uggggh, parsing the schema string here, really???
declaration["operator_name_with_overload"] = declaration["schema_string"].split(
"("
)[0]
if declaration["schema_string"]:
declaration["unqual_schema_string"] = declaration["schema_string"].split("::")[
1
]
declaration["unqual_operator_name_with_overload"] = declaration[
"operator_name_with_overload"
].split("::")[1]
else:
declaration["unqual_schema_string"] = ""
declaration["unqual_operator_name_with_overload"] = ""
# Simulate multiple dispatch, even if it's not necessary
if "options" not in declaration:
declaration["options"] = [
{
"arguments": copy.deepcopy(declaration["arguments"]),
"schema_order_arguments": copy.deepcopy(
declaration["schema_order_arguments"]
),
}
]
del declaration["arguments"]
del declaration["schema_order_arguments"]
# Parse arguments (some of them can be strings)
for option in declaration["options"]:
option["arguments"] = parse_arguments(option["arguments"])
option["schema_order_arguments"] = parse_arguments(
option["schema_order_arguments"]
)
# Propagate defaults from declaration to options
for option in declaration["options"]:
for k, v in declaration.items():
# TODO(zach): why does cwrap not propagate 'name'? I need it
# propagaged for ATen
if k != "options":
option.setdefault(k, v)
# TODO(zach): added option to remove keyword handling for C++ which cannot
# support it.
Option = Dict[str, Any]
def filter_unique_options(
options: Iterable[Option],
allow_kwarg: bool,
type_to_signature: Dict[str, str],
remove_self: bool,
) -> List[Option]:
def exclude_arg(arg: Arg) -> bool:
return arg["type"] == "CONSTANT" # type: ignore[no-any-return]
def exclude_arg_with_self_check(arg: Arg) -> bool:
return exclude_arg(arg) or (remove_self and arg["name"] == "self")
def signature(option: Option, num_kwarg_only: int) -> str:
if num_kwarg_only == 0:
kwarg_only_count = None
else:
kwarg_only_count = -num_kwarg_only
arg_signature = "#".join(
type_to_signature.get(arg["type"], arg["type"])
for arg in option["arguments"][:kwarg_only_count]
if not exclude_arg_with_self_check(arg)
)
if kwarg_only_count is None:
return arg_signature
kwarg_only_signature = "#".join(
arg["name"] + "#" + arg["type"]
for arg in option["arguments"][kwarg_only_count:]
if not exclude_arg(arg)
)
return arg_signature + "#-#" + kwarg_only_signature
seen_signatures = set()
unique = []
for option in options:
# if only check num_kwarg_only == 0 if allow_kwarg == False
limit = len(option["arguments"]) if allow_kwarg else 0
for num_kwarg_only in range(0, limit + 1):
sig = signature(option, num_kwarg_only)
if sig not in seen_signatures:
if num_kwarg_only > 0:
for arg in option["arguments"][-num_kwarg_only:]:
arg["kwarg_only"] = True
unique.append(option)
seen_signatures.add(sig)
break
return unique
def sort_by_number_of_args(declaration: Declaration, reverse: bool = True) -> None:
def num_args(option: Option) -> int:
return len(option["arguments"])
declaration["options"].sort(key=num_args, reverse=reverse)
class Function(object):
def __init__(self, name: str) -> None:
self.name = name
self.arguments: List["Argument"] = []
def add_argument(self, arg: "Argument") -> None:
assert isinstance(arg, Argument)
self.arguments.append(arg)
def __repr__(self) -> str:
return self.name + "(" + ", ".join(a.__repr__() for a in self.arguments) + ")"
class Argument(object):
def __init__(self, _type: str, name: str, is_optional: bool):
self.type = _type
self.name = name
self.is_optional = is_optional
def __repr__(self) -> str:
return self.type + " " + self.name
def parse_header(path: str) -> List[Function]:
with open(path, "r") as f:
lines: Iterable[Any] = f.read().split("\n")
# Remove empty lines and prebackend directives
lines = filter(lambda l: l and not l.startswith("#"), lines)
# Remove line comments
lines = (l.partition("//") for l in lines)
# Select line and comment part
lines = ((l[0].strip(), l[2].strip()) for l in lines)
# Remove trailing special signs
lines = ((l[0].rstrip(");").rstrip(","), l[1]) for l in lines)
# Split arguments
lines = ((l[0].split(","), l[1]) for l in lines)
# Flatten lines
new_lines = []
for l, c in lines:
for split in l:
new_lines.append((split, c))
lines = new_lines
del new_lines
# Remove unnecessary whitespace
lines = ((l[0].strip(), l[1]) for l in lines)
# Remove empty lines
lines = filter(lambda l: l[0], lines)
generic_functions = []
for l, c in lines:
if l.startswith("TH_API void THNN_"):
fn_name = l[len("TH_API void THNN_") :]
if fn_name[0] == "(" and fn_name[-2] == ")":
fn_name = fn_name[1:-2]
else:
fn_name = fn_name[:-1]
generic_functions.append(Function(fn_name))
elif l.startswith("TORCH_CUDA_CPP_API void THNN_"):
fn_name = l[len("TORCH_CUDA_CPP_API void THNN_") :]
if fn_name[0] == "(" and fn_name[-2] == ")":
fn_name = fn_name[1:-2]
else:
fn_name = fn_name[:-1]
generic_functions.append(Function(fn_name))
elif l.startswith("TORCH_CUDA_CU_API void THNN_"):
fn_name = l[len("TORCH_CUDA_CU_API void THNN_") :]
if fn_name[0] == "(" and fn_name[-2] == ")":
fn_name = fn_name[1:-2]
else:
fn_name = fn_name[:-1]
generic_functions.append(Function(fn_name))
elif l:
t, name = l.split()
if "*" in name:
t = t + "*"
name = name[1:]
generic_functions[-1].add_argument(Argument(t, name, "[OPTIONAL]" in c))
return generic_functions