| from ._ops import OpOverload |
| from typing import Set |
| import traceback |
| import torch |
| |
| __all__ = ['Library', 'impl', 'define'] |
| |
| # Set containing the combination of (namespace, operator, DispatchKey) for which a new kernel has been registered |
| # The keys in the set are of the form `namespace + "/" + op_name + "/" + dispatch_key`. |
| # This set is maintained to ensure that two libraries don't try to override the exact same functionality to avoid |
| # libraries calling into kernels not intended to be called. |
| _impls: Set[str] = set() |
| |
| # prim is reserved by TorchScript interpreter |
| _reserved_namespaces = ['prim'] |
| |
| class Library: |
| """ |
| A class to create libraries that can be used to register new operators or |
| override operators in existing libraries from Python. |
| A user can optionally pass in a dispatch keyname if they only want to register |
| kernels corresponding to only one specific dispatch key. |
| |
| To create a library to override operators in an existing library (with name ns), set the kind to "IMPL". |
| To create a new library (with name ns) to register new operators, set the kind to "DEF". |
| Args: |
| ns: library name |
| kind: "DEF", "IMPL" (default: "IMPL") |
| dispatch_key: PyTorch dispatch key (default: "") |
| """ |
| def __init__(self, ns, kind, dispatch_key=""): |
| if kind != "IMPL" and kind != "DEF": |
| raise ValueError("Unsupported kind: ", kind) |
| |
| if ns in _reserved_namespaces and kind == "DEF": |
| raise ValueError(ns, " is a reserved namespace. Please try creating a library with another name.") |
| |
| frame = traceback.extract_stack(limit=3)[0] |
| filename, lineno = frame.filename, frame.lineno |
| self.m = torch._C._dispatch_library(kind, ns, dispatch_key, filename, lineno) |
| self.ns = ns |
| self._op_impls = set() |
| self.kind = kind |
| self.dispatch_key = dispatch_key |
| |
| def __repr__(self): |
| return "Library(kind={}, ns={}, dispatch_key={})>".format(self.kind, self.ns, self.dispatch_key) |
| |
| def define(self, schema, alias_analysis=""): |
| r'''Defines a new operator and its semantics in the ns namespace. |
| |
| Args: |
| schema: function schema to define a new operator. |
| alias_analysis (optional): Indicates if the aliasing properties of the operator arguments can be |
| inferred from the schema (default behavior) or not ("CONSERVATIVE"). |
| Returns: |
| name of the operator as inferred from the schema. |
| |
| Example:: |
| >>> my_lib = Library("foo", "DEF") |
| >>> my_lib.define("sum(Tensor self) -> Tensor") |
| ''' |
| # This is added because we also want to disallow PURE_FUNCTION alias analysis which is a valid |
| # AliasAnalysis type in C++ |
| if alias_analysis not in ["", "FROM_SCHEMA", "CONSERVATIVE"]: |
| raise RuntimeError("Invalid alias_analysis type {}".format(alias_analysis)) |
| return self.m.define(schema, alias_analysis) |
| |
| def impl(self, op_name, fn, dispatch_key=''): |
| r'''Registers the function implementation for an operator defined in the library. |
| |
| Args: |
| op_name: operator name (along with the overload) or OpOverload object. |
| fn: function that's the operator implementation for the input dispatch key. |
| dispatch_key: dispatch key that the input function should be registered for. By default, it uses |
| the dispatch key that the library was created with. |
| |
| Example:: |
| >>> # xdoctest: +SKIP |
| >>> my_lib = Library("aten", "IMPL") |
| >>> def div_cpu(self, other): |
| >>> return self * (1 / other) |
| >>> my_lib.impl("div.Tensor", "CPU") |
| ''' |
| if not callable(fn): |
| raise TypeError("Input function is required to be a callable but found type {}".format(type(fn))) |
| if dispatch_key == '': |
| dispatch_key = self.dispatch_key |
| |
| if isinstance(op_name, str): |
| name = op_name |
| elif isinstance(op_name, OpOverload): |
| name = op_name._schema.name |
| overload_name = op_name._schema.overload_name |
| if overload_name != '': |
| name = name + '.' + overload_name |
| else: |
| raise RuntimeError("impl should be passed either a name or an OpOverload object as the first argument") |
| |
| key = self.ns + "/" + name.split("::")[-1] + "/" + dispatch_key |
| if key in _impls: |
| # TODO: in future, add more info about where the existing function is registered (this info is |
| # today already returned by the C++ warning when impl is called but we error out before that) |
| raise RuntimeError("This is not allowed since there's already a kernel registered from python overriding {}" |
| "'s behavior for {} dispatch key and {} namespace.". |
| format(name.split("::")[-1], dispatch_key, self.ns)) |
| |
| |
| if dispatch_key == "Meta": |
| dispatcher_op_name = name |
| if '::' not in dispatcher_op_name: |
| dispatcher_op_name = f'{self.ns}::{dispatcher_op_name}' |
| |
| # Internally, we shouldn't be registering meta kernels for any operators that |
| # have CompositeImplicitAutograd kernels. |
| # Instead, we should be letting those decompositions run, and writing meta kernels |
| # only for the base operators. |
| if torch._C._dispatch_has_kernel_for_dispatch_key(dispatcher_op_name, "CompositeImplicitAutograd"): |
| raise RuntimeError( |
| f"We should not register a meta kernel directly to the operator '{name}'," |
| " because it has a CompositeImplicitAutograd kernel in core." |
| " Instead we should let the operator decompose, and ensure that we have meta kernels" |
| " for the base ops that it decomposes into.") |
| |
| self.m.impl(name, dispatch_key if dispatch_key != "" else "CompositeImplicitAutograd", fn) |
| |
| _impls.add(key) |
| self._op_impls.add(key) |
| |
| def __del__(self): |
| # _op_impls might not have been initialized if an error was thrown in __init__ |
| _op_impls_ = getattr(self, '_op_impls', None) |
| if _op_impls_: |
| for key in self._op_impls: |
| _impls.remove(key) |
| del self.m |
| |
| # decorator to register python functions for library ops |
| # Note: this decorator API should remain consistent with `Library.impl` API |
| def impl(lib, name, dispatch_key=""): |
| def wrap(f): |
| lib.impl(name, f, dispatch_key) |
| return f |
| return wrap |
| |
| def define(lib, schema, alias_analysis=""): |
| def wrap(f): |
| name = lib.define(schema, alias_analysis) |
| lib.impl(name, f) |
| return f |
| return wrap |