blob: 14c5c35ed45d880806682aed71350fdb2aefc658 [file] [log] [blame]
import contextlib
import warnings
from torch._C import _len_torch_dispatch_stack, _get_dispatch_stack_at,\
_pop_torch_dispatch_stack, _push_on_torch_dispatch_stack
# TODO: Limitations and things about enable_torch_dispatch_mode we should fix before exposing it:
# - We need a better user-facing api for _DisableTorchDispatch that
# is able to selectively disable __torch_dispatch__ of a particular class.
# - It doesn't work with the tensor constructors (torch.tensor, torch.Tensor)
# - Better name (see https://github.com/pytorch/pytorch/pull/63496#discussion_r694091694)
class TorchDispatchMode:
"""
A ``TorchDispatchMode`` allows you to override the meaning of all
``__torch_dispatch__`` overrideable functions within a dynamic scope,
without having to actually create a tensor subclass or manually
monkey-patch functions in the PyTorch API. Some common situations
where you should use a mode:
* You want to override the meaning of factory functions, or other
functions that do not otherwise take a tensor as an argument
(these cannot be overridden with tensor subclasses).
* You want to override the behavior of all functions without needing
to wrap your inputs in tensor subclasses; e.g., if you are just
interested in logging intermediate computations.
* You want to control the order of execution of various tensor
subclasses explicitly, rather than implicitly via the return of
``NotImplemented``.
Independent subclasses of :class:`TorchDispatchMode` are compositional:
modes can be pushed onto a stack using ``with MyMode():``.
When you call functions in the PyTorch API inside your
``__torch_dispatch__`` implementation, by default, they will forward on to
the next mode on the mode stack. If you want recursively call back into
your current ``__torch_dispatch__`` implementation, either explicitly
invoke ``self.__torch_dispatch__(...)``, or use the context manager
``__torch_dispatch__(self)`` to make PyTorch
API self-referential (beware of infinite loops, in this case!)
"""
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
raise NotImplementedError()
def __enter__(self):
_push_mode(self)
return self
def __exit__(self, exc_type, exc_val, exc_tb):
_pop_mode()
@classmethod
def push(cls, *args, **kwargs):
warnings.warn("`Mode.push()` is no longer necessary and can be replaced with just `with Mode()`")
instance = cls(*args, **kwargs)
return instance
def _get_current_dispatch_mode():
stack_len = _len_torch_dispatch_stack()
return _get_dispatch_stack_at(stack_len - 1) if stack_len > 0 else None
def _get_current_dispatch_mode_stack():
stack_len = _len_torch_dispatch_stack()
return [_get_dispatch_stack_at(i) for i in range(stack_len)]
def _push_mode(mode):
_push_on_torch_dispatch_stack(mode)
def _pop_mode():
return _pop_torch_dispatch_stack()
@contextlib.contextmanager
def _pop_mode_temporarily():
old = _pop_mode()
try:
yield old
finally:
_push_mode(old)
@contextlib.contextmanager
def _disable_current_modes():
mode_len = _len_torch_dispatch_stack()
old_modes = [_pop_mode() for _ in range(mode_len)]
try:
yield old_modes
finally:
for mode in reversed(old_modes):
_push_mode(mode)
class BaseTorchDispatchMode(TorchDispatchMode):
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
return func(*args, **kwargs)