| import contextlib |
| from typing import Iterator, Set |
| import functools |
| |
| from torch.utils._mode_utils import _enable_mode, _push_mode, _ModeInfo, _wrap_init, _restore_mode |
| from torch._C import _get_torch_dispatch_mode, _set_torch_dispatch_mode |
| from dataclasses import dataclass |
| |
| |
| @dataclass |
| class TorchDispatchModeInfo(_ModeInfo): |
| def __init__(self): |
| super().__init__(mode_name="torch_dispatch", mode_class=TorchDispatchMode, |
| base_mode_class=BaseTorchDispatchMode) |
| |
| def get_mode(self): |
| return _get_torch_dispatch_mode() |
| |
| def set_mode(self, mode): |
| return _set_torch_dispatch_mode(mode) |
| |
| |
| # TODO: Limitations and things about enable_torch_dispatch_mode we should fix before exposing it: |
| # - We need a better user-facing api for torch._C._DisableTorchDispatch that |
| # is able to selectively disable __torch_dispatch__ of a particular class. |
| # - It doesn't work with the tensor constructors (torch.tensor, torch.Tensor) |
| # - Better name (see https://github.com/pytorch/pytorch/pull/63496#discussion_r694091694) |
| @contextlib.contextmanager |
| def enable_torch_dispatch_mode(mode, *, replace=None, ignore_preexisting=False) -> Iterator[None]: |
| """ |
| Context manager that causes all pytorch operators to dispatch to the passed-in |
| type's __torch_dispatch__ function, including operations that accept no tensors |
| but return a tensor. |
| |
| This function is non-compositional; if there is already an existing mode, |
| it will raise an error |
| |
| This function is safe to use inside a ``__torch_dispatch__`` mode handler, |
| as the mode is guaranteed to be disabled in this context. You can use |
| this context manager to reinstate the mode so that calls to overridable |
| APIs recursively call back into your mode handler (this can easily cause |
| infinite loops, so use with care!) |
| |
| enable_torch_dispatch_mode is affected by _DisableTorchDispatch. |
| |
| Args: |
| mode (:class:`TorchDispatchMode`, Tensor-like class, or None): the |
| mode to set as current mode. If you pass a Tensor-like class, |
| it will be treated as a non-compositional mode with no state, |
| which is convenient if you have an existing tensor subclass |
| that you'd like to apply globally in a quick and dirty way. |
| Passing None will disable the current mode. |
| replace (:class:`TorchDispatchMode` or Tensor-like class): the |
| mode to replace. You can use this argument to change the mode in |
| a situation where you know what the current mode is (and you are |
| intentionally overwriting it.) If you don't know what the current |
| mode is, use ``ignore_preexisting`` instead. |
| ignore_preexisting (bool): if True, ignore any preexisting mode |
| and overwrite it with the passed mode. |
| """ |
| |
| return _enable_mode(mode, mode_info=TorchDispatchModeInfo(), replace=replace, ignore_preexisting=ignore_preexisting) |
| |
| |
| def _wrap_torch_dispatch(f): |
| @functools.wraps(f) |
| def wrapped(self, *args, **kwargs): |
| if isinstance(f, classmethod): |
| raise RuntimeError("TorchDispatchMode's torch_dispatch function " + |
| "should be a normal method not a class method") |
| inner = getattr(self, "inner", None) |
| |
| with enable_torch_dispatch_mode(inner): |
| return f(self, *args, **kwargs) |
| return wrapped |
| |
| |
| # Implementation note, since this is based on TorchFunctionMode, this had the |
| # same dilemma: I had a choice about how much of mode stacks |
| # to implement in Python versus in C++. At time of writing, I did not care |
| # too much about implementation efficiency; however, I do care about making it |
| # hard for users to implement modes in the wrong way. In the end, it turned |
| # out to be possible to implement mode stacks entirely from userland, with the |
| # C++ API providing only _get_torch_dispatch_mode() and |
| # _set_torch_dispatch_mode(), so I opted to provide some unsafe C++ bindings and |
| # have the bulk of the logic for managing the stack in Python, which helped |
| # simplify the C++ API surface. It would also have been valid to build in the |
| # notion of mode stack directly into C++ but in this design it's substantially |
| # more difficult to interact with TorchDispatchModeMeta. |
| |
| class TorchDispatchModeMeta(type): |
| """ |
| Metaclass for :class:`TorchDispatchMode`; it does two things: |
| |
| * Adds an implicit ``inner`` kwarg to ``__init__``, to |
| allow the modes to be chained together to form a stack. |
| |
| * Reenables the inner mode, so that by default PyTorch API calls |
| will compositionally proceed to the next mode on the stack. |
| |
| The default behavior for the second bullet is important, as it is easy to |
| accidentally write ``_wrap_torch_dispatch`` implementations that are not |
| compositional, and the wrapping here makes the obvious code do the |
| right thing (aka, this is why there is a metaclass). |
| """ |
| def __new__(metacls, name, bases, dct): |
| if '__init__' in dct: |
| dct['__init__'] = _wrap_init(dct['__init__']) |
| if '__torch_dispatch__' in dct: |
| dct['__torch_dispatch__'] = _wrap_torch_dispatch(dct['__torch_dispatch__']) |
| return super().__new__(metacls, name, bases, dct) |
| |
| |
| class TorchDispatchMode(metaclass=TorchDispatchModeMeta): |
| """ |
| A ``TorchDispatchMode`` allows you to override the meaning of all |
| ``__torch_dispatch__`` overrideable functions within a dynamic scope, |
| without having to actually create a tensor subclass or manually |
| monkey-patch functions in the PyTorch API. Some common situations |
| where you should use a mode: |
| |
| * You want to override the meaning of factory functions, or other |
| functions that do not otherwise take a tensor as an argument |
| (these cannot be overridden with tensor subclasses). |
| |
| * You want to override the behavior of all functions without needing |
| to wrap your inputs in tensor subclasses; e.g., if you are just |
| interested in logging intermediate computations. |
| |
| * You want to control the order of execution of various tensor |
| subclasses explicitly, rather than implicitly via the return of |
| ``NotImplemented``. |
| |
| Independent subclasses of :class:`TorchDispatchMode` are compositional: |
| modes can be pushed onto a stack with :func:`push_torch_dispatch_mode`. |
| When you call functions in the PyTorch API inside your |
| ``__torch_dispatch__`` implementation, by default, they will forward on to |
| the next mode on the mode stack. If you want recursively call back into |
| your current ``__torch_dispatch__`` implementation, either explicitly |
| invoke ``self.__torch_dispatch__(...)``, or use the context manager |
| ``__torch_dispatch__(self, replace=self.inner)`` to make PyTorch |
| API self-referential (beware of infinite loops, in this case!) |
| """ |
| # Force metaclass to generate constructor at the base of the hierarchy |
| def __init__(self): |
| self.ancestors: Set[TorchDispatchMode] |
| |
| def __torch_dispatch__(self, func, types, args=(), kwargs=None): |
| raise NotImplementedError() |
| |
| def __enter__(self): |
| old = _get_torch_dispatch_mode() |
| if hasattr(self, "inner"): |
| raise RuntimeError(f"{self} has already been used as a mode. Please use a fresh version or use restore") |
| else: |
| self.inner = old |
| if old is None: |
| self.ancestors = set() |
| else: |
| self.ancestors = self.inner.ancestors.union({self.inner}) |
| _set_torch_dispatch_mode(self) |
| return self |
| |
| def __exit__(self, exc_type, exc_val, exc_tb): |
| _set_torch_dispatch_mode(self.inner) |
| |
| @contextlib.contextmanager |
| def restore(self): |
| return _restore_mode(self, mode_info=TorchDispatchModeInfo()) |
| |
| @classmethod |
| def push(cls, *args, **kwargs): |
| return push_torch_dispatch_mode(functools.partial(cls, *args, **kwargs)) |
| |
| |
| class BaseTorchDispatchMode(TorchDispatchMode): |
| def __torch_dispatch__(self, func, types, args=(), kwargs=None): |
| if kwargs is None: |
| kwargs = {} |
| return func(*args, **kwargs) |
| |
| @contextlib.contextmanager |
| def push_torch_dispatch_mode(ctor) -> Iterator[object]: |
| return _push_mode(ctor, mode_info=TorchDispatchModeInfo()) |