| # this code should be common among cwrap and ATen preprocessing |
| # for now, I have put it in one place but right now is copied out of cwrap |
| |
| import copy |
| from typing import Any, Dict, Iterable, List, Union |
| |
| Arg = Dict[str, Any] |
| |
| |
| def parse_arguments(args: List[Union[str, Arg]]) -> List[Arg]: |
| new_args = [] |
| for arg in args: |
| # Simple arg declaration of form "<type> <name>" |
| if isinstance(arg, str): |
| t, _, name = arg.partition(" ") |
| new_args.append({"type": t, "name": name}) |
| elif isinstance(arg, dict): |
| if "arg" in arg: |
| arg["type"], _, arg["name"] = arg["arg"].partition(" ") |
| del arg["arg"] |
| new_args.append(arg) |
| else: |
| raise AssertionError() |
| return new_args |
| |
| |
| Declaration = Dict[str, Any] |
| |
| |
| def set_declaration_defaults(declaration: Declaration) -> None: |
| if "schema_string" not in declaration: |
| # This happens for legacy TH bindings like |
| # _thnn_conv_depthwise2d_backward |
| declaration["schema_string"] = "" |
| declaration.setdefault("arguments", []) |
| declaration.setdefault("return", "void") |
| if "cname" not in declaration: |
| declaration["cname"] = declaration["name"] |
| if "backends" not in declaration: |
| declaration["backends"] = ["CPU", "CUDA"] |
| assert "api_name" not in declaration |
| declaration["api_name"] = declaration["name"] |
| # NB: keep this in sync with gen_autograd.py |
| if declaration.get("overload_name"): |
| declaration["type_wrapper_name"] = "{}_{}".format( |
| declaration["name"], declaration["overload_name"] |
| ) |
| else: |
| declaration["type_wrapper_name"] = declaration["name"] |
| # TODO: Uggggh, parsing the schema string here, really??? |
| declaration["operator_name_with_overload"] = declaration["schema_string"].split( |
| "(" |
| )[0] |
| if declaration["schema_string"]: |
| declaration["unqual_schema_string"] = declaration["schema_string"].split("::")[ |
| 1 |
| ] |
| declaration["unqual_operator_name_with_overload"] = declaration[ |
| "operator_name_with_overload" |
| ].split("::")[1] |
| else: |
| declaration["unqual_schema_string"] = "" |
| declaration["unqual_operator_name_with_overload"] = "" |
| # Simulate multiple dispatch, even if it's not necessary |
| if "options" not in declaration: |
| declaration["options"] = [ |
| { |
| "arguments": copy.deepcopy(declaration["arguments"]), |
| "schema_order_arguments": copy.deepcopy( |
| declaration["schema_order_arguments"] |
| ), |
| } |
| ] |
| del declaration["arguments"] |
| del declaration["schema_order_arguments"] |
| # Parse arguments (some of them can be strings) |
| for option in declaration["options"]: |
| option["arguments"] = parse_arguments(option["arguments"]) |
| option["schema_order_arguments"] = parse_arguments( |
| option["schema_order_arguments"] |
| ) |
| # Propagate defaults from declaration to options |
| for option in declaration["options"]: |
| for k, v in declaration.items(): |
| # TODO(zach): why does cwrap not propagate 'name'? I need it |
| # propagaged for ATen |
| if k != "options": |
| option.setdefault(k, v) |
| |
| |
| # TODO(zach): added option to remove keyword handling for C++ which cannot |
| # support it. |
| |
| Option = Dict[str, Any] |
| |
| |
| def filter_unique_options( |
| options: Iterable[Option], |
| allow_kwarg: bool, |
| type_to_signature: Dict[str, str], |
| remove_self: bool, |
| ) -> List[Option]: |
| def exclude_arg(arg: Arg) -> bool: |
| return arg["type"] == "CONSTANT" # type: ignore[no-any-return] |
| |
| def exclude_arg_with_self_check(arg: Arg) -> bool: |
| return exclude_arg(arg) or (remove_self and arg["name"] == "self") |
| |
| def signature(option: Option, num_kwarg_only: int) -> str: |
| if num_kwarg_only == 0: |
| kwarg_only_count = None |
| else: |
| kwarg_only_count = -num_kwarg_only |
| arg_signature = "#".join( |
| type_to_signature.get(arg["type"], arg["type"]) |
| for arg in option["arguments"][:kwarg_only_count] |
| if not exclude_arg_with_self_check(arg) |
| ) |
| if kwarg_only_count is None: |
| return arg_signature |
| kwarg_only_signature = "#".join( |
| arg["name"] + "#" + arg["type"] |
| for arg in option["arguments"][kwarg_only_count:] |
| if not exclude_arg(arg) |
| ) |
| return arg_signature + "#-#" + kwarg_only_signature |
| |
| seen_signatures = set() |
| unique = [] |
| for option in options: |
| # if only check num_kwarg_only == 0 if allow_kwarg == False |
| limit = len(option["arguments"]) if allow_kwarg else 0 |
| for num_kwarg_only in range(0, limit + 1): |
| sig = signature(option, num_kwarg_only) |
| if sig not in seen_signatures: |
| if num_kwarg_only > 0: |
| for arg in option["arguments"][-num_kwarg_only:]: |
| arg["kwarg_only"] = True |
| unique.append(option) |
| seen_signatures.add(sig) |
| break |
| return unique |
| |
| |
| def sort_by_number_of_args(declaration: Declaration, reverse: bool = True) -> None: |
| def num_args(option: Option) -> int: |
| return len(option["arguments"]) |
| |
| declaration["options"].sort(key=num_args, reverse=reverse) |
| |
| |
| class Function(object): |
| def __init__(self, name: str) -> None: |
| self.name = name |
| self.arguments: List["Argument"] = [] |
| |
| def add_argument(self, arg: "Argument") -> None: |
| assert isinstance(arg, Argument) |
| self.arguments.append(arg) |
| |
| def __repr__(self) -> str: |
| return self.name + "(" + ", ".join(a.__repr__() for a in self.arguments) + ")" |
| |
| |
| class Argument(object): |
| def __init__(self, _type: str, name: str, is_optional: bool): |
| self.type = _type |
| self.name = name |
| self.is_optional = is_optional |
| |
| def __repr__(self) -> str: |
| return self.type + " " + self.name |
| |
| |
| def parse_header(path: str) -> List[Function]: |
| with open(path, "r") as f: |
| lines: Iterable[Any] = f.read().split("\n") |
| |
| # Remove empty lines and prebackend directives |
| lines = filter(lambda l: l and not l.startswith("#"), lines) |
| # Remove line comments |
| lines = (l.partition("//") for l in lines) |
| # Select line and comment part |
| lines = ((l[0].strip(), l[2].strip()) for l in lines) |
| # Remove trailing special signs |
| lines = ((l[0].rstrip(");").rstrip(","), l[1]) for l in lines) |
| # Split arguments |
| lines = ((l[0].split(","), l[1]) for l in lines) |
| # Flatten lines |
| new_lines = [] |
| for l, c in lines: |
| for split in l: |
| new_lines.append((split, c)) |
| lines = new_lines |
| del new_lines |
| # Remove unnecessary whitespace |
| lines = ((l[0].strip(), l[1]) for l in lines) |
| # Remove empty lines |
| lines = filter(lambda l: l[0], lines) |
| generic_functions = [] |
| for l, c in lines: |
| if l.startswith("TH_API void THNN_"): |
| fn_name = l[len("TH_API void THNN_") :] |
| if fn_name[0] == "(" and fn_name[-2] == ")": |
| fn_name = fn_name[1:-2] |
| else: |
| fn_name = fn_name[:-1] |
| generic_functions.append(Function(fn_name)) |
| elif l.startswith("TORCH_CUDA_CPP_API void THNN_"): |
| fn_name = l[len("TORCH_CUDA_CPP_API void THNN_") :] |
| if fn_name[0] == "(" and fn_name[-2] == ")": |
| fn_name = fn_name[1:-2] |
| else: |
| fn_name = fn_name[:-1] |
| generic_functions.append(Function(fn_name)) |
| elif l.startswith("TORCH_CUDA_CU_API void THNN_"): |
| fn_name = l[len("TORCH_CUDA_CU_API void THNN_") :] |
| if fn_name[0] == "(" and fn_name[-2] == ")": |
| fn_name = fn_name[1:-2] |
| else: |
| fn_name = fn_name[:-1] |
| generic_functions.append(Function(fn_name)) |
| elif l: |
| t, name = l.split() |
| if "*" in name: |
| t = t + "*" |
| name = name[1:] |
| generic_functions[-1].add_argument(Argument(t, name, "[OPTIONAL]" in c)) |
| return generic_functions |