| import contextlib |
| import functools |
| import warnings |
| from typing import Callable, Optional |
| |
| import torch |
| from torch._library.utils import Kernel, RegistrationHandle |
| |
| |
| class AbstractImplHolder: |
| """A holder where one can register an abstract impl to.""" |
| |
| def __init__(self, qualname: str): |
| self.qualname: str = qualname |
| self.kernel: Optional[Kernel] = None |
| self.lib: Optional[torch.library.Library] = None |
| |
| def register(self, func: Callable, source: str) -> RegistrationHandle: |
| """Register an abstract impl. |
| |
| Returns a RegistrationHandle that one can use to de-register this |
| abstract impl. |
| """ |
| if self.kernel is not None: |
| raise RuntimeError( |
| f"impl_abstract(...): the operator {self.qualname} " |
| f"already has an abstract impl registered at " |
| f"{self.kernel.source}." |
| ) |
| if torch._C._dispatch_has_kernel_for_dispatch_key(self.qualname, "Meta"): |
| raise RuntimeError( |
| f"impl_abstract(...): the operator {self.qualname} " |
| f"already has an DispatchKey::Meta implementation via a " |
| f"pre-existing torch.library or TORCH_LIBRARY registration. " |
| f"Please either remove that registration or don't call " |
| f"impl_abstract." |
| ) |
| |
| if torch._C._dispatch_has_kernel_for_dispatch_key( |
| self.qualname, "CompositeImplicitAutograd" |
| ): |
| raise RuntimeError( |
| f"impl_abstract(...): the operator {self.qualname} " |
| f"already has an implementation for this device type via a " |
| f"pre-existing registration to " |
| f"DispatchKey::CompositeImplicitAutograd." |
| f"CompositeImplicitAutograd operators do not need an abstract " |
| f"impl; " |
| f"instead, the operator will decompose into its constituents " |
| f"and those " |
| f"can have abstract impls defined on them." |
| ) |
| |
| # Store the kernel in this holder |
| self.kernel = Kernel(func, source) |
| |
| # Also register the abstract impl to Meta key |
| if self.lib is None: |
| ns = self.qualname.split("::")[0] |
| self.lib = torch.library.Library(ns, "FRAGMENT") |
| meta_kernel = construct_meta_kernel(self.qualname, self) |
| self.lib.impl(self.qualname, meta_kernel, "Meta") |
| |
| def deregister_abstract_impl(): |
| if self.lib: |
| self.lib._destroy() |
| self.lib = None |
| self.kernel = None |
| |
| return RegistrationHandle(deregister_abstract_impl) |
| |
| |
| def construct_meta_kernel( |
| qualname: str, abstract_impl_holder: AbstractImplHolder |
| ) -> Callable: |
| assert abstract_impl_holder.kernel is not None |
| |
| @functools.wraps(abstract_impl_holder.kernel.func) |
| def meta_kernel(*args, **kwargs): |
| assert abstract_impl_holder.kernel is not None |
| source = abstract_impl_holder.kernel.source |
| |
| def error_on_ctx(): |
| raise RuntimeError( |
| f"Attempted to call get_ctx() for the meta implementation " |
| f"for {qualname} (implemented at {source})" |
| f"You have presumably called get_ctx() because the operator " |
| f"has a data-dependent output shape; if so, there is no " |
| f"such meta implementation and this error is the correct " |
| f"behavior." |
| ) |
| |
| with set_ctx_getter(error_on_ctx): |
| return abstract_impl_holder.kernel(*args, **kwargs) |
| |
| return meta_kernel |
| |
| |
| def get_none(): |
| return None |
| |
| |
| global_ctx_getter: Callable = get_none |
| |
| |
| @contextlib.contextmanager |
| def set_ctx_getter(ctx_getter): |
| global global_ctx_getter |
| prev = global_ctx_getter |
| try: |
| global_ctx_getter = ctx_getter |
| yield |
| finally: |
| global_ctx_getter = prev |
| |
| |
| class AbstractImplCtx: |
| """ |
| Context object for writing abstract implementations for custom operators. |
| """ |
| |
| def __init__(self, _shape_env, _op): |
| self._shape_env = _shape_env |
| self._op = _op |
| |
| def create_unbacked_symint(self, *, min=2, max=None) -> torch.SymInt: |
| warnings.warn( |
| "create_unbacked_symint is deprecated, please use new_dynamic_size instead" |
| ) |
| return self.new_dynamic_size(min=min, max=max) |
| |
| def new_dynamic_size(self, *, min=0, max=None) -> torch.SymInt: |
| """Constructs a new symint (symbolic int) representing a data-dependent value. |
| |
| This is useful for writing the abstract implementation (which is necessary |
| for torch.compile) for a CustomOp where an output Tensor has a size |
| that depends on the data of the input Tensors. |
| |
| Args: |
| min (int): A statically known inclusive lower bound for this symint. Default: 0 |
| max (Optional[int]): A statically known inclusive upper bound for this |
| symint. Default: None |
| |
| .. warning: |
| |
| It is important that the ``min`` and ``max`` (if not None) values are set |
| correctly, otherwise, there will be undefined behavior under |
| torch.compile. The default value of ``min`` is 2 due to torch.compile |
| specializing on 0/1 sizes. |
| |
| You must also verify that your implementation on concrete Tensors |
| (e.g. CPU/CUDA) only returns Tensors where the size that corresponds |
| to the symint also has respects these constraint. |
| The easiest way to do this is to add an assertion in the CPU/CUDA/etc |
| implementation that the size follows these bounds. |
| |
| Example:: |
| |
| >>> # An operator with data-dependent output shape |
| >>> lib = torch.library.Library("mymodule", "FRAGMENT") |
| >>> lib.define("mymodule::custom_nonzero(Tensor x) -> Tensor") |
| >>> |
| >>> @torch.library.impl_abstract("mymodule::custom_nonzero") |
| >>> def custom_nonzero_abstract(x): |
| >>> # Number of nonzero-elements is data-dependent. |
| >>> # Since we cannot peek at the data in an abstract impl, |
| >>> # we use the ctx object to construct a new symint that |
| >>> # represents the data-dependent size. |
| >>> ctx = torch.library.get_ctx() |
| >>> nnz = ctx.new_dynamic_size() |
| >>> shape = [nnz, x.dim()] |
| >>> result = x.new_empty(shape, dtype=torch.int64) |
| >>> return result |
| >>> |
| >>> @torch.library.impl(lib, "custom_nonzero", "CPU") |
| >>> def custom_nonzero_cpu(x): |
| >>> x_np = x.numpy() |
| >>> res = np.stack(np.nonzero(x_np), axis=1) |
| >>> return torch.tensor(res, device=x.device) |
| |
| """ |
| if ( |
| self._shape_env is None |
| or not self._shape_env.allow_dynamic_output_shape_ops |
| ): |
| raise torch._subclasses.fake_tensor.DynamicOutputShapeException(self._op) |
| |
| if isinstance(min, torch.SymInt) or isinstance(max, torch.SymInt): |
| raise ValueError( |
| f"ctx.new_dynamic_size(min={min}, max={max}): expected " |
| f"min and max to be statically known ints but got SymInt. " |
| f"This is not supported." |
| ) |
| |
| if min < 0: |
| raise ValueError( |
| f"ctx.new_dynamic_size(min={min}, ...): expected min to be " |
| f"greater than or equal to 0: this API can only create " |
| f"non-negative sizes." |
| ) |
| |
| result = self._shape_env.create_unbacked_symint() |
| torch.fx.experimental.symbolic_shapes._constrain_range_for_size( |
| result, min=min, max=max |
| ) |
| return result |