| # mypy: allow-untyped-defs |
| import dataclasses |
| import inspect |
| import sys |
| from typing import Any, Callable, Dict, Iterable, Tuple |
| |
| import torch |
| import torch._utils_internal as _utils_internal |
| from torch import _C |
| |
| |
| @dataclasses.dataclass |
| class Kernel: |
| """Models a (function, source location)""" |
| |
| func: Callable |
| source: str |
| |
| def __call__(self, *args, **kwargs): |
| return self.func(*args, **kwargs) |
| |
| |
| class RegistrationHandle: |
| """Does something when someone calls .destroy() on it""" |
| |
| def __init__(self, on_destroy: Callable): |
| self._on_destroy = on_destroy |
| |
| def destroy(self) -> None: |
| self._on_destroy() |
| |
| |
| def get_source(stacklevel: int) -> str: |
| """Get a string that represents the caller. |
| |
| Example: "/path/to/foo.py:42" |
| |
| Use stacklevel=1 to get the caller's source |
| Use stacklevel=2 to get the caller's caller's source |
| etc. |
| """ |
| frame = inspect.getframeinfo(sys._getframe(stacklevel)) |
| source = f"{frame.filename}:{frame.lineno}" |
| return source |
| |
| |
| def parse_namespace(qualname: str) -> Tuple[str, str]: |
| splits = qualname.split("::") |
| if len(splits) != 2: |
| raise ValueError( |
| f"Expected `qualname` to be of the form " |
| f'"namespace::name", but got {qualname}. ' |
| f"The qualname passed to the torch.library APIs must consist " |
| f"of a namespace and a name, e.g. aten::sin" |
| ) |
| return splits[0], splits[1] |
| |
| |
| def lookup_op(qualname: str) -> torch._ops.OpOverload: |
| namespace, name = parse_namespace(qualname) |
| if "." in name: |
| name, overload = name.split(".") |
| else: |
| overload = "default" |
| ns = getattr(torch.ops, namespace) |
| packet = getattr(ns, name) |
| return getattr(packet, overload) |
| |
| |
| def is_builtin(op: torch._ops.OpOverload) -> bool: |
| assert isinstance(op, torch._ops.OpOverload) |
| return op.namespace in {"aten", "prim", "prims"} |
| |
| |
| def is_functional_schema(schema: Any) -> bool: |
| """Check if the schema is functional. |
| |
| An operator is functional if: |
| - it does not mutate any of its inputs |
| - it does not return a view on any of its inputs |
| - it has at least one return |
| """ |
| |
| def is_functional(schema): |
| if schema.is_mutable: |
| return False |
| rets = schema.returns |
| is_non_mutating_view = len(rets) > 0 and any( |
| r.alias_info is not None and not r.alias_info.is_write for r in rets |
| ) |
| if is_non_mutating_view: |
| return False |
| if not schema.returns: |
| return False |
| return True |
| |
| if isinstance(schema, torch._C.FunctionSchema): |
| return is_functional(schema) |
| |
| # Lazy import because not all PyTorch builds have torchgen |
| from torchgen.model import FunctionSchema |
| |
| if isinstance(schema, str): |
| schema = FunctionSchema.parse(schema) |
| assert isinstance(schema, FunctionSchema) |
| return is_functional(schema) |
| |
| |
| # should be torch._C.JitType but that annotation is busted |
| def is_tensorlist_like_type(typ: Any) -> bool: |
| return ( |
| typ == _C.ListType(_C.TensorType.get()) |
| or typ == _C.ListType(_C.OptionalType(_C.TensorType.get())) |
| or typ == _C.OptionalType(_C.ListType(_C.TensorType.get())) |
| or typ == _C.OptionalType(_C.ListType(_C.OptionalType(_C.TensorType.get()))) |
| ) |
| |
| |
| # should be torch._C.JitType but that annotation is busted |
| def is_tensor_like_type(typ: Any) -> bool: |
| return typ == _C.TensorType.get() or typ == _C.OptionalType(_C.TensorType.get()) |
| |
| |
| def mutates_and_returns_first_arg(op: torch._ops.OpOverload): |
| """Check if an op is an inplace aten op, i.e. it mutates and returns the first arg. |
| |
| TODO: torchgen/model.py's FunctionSchema.parse is the source of truth for this, |
| but not all PyTorch builds have torchgen (due to the yaml dependency being weird). |
| Figure this out. |
| |
| Example: add_(Tensor(a!) x, Tensor y) -> Tensor(a) |
| """ |
| if op.namespace != "aten": |
| return False |
| schema = op._schema |
| if not len(schema.returns) == 1: |
| return False |
| if schema.returns[0].alias_info is None: |
| return False |
| alias_set = schema.returns[0].alias_info.after_set |
| if len(alias_set) != 1: |
| return False |
| loc = next(iter(alias_set)) |
| if len(schema.arguments) < 1: |
| return False |
| first_arg = schema.arguments[0] |
| if first_arg.alias_info is None: |
| return False |
| if not first_arg.alias_info.is_write: |
| return False |
| alias_set = first_arg.alias_info.after_set |
| if len(alias_set) != 1: |
| return False |
| if loc != next(iter(alias_set)): |
| return False |
| for arg in schema.arguments[1:]: |
| if arg.alias_info is not None: |
| return False |
| return True |
| |
| |
| def fill_defaults(schema, args, kwargs): |
| new_args = [] |
| new_kwargs = {} |
| for i in range(len(schema.arguments)): |
| info = schema.arguments[i] |
| if info.kwarg_only: |
| if info.name in kwargs: |
| new_kwargs[info.name] = kwargs[info.name] |
| else: |
| new_kwargs[info.name] = info.default_value |
| else: |
| if i < len(args): |
| new_args.append(args[i]) |
| else: |
| new_args.append(info.default_value) |
| return tuple(new_args), new_kwargs |
| |
| |
| def zip_schema( |
| schema: _C.FunctionSchema, args: Tuple[Any, ...], kwargs: Dict[str, Any] |
| ) -> Iterable[Tuple[_C.Argument, Any]]: |
| """zips schema.arguments and (args, kwargs) together. |
| |
| Assumes that (args, kwargs) were the inputs to some torch._ops.OpOverload: |
| that is, kwargs must be keyword-only arguments and default values may be omitted. |
| """ |
| assert len(schema.arguments) >= len(args) + len(kwargs) |
| for i in range(len(schema.arguments)): |
| info = schema.arguments[i] |
| if info.kwarg_only: |
| if info.name in kwargs: |
| yield info, kwargs[info.name] |
| continue |
| if i >= len(args): |
| # args that are equal to their default values are not populated |
| # if they are followed by args that are equal to their defaults. |
| # Skip these. |
| continue |
| yield info, args[i] |
| return |
| |
| |
| def can_generate_trivial_fake_impl(op: torch._ops.OpOverload) -> bool: |
| assert isinstance(op, torch._ops.OpOverload) |
| if is_builtin(op): |
| # We control the built-ins. These may (in rare cases) |
| # do input metadata mutation (which we have banned on custom ops) |
| return False |
| schema = op._schema |
| # It's suspicious if the op is not mutable but returns nothing, so we return False out of an abundance of caution |
| if not schema.is_mutable: |
| return False |
| if len(schema.returns) > 0: |
| return False |
| # If the op returns nothing, then it has a trivial fake impl. |
| return True |
| |
| |
| def requires_set_python_module() -> bool: |
| """If an op was defined in C++ and extended from Python using the |
| torch.library APIs, returns if we require that there have been a |
| m.set_python_module("mylib.ops") call from C++ that associates |
| the C++ op with a python module. |
| """ |
| return getattr(_utils_internal, "REQUIRES_SET_PYTHON_MODULE", True) |
| |
| |
| def handle_dispatch_mode(curr_mode, op_overload, *args, **kwargs): |
| assert isinstance(curr_mode, torch.utils._python_dispatch.TorchDispatchMode) |
| overload_types = [] |
| args_flattened, _ = torch.utils._pytree.tree_flatten((args, kwargs.values())) |
| for a in args_flattened: |
| # TODO: need to double check the semantics of the "types" argument to torch_dispatch. |
| # It's generated in PyInterpreter.cpp, but seems to be generated in two places, |
| # where in one case we only include tensors with the python key, and in another |
| # we include **all** tensors. |
| if isinstance(a, torch.Tensor) and torch._C._dispatch_keys(a).has( |
| torch._C.DispatchKey.Python |
| ): |
| overload_types.append(type(a)) |
| # TODO: check that I got these args correct (in C++, we pass in "0000"??) |
| |
| return curr_mode.__torch_dispatch__(op_overload, overload_types, args, kwargs) |
| |
| |
| def has_kwarg_only_args(schema: _C.FunctionSchema): |
| return any(a.kwarg_only for a in schema.arguments) |
| |
| |
| def has_kwarg_only_tensors(schema: _C.FunctionSchema): |
| for a in schema.arguments: |
| if not (is_tensor_like_type(a.type) or is_tensorlist_like_type(a.type)): |
| continue |
| if not a.kwarg_only: |
| continue |
| return True |
| return False |