| # mypy: allow-untyped-defs |
| import contextlib |
| import functools |
| import inspect |
| import re |
| import sys |
| import traceback |
| import weakref |
| from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple, Union |
| from typing_extensions import deprecated |
| |
| import torch |
| import torch._library as _library |
| from torch._library.custom_ops import ( |
| _maybe_get_opdef, |
| custom_op, |
| CustomOpDef, |
| device_types_t, |
| ) |
| from torch._library.infer_schema import infer_schema # noqa: F401 |
| from torch._ops import OpOverload |
| |
| |
| __all__ = [ |
| "Library", |
| "impl", |
| "define", |
| "fallthrough_kernel", |
| "impl_abstract", |
| "register_fake", |
| "register_torch_dispatch", |
| "register_vmap", |
| "get_ctx", |
| "custom_op", |
| "infer_schema", |
| ] |
| |
| # Set containing the combination of (namespace, operator, DispatchKey) for which a new kernel has been registered |
| # The keys in the set are of the form `namespace + "/" + op_name + "/" + dispatch_key`. |
| # This set is maintained to ensure that two libraries don't try to override the exact same functionality to avoid |
| # libraries calling into kernels not intended to be called. |
| _impls: Set[str] = set() |
| _defs: Set[str] = set() |
| |
| # prim is reserved by TorchScript interpreter |
| _reserved_namespaces = ["prim"] |
| |
| |
| def fallthrough_kernel(): |
| """ |
| A dummy function to pass to ``Library.impl`` in order to register a fallthrough. |
| """ |
| raise NotImplementedError("fallthrough_kernel() should never be called.") |
| |
| |
| class Library: |
| """ |
| A class to create libraries that can be used to register new operators or |
| override operators in existing libraries from Python. |
| A user can optionally pass in a dispatch keyname if they only want to register |
| kernels corresponding to only one specific dispatch key. |
| |
| To create a library to override operators in an existing library (with name ns), set the kind to "IMPL". |
| To create a new library (with name ns) to register new operators, set the kind to "DEF". |
| To create a fragment of a possibly existing library to register operators (and bypass |
| the limitation that there is only one library for a given namespace), set the kind to |
| "FRAGMENT". |
| |
| Args: |
| ns: library name |
| kind: "DEF", "IMPL" (default: "IMPL"), "FRAGMENT" |
| dispatch_key: PyTorch dispatch key (default: "") |
| """ |
| |
| def __init__(self, ns, kind, dispatch_key=""): |
| if kind not in ("IMPL", "DEF", "FRAGMENT"): |
| raise ValueError("Unsupported kind: ", kind) |
| |
| if ns in _reserved_namespaces and (kind == "DEF" or kind == "FRAGMENT"): |
| raise ValueError( |
| ns, |
| " is a reserved namespace. Please try creating a library with another name.", |
| ) |
| |
| frame = traceback.extract_stack(limit=3)[0] |
| filename, lineno = frame.filename, frame.lineno |
| self.m: Optional[Any] = torch._C._dispatch_library( |
| kind, ns, dispatch_key, filename, lineno |
| ) |
| self.ns = ns |
| self._op_defs: Set[str] = set() |
| self._op_impls: Set[str] = set() |
| self._registration_handles: List[torch._library.utils.RegistrationHandle] = [] |
| self.kind = kind |
| self.dispatch_key = dispatch_key |
| # Use a finalizer to setup the "destructor" instead of __del__. |
| # Python __del__ can lead to weird things (globals and locals may already |
| # be gone when __del__ actually gets called!). finalizers help the |
| # situation because it lets us capture references and keeps them alive |
| weakref.finalize( |
| self, |
| _del_library, |
| _impls, |
| self._op_impls, |
| _defs, |
| self._op_defs, |
| self._registration_handles, |
| ) |
| |
| def __repr__(self): |
| return f"Library(kind={self.kind}, ns={self.ns}, dispatch_key={self.dispatch_key})>" |
| |
| def define(self, schema, alias_analysis="", *, tags=()): |
| r"""Defines a new operator and its semantics in the ns namespace. |
| |
| Args: |
| schema: function schema to define a new operator. |
| alias_analysis (optional): Indicates if the aliasing properties of the operator arguments can be |
| inferred from the schema (default behavior) or not ("CONSERVATIVE"). |
| tags (Tag | Sequence[Tag]): one or more torch.Tag to apply to this |
| operator. Tagging an operator changes the operator's behavior |
| under various PyTorch subsystems; please read the docs for the |
| torch.Tag carefully before applying it. |
| |
| Returns: |
| name of the operator as inferred from the schema. |
| |
| Example:: |
| >>> my_lib = Library("mylib", "DEF") |
| >>> my_lib.define("sum(Tensor self) -> Tensor") |
| """ |
| # This is added because we also want to disallow PURE_FUNCTION alias analysis which is a valid |
| # AliasAnalysis type in C++ |
| if alias_analysis not in ["", "FROM_SCHEMA", "CONSERVATIVE"]: |
| raise RuntimeError(f"Invalid alias_analysis type {alias_analysis}") |
| assert self.m is not None |
| if isinstance(tags, torch.Tag): |
| tags = (tags,) |
| |
| name = schema.split("(")[0] |
| packet_name = name.split(".")[0] if "." in name else name |
| has_preexisting_packet = hasattr(torch.ops, self.ns) and hasattr( |
| getattr(torch.ops, self.ns), packet_name |
| ) |
| |
| result = self.m.define(schema, alias_analysis, tuple(tags)) |
| name = schema.split("(")[0] |
| qualname = self.ns + "::" + name |
| |
| # If the OpOverloadPacket exists already, then this means we're adding a |
| # new OpOverload for it. Refresh the packet to include the new OpOverload. |
| if has_preexisting_packet: |
| ns = getattr(torch.ops, self.ns) |
| packet = getattr(ns, packet_name) |
| torch._ops._refresh_packet(packet) |
| |
| self._op_defs.add(qualname) |
| _defs.add(qualname) |
| return result |
| |
| def _register_fake(self, op_name, fn, _stacklevel=1): |
| r"""Registers the fake impl for an operator defined in the library.""" |
| source = torch._library.utils.get_source(_stacklevel + 1) |
| frame = sys._getframe(_stacklevel) |
| caller_module = inspect.getmodule(frame) |
| # Can be none if you call register_fake from somewhere there isn't a module |
| # (e.g. __main__) |
| caller_module_name = None if caller_module is None else caller_module.__name__ |
| |
| # TODO(rzou): We're gonna need to stage this change with torchvision, |
| # since torchvision is github first. |
| if caller_module_name is not None and caller_module_name.startswith( |
| "torchvision." |
| ): |
| caller_module_name = None |
| |
| qualname = f"{self.ns}::{op_name}" |
| entry = torch._library.simple_registry.singleton.find(qualname) |
| if caller_module_name is not None: |
| func_to_register = _check_pystubs_once(fn, qualname, caller_module_name) |
| else: |
| func_to_register = fn |
| |
| handle = entry.fake_impl.register(func_to_register, source) |
| self._registration_handles.append(handle) |
| |
| def _register_torch_dispatch_rule(self, op_name, torch_dispatch_class, fn): |
| r"""Registers a torch_dispatch rule for the given operator and torch_dispatch_class. |
| |
| This allows for open registration to specify the behavior between the operator |
| and the torch_dispatch_class without needing to modify the torch_dispatch_class |
| or the operator directly. |
| |
| The torch_dispatch_class is either a Tensor subclass with `__torch_dispatch__` or a |
| TorchDispatchMode. |
| |
| If it is a Tensor subclass, we expect fn to have the following signature: |
| (cls, func: OpOverload, types: Tuple[type, ...], args, kwargs) -> Any |
| |
| If it is a TorchDispatchMode, we expect fn to have the following signature: |
| (mode, func: OpOverload, types: Tuple[type, ...], args, kwargs) -> Any |
| """ |
| qualname = f"{self.ns}::{op_name}" |
| entry = torch._library.simple_registry.singleton.find(qualname) |
| handle = entry.torch_dispatch_rules.register(torch_dispatch_class, fn) |
| self._registration_handles.append(handle) |
| |
| def _impl_with_aoti_compile(self, op_name, dispatch_key=""): |
| r"""Register the operator to use the AOTI-compiled implementation. |
| |
| Args: |
| op_name: operator name (along with the overload) or OpOverload object. |
| dispatch_key: dispatch key that the input function should be registered for. By default, it uses |
| the dispatch key that the library was created with. |
| |
| Example:: |
| >>> my_lib = Library("aten", "IMPL") |
| >>> my_lib._impl_with_aoti_compile("div.Tensor", "CPU") |
| """ |
| if dispatch_key == "": |
| dispatch_key = self.dispatch_key |
| assert torch.DispatchKeySet(dispatch_key).has(torch._C.DispatchKey.Dense) |
| |
| if isinstance(op_name, str): |
| name = op_name |
| elif isinstance(op_name, OpOverload): |
| name = op_name._schema.name |
| overload_name = op_name._schema.overload_name |
| if overload_name != "": |
| name = name + "." + overload_name |
| else: |
| raise RuntimeError( |
| "_impl_with_aoti_compile should be passed either a name or an OpOverload object " |
| "as the first argument" |
| ) |
| |
| key = self.ns + "/" + name.split("::")[-1] + "/" + dispatch_key |
| if key in _impls: |
| # TODO: in future, add more info about where the existing function is registered (this info is |
| # today already returned by the C++ warning when _impl_with_aoti_compile is called but we error out before that) |
| raise RuntimeError( |
| "This is not allowed since there's already a kernel registered from python overriding {}" |
| "'s behavior for {} dispatch key and {} namespace.".format( |
| name.split("::")[-1], dispatch_key, self.ns |
| ) |
| ) |
| |
| assert self.m is not None |
| impl_fn: Callable = self.m.impl_with_aoti_compile |
| impl_fn(self.ns, name.split("::")[-1], dispatch_key) |
| |
| _impls.add(key) |
| self._op_impls.add(key) |
| |
| def impl(self, op_name, fn, dispatch_key="", *, with_keyset=False): |
| r"""Registers the function implementation for an operator defined in the library. |
| |
| Args: |
| op_name: operator name (along with the overload) or OpOverload object. |
| fn: function that's the operator implementation for the input dispatch key or :func:`~fallthrough_kernel` |
| to register a fallthrough. |
| dispatch_key: dispatch key that the input function should be registered for. By default, it uses |
| the dispatch key that the library was created with. |
| with_keyset: flag controlling if the current dispatcher call keyset should be passed as the first argument |
| to :attr:`fn` when calling. This should be used to create the appropriate keyset for redispatch calls. |
| |
| Example:: |
| >>> my_lib = Library("aten", "IMPL") |
| >>> def div_cpu(self, other): |
| >>> return self * (1 / other) |
| >>> my_lib.impl("div.Tensor", div_cpu, "CPU") |
| """ |
| if not callable(fn): |
| raise TypeError( |
| f"Input function is required to be a callable but found type {type(fn)}" |
| ) |
| if dispatch_key == "": |
| dispatch_key = self.dispatch_key |
| |
| if isinstance(op_name, str): |
| name = op_name |
| elif isinstance(op_name, OpOverload): |
| name = op_name._schema.name |
| overload_name = op_name._schema.overload_name |
| if overload_name != "": |
| name = name + "." + overload_name |
| else: |
| raise RuntimeError( |
| "impl should be passed either a name or an OpOverload object as the first argument" |
| ) |
| |
| key = self.ns + "/" + name.split("::")[-1] + "/" + dispatch_key |
| if key in _impls: |
| # TODO: in future, add more info about where the existing function is registered (this info is |
| # today already returned by the C++ warning when impl is called but we error out before that) |
| raise RuntimeError( |
| "This is not allowed since there's already a kernel registered from python overriding {}" |
| "'s behavior for {} dispatch key and {} namespace.".format( |
| name.split("::")[-1], dispatch_key, self.ns |
| ) |
| ) |
| |
| if dispatch_key == "Meta": |
| dispatcher_op_name = name |
| if "::" not in dispatcher_op_name: |
| dispatcher_op_name = f"{self.ns}::{dispatcher_op_name}" |
| |
| # Internally, we shouldn't be registering meta kernels for any operators that |
| # have CompositeImplicitAutograd kernels. |
| # Instead, we should be letting those decompositions run, and writing meta kernels |
| # only for the base operators. |
| if torch._C._dispatch_has_kernel_for_dispatch_key( |
| dispatcher_op_name, "CompositeImplicitAutograd" |
| ): |
| raise RuntimeError( |
| f"We should not register a meta kernel directly to the operator '{name}'," |
| " because it has a CompositeImplicitAutograd kernel in core." |
| " Instead we should let the operator decompose, and ensure that we have meta kernels" |
| " for the base ops that it decomposes into." |
| ) |
| |
| assert self.m is not None |
| self.m.impl( |
| name, |
| dispatch_key if dispatch_key != "" else "CompositeImplicitAutograd", |
| fn, |
| with_keyset, |
| ) |
| |
| _impls.add(key) |
| self._op_impls.add(key) |
| |
| def fallback(self, fn, dispatch_key="", *, with_keyset=False): |
| r"""Registers the function implementation as the fallback for the given key. |
| |
| This function only works for a library with global namespace ("_"). |
| |
| Args: |
| fn: function used as fallback for the given dispatch key or :func:`~fallthrough_kernel` |
| to register a fallthrough. |
| dispatch_key: dispatch key that the input function should be registered for. By default, it uses |
| the dispatch key that the library was created with. |
| with_keyset: flag controlling if the current dispatcher call keyset should be passed as the first argument |
| to :attr:`fn` when calling. This should be used to create the appropriate keyset for redispatch calls. |
| |
| Example:: |
| >>> my_lib = Library("_", "IMPL") |
| >>> def fallback_kernel(op, *args, **kwargs): |
| >>> # Handle all autocast ops generically |
| >>> # ... |
| >>> my_lib.fallback(fallback_kernel, "Autocast") |
| """ |
| if dispatch_key == "": |
| dispatch_key = self.dispatch_key |
| |
| if self.ns != "_": |
| raise RuntimeError( |
| f"""Fallback can only be registered using libary fragment on the global namespace "_" but it is {self.ns}""" |
| ) |
| |
| assert dispatch_key != "" |
| assert self.m is not None |
| |
| self.m.fallback(dispatch_key, fn, with_keyset) |
| |
| def _destroy(self): |
| if self.m is not None: |
| self.m.reset() |
| self.m = None |
| for handle in self._registration_handles: |
| handle.destroy() |
| self._registration_handles.clear() |
| global _impls |
| _impls -= self._op_impls |
| for name in self._op_defs: |
| # Delete the cached torch.ops.ns.foo if it was registered. |
| # Otherwise, accessing it leads to a segfault. |
| # It's possible that we only registered an overload in this Library |
| # and another library owns an alive overload. |
| # That's OK - the next time torch.ops.ns.foo gets called, it'll be |
| # recomputed to point at the right collection of overloads. |
| ns, name_with_overload = name.split("::") |
| name = name_with_overload.split(".")[0] |
| if not hasattr(torch.ops, ns): |
| continue |
| namespace = getattr(torch.ops, ns) |
| if not hasattr(namespace, name): |
| continue |
| delattr(namespace, name) |
| |
| |
| def _del_library( |
| captured_impls, |
| op_impls, |
| captured_defs, |
| op_defs, |
| registration_handles, |
| ): |
| captured_impls -= op_impls |
| captured_defs -= op_defs |
| for handle in registration_handles: |
| handle.destroy() |
| |
| |
| @contextlib.contextmanager |
| def _scoped_library(*args, **kwargs): |
| try: |
| lib = Library(*args, **kwargs) |
| yield lib |
| finally: |
| lib._destroy() |
| |
| |
| _keep_alive: List[Library] = [] |
| |
| |
| NAMELESS_SCHEMA = re.compile(r"\(.*\) -> .*") |
| |
| |
| @functools.singledispatch |
| def define(qualname, schema, *, lib=None, tags=()): |
| r"""Defines a new operator. |
| |
| In PyTorch, defining an op (short for "operator") is a two step-process: |
| - we need to define the op (by providing an operator name and schema) |
| - we need to implement behavior for how the operator interacts with |
| various PyTorch subsystems, like CPU/CUDA Tensors, Autograd, etc. |
| |
| This entrypoint defines the custom operator (the first step) |
| you must then perform the second step by calling various |
| ``impl_*`` APIs, like :func:`torch.library.impl` or |
| :func:`torch.library.register_fake`. |
| |
| Args: |
| qualname (str): The qualified name for the operator. Should be |
| a string that looks like "namespace::name", e.g. "aten::sin". |
| Operators in PyTorch need a namespace to |
| avoid name collisions; a given operator may only be created once. |
| If you are writing a Python library, we recommend the namespace to |
| be the name of your top-level module. |
| schema (str): The schema of the operator. E.g. "(Tensor x) -> Tensor" |
| for an op that accepts one Tensor and returns one Tensor. It does |
| not contain the operator name (that is passed in ``qualname``). |
| lib (Optional[Library]): If provided, the lifetime of this operator |
| will be tied to the lifetime of the Library object. |
| tags (Tag | Sequence[Tag]): one or more torch.Tag to apply to this |
| operator. Tagging an operator changes the operator's behavior |
| under various PyTorch subsystems; please read the docs for the |
| torch.Tag carefully before applying it. |
| |
| Example:: |
| >>> import torch |
| >>> import numpy as np |
| >>> |
| >>> # Define the operator |
| >>> torch.library.define("mylib::sin", "(Tensor x) -> Tensor") |
| >>> |
| >>> # Add implementations for the operator |
| >>> @torch.library.impl("mylib::sin", "cpu") |
| >>> def f(x): |
| >>> return torch.from_numpy(np.sin(x.numpy())) |
| >>> |
| >>> # Call the new operator from torch.ops. |
| >>> x = torch.randn(3) |
| >>> y = torch.ops.mylib.sin(x) |
| >>> assert torch.allclose(y, x.sin()) |
| |
| """ |
| if not isinstance(qualname, str): |
| raise ValueError( |
| f"define(qualname, schema): expected qualname " |
| f"to be instance of str, got {type(qualname)}" |
| ) |
| namespace, name = torch._library.utils.parse_namespace(qualname) |
| if lib is None: |
| lib = Library(namespace, "FRAGMENT") |
| _keep_alive.append(lib) |
| if not NAMELESS_SCHEMA.fullmatch(schema): |
| raise ValueError( |
| f"define(qualname, schema, ...): expected schema " |
| f'to look like e.g. "(Tensor x) -> Tensor" but ' |
| f'got "{schema}"' |
| ) |
| lib.define(name + schema, alias_analysis="", tags=tags) |
| |
| |
| @define.register |
| def _(lib: Library, schema, alias_analysis=""): |
| """The old torch.library.define. |
| We're keeping this around for BC reasons |
| """ |
| |
| def wrap(f): |
| name = lib.define(schema, alias_analysis) |
| lib.impl(name, f) |
| return f |
| |
| return wrap |
| |
| |
| @functools.singledispatch |
| def impl(qualname, types, func=None, *, lib=None): |
| """Register an implementation for a device type for this operator. |
| |
| You may pass "default" for ``types`` to register this implementation as the |
| default implementation for ALL device types. |
| Please only use this if the implementation truly supports all device types; |
| for example, this is true if it is a composition of built-in PyTorch operators. |
| |
| Some valid types are: "cpu", "cuda", "xla", "mps", "ipu", "xpu". |
| |
| Args: |
| qualname (str): Should be a string that looks like "namespace::operator_name". |
| types (str | Sequence[str]): The device types to register an impl to. |
| lib (Optional[Library]): If provided, the lifetime of this registration |
| will be tied to the lifetime of the Library object. |
| |
| Examples: |
| >>> import torch |
| >>> import numpy as np |
| >>> |
| >>> # Define the operator |
| >>> torch.library.define("mylib::mysin", "(Tensor x) -> Tensor") |
| >>> |
| >>> # Add implementations for the cpu device |
| >>> @torch.library.impl("mylib::mysin", "cpu") |
| >>> def f(x): |
| >>> return torch.from_numpy(np.sin(x.numpy())) |
| >>> |
| >>> x = torch.randn(3) |
| >>> y = torch.ops.mylib.mysin(x) |
| >>> assert torch.allclose(y, x.sin()) |
| """ |
| return _impl(qualname, types, func, lib=lib, disable_dynamo=False) |
| |
| |
| def _impl(qualname, types, func=None, *, lib=None, disable_dynamo=False): |
| if isinstance(types, str): |
| types = (types,) |
| keys = set({}) |
| for typ in types: |
| is_dispatch_key = torch._C._parse_dispatch_key(typ) |
| if is_dispatch_key: |
| # We also support passing a DispatchKey to impl. Please prefer using |
| # the higher-level torch.library APIs and only pass DispatchKey to |
| # torch.library.impl with caution (or even better, don't use this |
| # option and file an issue on GitHub for what you need). |
| # We don't advertise this to users because |
| # it is very easy to shoot yourself in the foot. |
| keys.add(typ) |
| else: |
| keys.add(_device_type_to_key(typ)) |
| |
| def register(func): |
| namespace, _ = torch._library.utils.parse_namespace(qualname) |
| |
| if lib is None: |
| use_lib = Library(namespace, "FRAGMENT") |
| _keep_alive.append(use_lib) |
| else: |
| use_lib = lib |
| if disable_dynamo: |
| |
| @torch._disable_dynamo |
| def func_no_dynamo(*args, **kwargs): |
| return func(*args, **kwargs) |
| |
| for key in keys: |
| use_lib.impl(qualname, func_no_dynamo, key) |
| else: |
| for key in keys: |
| use_lib.impl(qualname, func, key) |
| |
| if func is None: |
| return register |
| else: |
| register(func) |
| |
| |
| def _device_type_to_key(device_type: str) -> str: |
| if device_type == "default": |
| # This is technically not correct, because although all device_type |
| # DispatchKeys are included in CompositeExplicitAutograd, |
| # not everything in CompositeExplicitAutograd is associated with a |
| # device_type. I don't really care that much about the difference. |
| return "CompositeExplicitAutograd" |
| return torch._C._dispatch_key_for_device(device_type) |
| |
| |
| @impl.register |
| def _(lib: Library, name, dispatch_key=""): |
| """Legacy torch.library.impl API. Kept around for BC""" |
| |
| def wrap(f): |
| lib.impl(name, f, dispatch_key) |
| return f |
| |
| return wrap |
| |
| |
| @deprecated( |
| "`torch.library.impl_abstract` was renamed to `torch.library.register_fake`. Please use that " |
| "instead; we will remove `torch.library.impl_abstract` in a future version of PyTorch.", |
| category=FutureWarning, |
| ) |
| def impl_abstract(qualname, func=None, *, lib=None, _stacklevel=1): |
| r"""This API was renamed to :func:`torch.library.register_fake` in PyTorch 2.4. |
| Please use that instead. |
| """ |
| if func is not None: |
| _stacklevel = _stacklevel + 1 |
| return register_fake(qualname, func, lib=lib, _stacklevel=_stacklevel) |
| |
| |
| _op_identifier = Union[ |
| str, "torch._ops.OpOverload", "torch._library.custom_ops.CustomOpDef" |
| ] |
| |
| |
| def register_kernel( |
| op: _op_identifier, |
| device_types: device_types_t, |
| func: Optional[Callable] = None, |
| /, |
| *, |
| lib: Optional[Library] = None, |
| ): |
| """Register an implementation for a device type for this operator. |
| |
| Some valid device_types are: "cpu", "cuda", "xla", "mps", "ipu", "xpu". |
| This API may be used as a decorator. |
| |
| Args: |
| fn (Callable): The function to register as the implementation for |
| the given device types. |
| device_types (None | str | Sequence[str]): The device_types to register an impl to. |
| If None, we will register to all device types -- please only use |
| this option if your implementation is truly device-type-agnostic. |
| |
| Examples:: |
| >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA) |
| >>> import torch |
| >>> from torch import Tensor |
| >>> from torch.library import custom_op |
| >>> import numpy as np |
| >>> |
| >>> # Create a custom op that works on cpu |
| >>> @custom_op("mylib::numpy_sin", mutates_args=(), device_types="cpu") |
| >>> def numpy_sin(x: Tensor) -> Tensor: |
| >>> x_np = x.numpy() |
| >>> y_np = np.sin(x_np) |
| >>> return torch.from_numpy(y_np) |
| >>> |
| >>> # Add implementations for the cuda device |
| >>> @torch.library.register_kernel("mylib::numpy_sin", "cuda") |
| >>> def _(x): |
| >>> x_np = x.cpu().numpy() |
| >>> y_np = np.sin(x_np) |
| >>> return torch.from_numpy(y_np).to(device=x.device) |
| >>> |
| >>> x_cpu = torch.randn(3) |
| >>> x_cuda = x_cpu.cuda() |
| >>> assert torch.allclose(numpy_sin(x_cpu), x_cpu.sin()) |
| >>> assert torch.allclose(numpy_sin(x_cuda), x_cuda.sin()) |
| |
| """ |
| |
| if not isinstance( |
| op, (str, torch._ops.OpOverload, torch._library.custom_ops.CustomOpDef) |
| ): |
| raise ValueError("register_kernel(op): got unexpected type for op: {type(op)}") |
| if isinstance(op, torch._ops.OpOverload): |
| op = op._name |
| opdef = _maybe_get_opdef(op) |
| if opdef is not None: |
| return opdef.register_kernel(device_types, func) |
| assert isinstance(op, str) |
| if device_types is None: |
| device_types = "CompositeExplicitAutograd" |
| |
| return _impl(op, device_types, func, lib=lib, disable_dynamo=True) |
| |
| |
| def register_fake( |
| op: _op_identifier, |
| func: Optional[Callable] = None, |
| /, |
| *, |
| lib: Optional[Library] = None, |
| _stacklevel: int = 1, |
| ): |
| r"""Register a FakeTensor implementation ("fake impl") for this operator. |
| |
| Also sometimes known as a "meta kernel", "abstract impl". |
| |
| An "FakeTensor implementation" specifies the behavior of this operator on |
| Tensors that carry no data ("FakeTensor"). Given some input Tensors with |
| certain properties (sizes/strides/storage_offset/device), it specifies |
| what the properties of the output Tensors are. |
| |
| The FakeTensor implementation has the same signature as the operator. |
| It is run for both FakeTensors and meta tensors. To write a FakeTensor |
| implementation, assume that all Tensor inputs to the operator are |
| regular CPU/CUDA/Meta tensors, but they do not have storage, and |
| you are trying to return regular CPU/CUDA/Meta tensor(s) as output. |
| The FakeTensor implementation must consist of only PyTorch operations |
| (and may not directly access the storage or data of any input or |
| intermediate Tensors). |
| |
| This API may be used as a decorator (see examples). |
| |
| For a detailed guide on custom ops, please see |
| https://pytorch.org/tutorials/advanced/custom_ops_landing_page.html |
| |
| Examples: |
| >>> import torch |
| >>> import numpy as np |
| >>> from torch import Tensor |
| >>> |
| >>> # Example 1: an operator without data-dependent output shape |
| >>> @torch.library.custom_op("mylib::custom_linear", mutates_args=()) |
| >>> def custom_linear(x: Tensor, weight: Tensor, bias: Tensor) -> Tensor: |
| >>> raise NotImplementedError("Implementation goes here") |
| >>> |
| >>> @torch.library.register_fake("mylib::custom_linear") |
| >>> def _(x, weight, bias): |
| >>> assert x.dim() == 2 |
| >>> assert weight.dim() == 2 |
| >>> assert bias.dim() == 1 |
| >>> assert x.shape[1] == weight.shape[1] |
| >>> assert weight.shape[0] == bias.shape[0] |
| >>> assert x.device == weight.device |
| >>> |
| >>> return (x @ weight.t()) + bias |
| >>> |
| >>> with torch._subclasses.fake_tensor.FakeTensorMode(): |
| >>> x = torch.randn(2, 3) |
| >>> w = torch.randn(3, 3) |
| >>> b = torch.randn(3) |
| >>> y = torch.ops.mylib.custom_linear(x, w, b) |
| >>> |
| >>> assert y.shape == (2, 3) |
| >>> |
| >>> # Example 2: an operator with data-dependent output shape |
| >>> @torch.library.custom_op("mylib::custom_nonzero", mutates_args=()) |
| >>> def custom_nonzero(x: Tensor) -> Tensor: |
| >>> x_np = x.numpy(force=True) |
| >>> res = np.stack(np.nonzero(x_np), axis=1) |
| >>> return torch.tensor(res, device=x.device) |
| >>> |
| >>> @torch.library.register_fake("mylib::custom_nonzero") |
| >>> def _(x): |
| >>> # Number of nonzero-elements is data-dependent. |
| >>> # Since we cannot peek at the data in an fake impl, |
| >>> # we use the ctx object to construct a new symint that |
| >>> # represents the data-dependent size. |
| >>> ctx = torch.library.get_ctx() |
| >>> nnz = ctx.new_dynamic_size() |
| >>> shape = [nnz, x.dim()] |
| >>> result = x.new_empty(shape, dtype=torch.int64) |
| >>> return result |
| >>> |
| >>> from torch.fx.experimental.proxy_tensor import make_fx |
| >>> |
| >>> x = torch.tensor([0, 1, 2, 3, 4, 0]) |
| >>> trace = make_fx(torch.ops.mylib.custom_nonzero, tracing_mode="symbolic")(x) |
| >>> trace.print_readable() |
| >>> |
| >>> assert torch.allclose(trace(x), torch.ops.mylib.custom_nonzero(x)) |
| |
| """ |
| if not isinstance( |
| op, (str, torch._ops.OpOverload, torch._library.custom_ops.CustomOpDef) |
| ): |
| raise ValueError("register_fake(op): got unexpected type for op: {type(op)}") |
| if isinstance(op, torch._ops.OpOverload): |
| op = op._name |
| opdef = _maybe_get_opdef(op) |
| if opdef is not None: |
| if func is None: |
| return opdef.register_fake |
| else: |
| return opdef.register_fake(func) |
| assert isinstance(op, str) |
| |
| stacklevel = _stacklevel |
| |
| def register(func): |
| namespace, op_name = torch._library.utils.parse_namespace(op) |
| if lib is None: |
| use_lib = Library(namespace, "FRAGMENT") |
| _keep_alive.append(use_lib) |
| else: |
| use_lib = lib |
| use_lib._register_fake(op_name, func, _stacklevel=stacklevel + 1) |
| return func |
| |
| if func is None: |
| return register |
| else: |
| stacklevel += 1 |
| return register(func) |
| |
| |
| def register_autograd( |
| op: _op_identifier, |
| backward: Callable, |
| /, |
| *, |
| setup_context: Optional[Callable] = None, |
| lib=None, |
| ) -> None: |
| r"""Register a backward formula for this custom op. |
| |
| In order for an operator to work with autograd, you need to register |
| a backward formula: |
| 1. You must tell us how to compute gradients during the backward pass |
| by providing us a "backward" function. |
| 2. If you need any values from the forward to compute gradients, you can |
| use `setup_context` to save values for backward. |
| |
| ``backward`` runs during the backward pass. It accepts ``(ctx, *grads)``: |
| - ``grads`` is one or more gradients. The number of gradients matches |
| the number of outputs of the operator. |
| The ``ctx`` object is `the same ctx object <context_method_mixins>`_ used by |
| :class:`torch.autograd.Function`. The semantics of ``backward_fn`` are the |
| same as :meth:`torch.autograd.Function.backward`. |
| |
| ``setup_context(ctx, inputs, output)`` runs during the forward pass. |
| Please save quantities needed for backward onto the ``ctx`` object via |
| either :meth:`torch.autograd.function.FunctionCtx.save_for_backward` |
| or assigning them as attributes of ``ctx``. If your custom op has |
| kwarg-only arguments, we expect the signature of ``setup_context`` |
| to be ``setup_context(ctx, inputs, keyword_only_inputs, output)``. |
| |
| Both ``setup_context_fn`` and ``backward_fn`` must be traceable. That is, |
| they may not directly access :meth:`torch.Tensor.data_ptr` and they must |
| not depend on or mutate global state. If you need a non-traceable backward, |
| you can make it a separate custom_op that you call inside ``backward_fn``. |
| |
| Examples: |
| >>> import torch |
| >>> import numpy as np |
| >>> from torch import Tensor |
| >>> |
| >>> @torch.library.custom_op("mylib::numpy_sin", mutates_args=()) |
| >>> def numpy_sin(x: Tensor) -> Tensor: |
| >>> x_np = x.cpu().numpy() |
| >>> y_np = np.sin(x_np) |
| >>> return torch.from_numpy(y_np).to(device=x.device) |
| >>> |
| >>> def setup_context(ctx, inputs, output) -> Tensor: |
| >>> x, = inputs |
| >>> ctx.save_for_backward(x) |
| >>> |
| >>> def backward(ctx, grad): |
| >>> x, = ctx.saved_tensors |
| >>> return grad * x.cos() |
| >>> |
| >>> torch.library.register_autograd( |
| ... "mylib::numpy_sin", backward, setup_context=setup_context |
| ... ) |
| >>> |
| >>> x = torch.randn(3, requires_grad=True) |
| >>> y = numpy_sin(x) |
| >>> (grad_x,) = torch.autograd.grad(y, x, torch.ones_like(y)) |
| >>> assert torch.allclose(grad_x, x.cos()) |
| >>> |
| >>> # Example with a keyword-only arg |
| >>> @torch.library.custom_op("mylib::numpy_mul", mutates_args=()) |
| >>> def numpy_mul(x: Tensor, *, val: float) -> Tensor: |
| >>> x_np = x.cpu().numpy() |
| >>> y_np = x_np * val |
| >>> return torch.from_numpy(y_np).to(device=x.device) |
| >>> |
| >>> def setup_context(ctx, inputs, keyword_only_inputs, output) -> Tensor: |
| >>> ctx.val = keyword_only_inputs["val"] |
| >>> |
| >>> def backward(ctx, grad): |
| >>> return grad * ctx.val |
| >>> |
| >>> torch.library.register_autograd( |
| ... "mylib::numpy_mul", backward, setup_context=setup_context |
| ... ) |
| >>> |
| >>> x = torch.randn(3, requires_grad=True) |
| >>> y = numpy_mul(x, val=3.14) |
| >>> (grad_x,) = torch.autograd.grad(y, x, torch.ones_like(y)) |
| >>> assert torch.allclose(grad_x, torch.full_like(x, 3.14)) |
| |
| """ |
| if not isinstance( |
| op, (str, torch._ops.OpOverload, torch._library.custom_ops.CustomOpDef) |
| ): |
| raise ValueError( |
| f"register_autograd(op): got unexpected type for op: {type(op)}" |
| ) |
| if isinstance(op, torch._ops.OpOverload): |
| op = op._name |
| opdef = _maybe_get_opdef(op) |
| if opdef is not None: |
| opdef.register_autograd(backward, setup_context=setup_context) |
| return |
| |
| assert isinstance(op, str) |
| qualname = op |
| op = torch._library.utils.lookup_op(qualname) |
| schema = op._schema |
| if not _library.utils.is_functional_schema(schema): |
| raise RuntimeError( |
| f"Cannot register autograd formula for non-functional operator " |
| f"{op} with schema {schema}. Please create " |
| f"a functional operator and register an autograd formula for that." |
| ) |
| if _library.utils.has_kwarg_only_tensors(schema): |
| raise NotImplementedError( |
| f"register_autograd with kwarg-only Tensor args. In the original " |
| f"definition of the op, please make your tensors not kwarg-only. " |
| f"Got: {schema}" |
| ) |
| |
| info = _library.autograd.Info(backward, setup_context) |
| autograd_kernel = _library.autograd.make_autograd_impl(op, info) |
| namespace, opname = torch._library.utils.parse_namespace(qualname) |
| if lib is None: |
| lib = Library(namespace, "FRAGMENT") |
| _keep_alive.append(lib) |
| lib.impl(opname, autograd_kernel, "Autograd", with_keyset=True) |
| |
| |
| def register_torch_dispatch( |
| op: _op_identifier, |
| torch_dispatch_class: Any, |
| func: Optional[Callable] = None, |
| /, |
| *, |
| lib: Optional[Library] = None, |
| ): |
| r"""Registers a torch_dispatch rule for the given operator and ``torch_dispatch_class``. |
| |
| This allows for open registration to specify the behavior between the operator |
| and the ``torch_dispatch_class`` without needing to modify the ``torch_dispatch_class`` |
| or the operator directly. |
| |
| The ``torch_dispatch_class`` is either a Tensor subclass with ``__torch_dispatch__`` or a |
| TorchDispatchMode. |
| |
| If it is a Tensor subclass, we expect ``func`` to have the following signature: |
| ``(cls, func: OpOverload, types: Tuple[type, ...], args, kwargs) -> Any`` |
| |
| If it is a TorchDispatchMode, we expect ``func`` to have the following signature: |
| ``(mode, func: OpOverload, types: Tuple[type, ...], args, kwargs) -> Any`` |
| |
| ``args`` and ``kwargs`` will have been normalized the same way they are |
| in ``__torch_dispatch__`` (see :ref:`torch-dispatch-calling-convention`). |
| |
| Examples: |
| |
| >>> import torch |
| >>> |
| >>> @torch.library.custom_op("mylib::foo", mutates_args={}) |
| >>> def foo(x: torch.Tensor) -> torch.Tensor: |
| >>> return x.clone() |
| >>> |
| >>> class MyMode(torch.utils._python_dispatch.TorchDispatchMode): |
| >>> def __torch_dispatch__(self, func, types, args=(), kwargs=None): |
| >>> return func(*args, **kwargs) |
| >>> |
| >>> @torch.library.register_torch_dispatch("mylib::foo", MyMode) |
| >>> def _(mode, func, types, args, kwargs): |
| >>> x, = args |
| >>> return x + 1 |
| >>> |
| >>> x = torch.randn(3) |
| >>> y = foo(x) |
| >>> assert torch.allclose(y, x) |
| >>> |
| >>> with MyMode(): |
| >>> y = foo(x) |
| >>> assert torch.allclose(y, x + 1) |
| |
| """ |
| if not isinstance( |
| op, (str, torch._ops.OpOverload, torch._library.custom_ops.CustomOpDef) |
| ): |
| raise ValueError( |
| "register_torch_dispatch(op): got unexpected type for op: {type(op)}" |
| ) |
| if isinstance(op, torch._ops.OpOverload): |
| op = op._name |
| opdef = _maybe_get_opdef(op) |
| if opdef is not None: |
| return opdef.register_torch_dispatch(torch_dispatch_class, func) |
| assert isinstance(op, str) |
| |
| def register(func): |
| namespace, op_name = torch._library.utils.parse_namespace(op) |
| if lib is None: |
| use_lib = Library(namespace, "FRAGMENT") |
| _keep_alive.append(use_lib) |
| else: |
| use_lib = lib |
| use_lib._register_torch_dispatch_rule(op_name, torch_dispatch_class, func) |
| return func |
| |
| if func is None: |
| return register |
| else: |
| return register(func) |
| |
| |
| def register_vmap( |
| op: _op_identifier, |
| func: Optional[Callable] = None, |
| /, |
| *, |
| lib=None, |
| ): |
| r"""Register a vmap implementation to support :func:`torch.vmap` for this custom op. |
| |
| This API may be used as a decorator (see examples). |
| |
| In order for an operator to work with :func:`torch.vmap`, you may need to register a |
| vmap implementation in the following signature: |
| |
| ``vmap_func(info, in_dims: Tuple[Optional[int]], *args, **kwargs)``, |
| |
| where ``*args`` and ``**kwargs`` are the arguments and kwargs for ``op``. |
| We do not support kwarg-only Tensor args. |
| |
| It specifies how do we compute the batched version of ``op`` given inputs with an additional |
| dimension (specified by ``in_dims``). |
| |
| For each arg in ``args``, ``in_dims`` has a corresponding ``Optional[int]``. It is ``None`` |
| if the arg is not a Tensor or if the arg is not being vmapped over, otherwise, it is an integer |
| specifying what dimension of the Tensor is being vmapped over. |
| |
| ``info`` is a collection of additional metadata that may be helpful: |
| ``info.batch_size`` specifies the size of the dimension being vmapped over, while |
| ``info.randomness`` is the ``randomness`` option that was passed to :func:`torch.vmap`. |
| |
| The return of the function ``func`` is a tuple of ``(output, out_dims)``. Similar to ``in_dims``, |
| ``out_dims`` should be of the same structure as ``output`` and contain one ``out_dim`` |
| per output that specifies if the output has the vmapped dimension and what index it is in. |
| |
| Examples: |
| >>> import torch |
| >>> import numpy as np |
| >>> from torch import Tensor |
| >>> from typing import Tuple |
| >>> |
| >>> def to_numpy(tensor): |
| >>> return tensor.cpu().numpy() |
| >>> |
| >>> lib = torch.library.Library("mylib", "FRAGMENT") |
| >>> @torch.library.custom_op("mylib::numpy_cube", mutates_args=()) |
| >>> def numpy_cube(x: Tensor) -> Tuple[Tensor, Tensor]: |
| >>> x_np = to_numpy(x) |
| >>> dx = torch.tensor(3 * x_np ** 2, device=x.device) |
| >>> return torch.tensor(x_np ** 3, device=x.device), dx |
| >>> |
| >>> def numpy_cube_vmap(info, in_dims, x): |
| >>> result = numpy_cube(x) |
| >>> return result, (in_dims[0], in_dims[0]) |
| >>> |
| >>> torch.library.register_vmap(numpy_cube, numpy_cube_vmap) |
| >>> |
| >>> x = torch.randn(3) |
| >>> torch.vmap(numpy_cube)(x) |
| >>> |
| >>> @torch.library.custom_op("mylib::numpy_mul", mutates_args=()) |
| >>> def numpy_mul(x: Tensor, y: Tensor) -> Tensor: |
| >>> return torch.tensor(to_numpy(x) * to_numpy(y), device=x.device) |
| >>> |
| >>> @torch.library.register_vmap("mylib::numpy_mul") |
| >>> def numpy_mul_vmap(info, in_dims, x, y): |
| >>> x_bdim, y_bdim = in_dims |
| >>> x = x.movedim(x_bdim, -1) if x_bdim is not None else x.unsqueeze(-1) |
| >>> y = y.movedim(y_bdim, -1) if y_bdim is not None else y.unsqueeze(-1) |
| >>> result = x * y |
| >>> result = result.movedim(-1, 0) |
| >>> return result, 0 |
| >>> |
| >>> |
| >>> x = torch.randn(3) |
| >>> y = torch.randn(3) |
| >>> torch.vmap(numpy_mul)(x, y) |
| |
| .. note:: |
| The vmap function should aim to preserve the semantics of the entire custom operator. |
| That is, ``grad(vmap(op))`` should be replaceable with a ``grad(map(op))``. |
| |
| If your custom operator has any custom behavior in the backward pass, please |
| keep this in mind. |
| |
| """ |
| if not isinstance( |
| op, (str, torch._ops.OpOverload, torch._library.custom_ops.CustomOpDef) |
| ): |
| raise ValueError(f"register_vmap(op): got unexpected type for op: {type(op)}") |
| if isinstance(op, torch._ops.OpOverload): |
| op = op._name |
| opdef = _maybe_get_opdef(op) |
| if opdef is not None: |
| return opdef.register_vmap(func) |
| assert isinstance(op, str) |
| qualname = op |
| op = torch._library.utils.lookup_op(qualname) |
| schema = op._schema |
| if _library.utils.has_kwarg_only_tensors(schema): |
| raise NotImplementedError( |
| f"register_vmap with kwarg-only Tensor args. In the original " |
| f"definition of the op, please make your tensors not kwarg-only. " |
| f"Got: {schema}" |
| ) |
| |
| def register(func): |
| nonlocal op, lib |
| |
| namespace, opname = torch._library.utils.parse_namespace(qualname) |
| if lib is None: |
| lib = Library(namespace, "FRAGMENT") |
| _keep_alive.append(lib) |
| |
| from torch._functorch.autograd_function import custom_function_call_vmap_helper |
| from torch._functorch.pyfunctorch import retrieve_current_functorch_interpreter |
| |
| def wrapped_func(keyset, *args, **kwargs): |
| interpreter = retrieve_current_functorch_interpreter() |
| return custom_function_call_vmap_helper( |
| interpreter, func, op, *args, **kwargs |
| ) |
| |
| lib.impl(opname, wrapped_func, "FuncTorchBatched", with_keyset=True) |
| |
| if func is None: |
| return register |
| else: |
| return register(func) |
| |
| |
| # If the op was defined in C++, then we want to make sure there was an |
| # m.set_python_module(module, ...) call and that the module is the |
| # same as the module that called torch.library.register_fake. |
| def _check_pystubs_once(func, qualname, actual_module_name): |
| checked = False |
| |
| def inner(*args, **kwargs): |
| nonlocal checked |
| if checked: |
| return func(*args, **kwargs) |
| |
| op = torch._library.utils.lookup_op(qualname) |
| if op._defined_in_python: |
| checked = True |
| return func(*args, **kwargs) |
| |
| maybe_pystub = torch._C._dispatch_pystub( |
| op._schema.name, op._schema.overload_name |
| ) |
| if maybe_pystub is None: |
| if torch._library.utils.requires_set_python_module(): |
| namespace = op.namespace |
| cpp_filename = op._handle.debug() |
| raise RuntimeError( |
| f"Operator '{qualname}' was defined in C++ and has a Python " |
| f"fake impl. In this situation, we require there to also be a " |
| f'companion C++ `m.set_python_module("{actual_module_name}")` ' |
| f"call, but we could not find one. Please add that to " |
| f"to the top of the C++ TORCH_LIBRARY({namespace}, ...) block the " |
| f"operator was registered in ({cpp_filename})" |
| ) |
| else: |
| pystub_module = maybe_pystub[0] |
| if actual_module_name != pystub_module: |
| cpp_filename = op._handle.debug() |
| raise RuntimeError( |
| f"Operator '{qualname}' specified that its python fake impl " |
| f"is in the Python module '{pystub_module}' but it was actually found " |
| f"in '{actual_module_name}'. Please either move the fake impl " |
| f"or correct the m.set_python_module call ({cpp_filename})" |
| ) |
| checked = True |
| return func(*args, **kwargs) |
| |
| return inner |
| |
| |
| # NOTE [ctx inside the fake implementation] |
| # If a user has an operator with data-dependent output shape, then when writing |
| # a fake implementation they must query the current ctx and use methods on the |
| # ctx to construct a new unbacked symint. |
| # |
| # This is done via us setting the global_ctx_getter function every time a fake |
| # implementation is invoked. |
| def get_ctx() -> "torch._library.fake_impl.FakeImplCtx": |
| """get_ctx() returns the current AbstractImplCtx object. |
| |
| Calling ``get_ctx()`` is only valid inside of an fake impl |
| (see :func:`torch.library.register_fake` for more usage details. |
| """ |
| return torch._library.fake_impl.global_ctx_getter() |
| |
| |
| _OPCHECK_DEFAULT_UTILS = ( |
| "test_schema", |
| "test_autograd_registration", |
| "test_faketensor", |
| "test_aot_dispatch_dynamic", |
| ) |
| |
| |
| def opcheck( |
| op: Union[torch._ops.OpOverload, torch._ops.OpOverloadPacket, CustomOpDef], |
| args: Tuple[Any, ...], |
| kwargs: Optional[Dict[str, Any]] = None, |
| *, |
| test_utils: Union[str, Sequence[str]] = _OPCHECK_DEFAULT_UTILS, |
| raise_exception: bool = True, |
| ) -> Dict[str, str]: |
| """Given an operator and some sample arguments, tests if the operator is |
| registered correctly. |
| |
| That is, when you use the torch.library/TORCH_LIBRARY APIs to create a |
| custom op, you specified metadata (e.g. mutability info) about the custom op |
| and these APIs require that the functions you pass them satisfy certain |
| properties (e.g. no data pointer access in the fake/meta/abstract kernel) |
| ``opcheck`` tests these metadata and properties. |
| |
| Concretely, we test the following: |
| |
| - test_schema: If the schema matches the implementation of |
| the operator. For example: if the schema specifies a Tensor is mutated, |
| then we check the implementation mutates the Tensor. If the schema |
| specifies that we return a new Tensor, then we check that the |
| implementation returns a new Tensor (instead of an existing one or |
| a view of an existing one). |
| - test_autograd_registration: If the operator supports training |
| (autograd): we check that its autograd formula is registered via |
| torch.library.register_autograd or a manual registration to one |
| or more DispatchKey::Autograd keys. Any other DispatchKey-based |
| registrations may lead to undefined behavior. |
| - test_faketensor: If the operator has a FakeTensor kernel |
| (and if it is correct). The FakeTensor kernel is necessary ( |
| but not sufficient) for the operator to work with PyTorch compilation |
| APIs (torch.compile/export/FX). We check that a FakeTensor kernel |
| (also sometimes known as a meta kernel) was registered for the |
| operator and that it is correct. This test takes the result of |
| running the operator on real tensors and the result of running |
| the operator on FakeTensors and checks that they have the same |
| Tensor metadata (sizes/strides/dtype/device/etc). |
| - test_aot_dispatch_dynamic: If the operator has correct behavior |
| with PyTorch compilation APIs (torch.compile/export/FX). |
| This checks that the outputs (and gradients, if applicable) are the |
| same under eager-mode PyTorch and torch.compile. |
| This test is a superset of ``test_faketensor`` and is an e2e test; |
| other things it tests are that the operator supports |
| functionalization and that the backward pass (if it exists) also |
| supports FakeTensor and functionalization. |
| |
| For best results, please call ``opcheck`` multiple times with a |
| representative set of inputs. If your operator supports |
| autograd, please use ``opcheck`` with inputs with ``requires_grad = True``; |
| if your operator supports multiple devices (e.g. CPU and CUDA), please |
| use ``opcheck`` with inputs on all supported devices. |
| |
| Args: |
| op: The operator. Must either be a function decorated with |
| :func:`torch.library.custom_op` or an OpOverload/OpOverloadPacket |
| found in torch.ops.* (e.g. torch.ops.aten.sin, torch.ops.mylib.foo) |
| args: The args to the operator |
| kwargs: The kwargs to the operator |
| test_utils: Tests that we should run. Default: all of them. |
| Example: ("test_schema", "test_faketensor") |
| raise_exception: If we should raise an exception on the first |
| error. If False, we will return a dict with information |
| on if each test passed or not. |
| |
| .. warning:: |
| |
| opcheck and :func:`torch.autograd.gradcheck` test different things; |
| opcheck tests if your usage of torch.library APIs is correct while |
| :func:`torch.autograd.gradcheck` tests if your autograd formula is |
| mathematically correct. Use both to test custom ops that support |
| gradient computation. |
| |
| Example: |
| |
| >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA) |
| >>> @torch.library.custom_op("mylib::numpy_mul", mutates_args=()) |
| >>> def numpy_add(x: Tensor, y: float) -> Tensor: |
| >>> x_np = x.numpy(force=True) |
| >>> z_np = x_np + y |
| >>> return torch.from_numpy(z_np).to(x.device) |
| >>> |
| >>> @numpy_sin.register_fake |
| >>> def _(x, y): |
| >>> return torch.empty_like(x) |
| >>> |
| >>> def setup_context(ctx, inputs, output): |
| >>> y, = inputs |
| >>> ctx.y = y |
| >>> |
| >>> def backward(ctx, grad): |
| >>> return grad * ctx.y, None |
| >>> |
| >>> numpy_sin.register_autograd(backward, setup_context=setup_context) |
| >>> |
| >>> sample_inputs = [ |
| >>> (torch.randn(3), 3.14), |
| >>> (torch.randn(2, 3, device='cuda'), 2.718), |
| >>> (torch.randn(1, 10, requires_grad=True), 1.234), |
| >>> (torch.randn(64, 64, device='cuda', requires_grad=True), 90.18), |
| >>> ] |
| >>> |
| >>> for args in sample_inputs: |
| >>> torch.library.opcheck(foo, args) |
| |
| """ |
| import torch.testing._internal.optests as optests |
| |
| return optests.opcheck( |
| op, args, kwargs, test_utils=test_utils, raise_exception=raise_exception |
| ) |