| from typing import Any, Optional |
| |
| import torch |
| |
| from torch.utils._contextlib import ( |
| _DecoratorContextManager, |
| _NoParamDecoratorContextManager, |
| F, |
| ) |
| |
| __all__ = [ |
| "no_grad", |
| "enable_grad", |
| "set_grad_enabled", |
| "inference_mode", |
| "set_multithreading_enabled", |
| ] |
| |
| |
| class no_grad(_NoParamDecoratorContextManager): |
| r"""Context-manager that disables gradient calculation. |
| |
| Disabling gradient calculation is useful for inference, when you are sure |
| that you will not call :meth:`Tensor.backward()`. It will reduce memory |
| consumption for computations that would otherwise have `requires_grad=True`. |
| |
| In this mode, the result of every computation will have |
| `requires_grad=False`, even when the inputs have `requires_grad=True`. |
| There is an exception! All factory functions, or functions that create |
| a new Tensor and take a requires_grad kwarg, will NOT be affected by |
| this mode. |
| |
| This context manager is thread local; it will not affect computation |
| in other threads. |
| |
| Also functions as a decorator. |
| |
| .. note:: |
| No-grad is one of several mechanisms that can enable or |
| disable gradients locally see :ref:`locally-disable-grad-doc` for |
| more information on how they compare. |
| |
| .. note:: |
| This API does not apply to :ref:`forward-mode AD <forward-mode-ad>`. |
| If you want to disable forward AD for a computation, you can unpack |
| your dual tensors. |
| |
| Example:: |
| >>> # xdoctest: +SKIP |
| >>> x = torch.tensor([1.], requires_grad=True) |
| >>> with torch.no_grad(): |
| ... y = x * 2 |
| >>> y.requires_grad |
| False |
| >>> @torch.no_grad() |
| ... def doubler(x): |
| ... return x * 2 |
| >>> z = doubler(x) |
| >>> z.requires_grad |
| False |
| >>> @torch.no_grad |
| ... def tripler(x): |
| ... return x * 3 |
| >>> z = tripler(x) |
| >>> z.requires_grad |
| False |
| >>> # factory function exception |
| >>> with torch.no_grad(): |
| ... a = torch.nn.Parameter(torch.rand(10)) |
| >>> a.requires_grad |
| True |
| """ |
| |
| def __init__(self) -> None: |
| if not torch._jit_internal.is_scripting(): |
| super().__init__() |
| self.prev = False |
| |
| def __enter__(self) -> None: |
| self.prev = torch.is_grad_enabled() |
| torch.set_grad_enabled(False) |
| |
| def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: |
| torch.set_grad_enabled(self.prev) |
| |
| |
| class enable_grad(_NoParamDecoratorContextManager): |
| r"""Context-manager that enables gradient calculation. |
| |
| Enables gradient calculation, if it has been disabled via :class:`~no_grad` |
| or :class:`~set_grad_enabled`. |
| |
| This context manager is thread local; it will not affect computation |
| in other threads. |
| |
| Also functions as a decorator. |
| |
| .. note:: |
| enable_grad is one of several mechanisms that can enable or |
| disable gradients locally see :ref:`locally-disable-grad-doc` for |
| more information on how they compare. |
| |
| .. note:: |
| This API does not apply to :ref:`forward-mode AD <forward-mode-ad>`. |
| |
| Example:: |
| >>> # xdoctest: +SKIP |
| >>> x = torch.tensor([1.], requires_grad=True) |
| >>> with torch.no_grad(): |
| ... with torch.enable_grad(): |
| ... y = x * 2 |
| >>> y.requires_grad |
| True |
| >>> y.backward() |
| >>> x.grad |
| tensor([2.]) |
| >>> @torch.enable_grad() |
| ... def doubler(x): |
| ... return x * 2 |
| >>> with torch.no_grad(): |
| ... z = doubler(x) |
| >>> z.requires_grad |
| True |
| >>> @torch.enable_grad |
| ... def tripler(x): |
| ... return x * 3 |
| >>> with torch.no_grad(): |
| ... z = tripler(x) |
| >>> z.requires_grad |
| True |
| |
| """ |
| |
| def __enter__(self) -> None: |
| self.prev = torch.is_grad_enabled() |
| torch._C._set_grad_enabled(True) |
| |
| def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: |
| torch._C._set_grad_enabled(self.prev) |
| |
| |
| class set_grad_enabled(_DecoratorContextManager): |
| r"""Context-manager that sets gradient calculation on or off. |
| |
| ``set_grad_enabled`` will enable or disable grads based on its argument :attr:`mode`. |
| It can be used as a context-manager or as a function. |
| |
| This context manager is thread local; it will not affect computation |
| in other threads. |
| |
| Args: |
| mode (bool): Flag whether to enable grad (``True``), or disable |
| (``False``). This can be used to conditionally enable |
| gradients. |
| |
| .. note:: |
| set_grad_enabled is one of several mechanisms that can enable or |
| disable gradients locally see :ref:`locally-disable-grad-doc` for |
| more information on how they compare. |
| |
| .. note:: |
| This API does not apply to :ref:`forward-mode AD <forward-mode-ad>`. |
| |
| Example:: |
| >>> # xdoctest: +SKIP |
| >>> x = torch.tensor([1.], requires_grad=True) |
| >>> is_train = False |
| >>> with torch.set_grad_enabled(is_train): |
| ... y = x * 2 |
| >>> y.requires_grad |
| False |
| >>> _ = torch.set_grad_enabled(True) |
| >>> y = x * 2 |
| >>> y.requires_grad |
| True |
| >>> _ = torch.set_grad_enabled(False) |
| >>> y = x * 2 |
| >>> y.requires_grad |
| False |
| |
| """ |
| |
| def __init__(self, mode: bool) -> None: |
| self.prev = torch.is_grad_enabled() |
| self.mode = mode |
| torch._C._set_grad_enabled(mode) |
| |
| def __call__(self, orig_func: F) -> F: |
| torch._C._set_grad_enabled(self.prev) |
| return super().__call__(orig_func) |
| |
| def __enter__(self) -> None: |
| torch._C._set_grad_enabled(self.mode) |
| |
| def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: |
| torch._C._set_grad_enabled(self.prev) |
| |
| def clone(self) -> "set_grad_enabled": |
| return self.__class__(self.mode) |
| |
| |
| class inference_mode(_DecoratorContextManager): |
| r"""Context-manager that enables or disables inference mode. |
| |
| InferenceMode is a new context manager analogous to :class:`~no_grad` |
| to be used when you are certain your operations will have no interactions |
| with autograd (e.g., model training). Code run under this mode gets better |
| performance by disabling view tracking and version counter bumps. Note that |
| unlike some other mechanisms that locally enable or disable grad, |
| entering inference_mode also disables to :ref:`forward-mode AD <forward-mode-ad>`. |
| |
| This context manager is thread local; it will not affect computation |
| in other threads. |
| |
| Also functions as a decorator. |
| |
| .. note:: |
| Inference mode is one of several mechanisms that can enable or |
| disable gradients locally see :ref:`locally-disable-grad-doc` for |
| more information on how they compare. |
| |
| Args: |
| mode (bool or function): Either a boolean flag whether to enable or |
| disable inference mode or a Python function to decorate with |
| inference mode enabled |
| |
| Example:: |
| >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD) |
| >>> import torch |
| >>> x = torch.ones(1, 2, 3, requires_grad=True) |
| >>> with torch.inference_mode(): |
| ... y = x * x |
| >>> y.requires_grad |
| False |
| >>> # xdoctest: +SKIP("want string isnt quite right") |
| >>> y._version |
| Traceback (most recent call last): |
| File "<stdin>", line 1, in <module> |
| RuntimeError: Inference tensors do not track version counter. |
| >>> @torch.inference_mode() |
| ... def func(x): |
| ... return x * x |
| >>> out = func(x) |
| >>> out.requires_grad |
| False |
| >>> @torch.inference_mode |
| ... def doubler(x): |
| ... return x * 2 |
| >>> out = doubler(x) |
| >>> out.requires_grad |
| False |
| |
| """ |
| |
| def __init__(self, mode: bool = True) -> None: |
| if not torch._jit_internal.is_scripting(): |
| super().__init__() |
| # Holds a context manager that can enable or disable inference mode |
| self._inference_mode_raii_context: Optional[torch._C._InferenceMode] = None |
| self.mode = mode |
| |
| def __new__(cls, mode=True): |
| if isinstance(mode, bool): |
| return super().__new__(cls) |
| return cls()(mode) |
| |
| def __enter__(self) -> None: |
| self._inference_mode_context = torch._C._InferenceMode(self.mode) |
| self._inference_mode_context.__enter__() |
| |
| def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: |
| self._inference_mode_context.__exit__(exc_type, exc_value, traceback) |
| |
| def clone(self) -> "inference_mode": |
| return self.__class__(self.mode) |
| |
| |
| def _enter_inference_mode(mode): |
| mode_context = torch._C._InferenceMode(mode) |
| mode_context.__enter__() |
| return mode_context |
| |
| |
| def _exit_inference_mode(mode): |
| mode.__exit__(None, None, None) |
| |
| |
| class set_multithreading_enabled(_DecoratorContextManager): |
| r"""Context-manager that sets multithreaded backwards on or off. |
| |
| ``set_multithreading_enabled`` will enable or disable multithreaded backwards based on its argument :attr:`mode`. |
| It can be used as a context-manager or as a function. |
| |
| This context manager is thread local; it will not affect computation |
| in other threads. |
| |
| Args: |
| mode (bool): Flag whether to enable multithreaded backwards (``True``), or disable |
| (``False``). |
| |
| .. note:: |
| This API does not apply to :ref:`forward-mode AD <forward-mode-ad>`. |
| |
| """ |
| |
| def __init__(self, mode: bool) -> None: |
| self.prev = torch._C._is_multithreading_enabled() |
| torch._C._set_multithreading_enabled(mode) |
| self.mode = mode |
| |
| def __enter__(self) -> None: |
| pass |
| |
| def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: |
| torch._C._set_multithreading_enabled(self.prev) |
| |
| def clone(self) -> "set_multithreading_enabled": |
| return self.__class__(self.mode) |
| |
| |
| class _force_original_view_tracking(_DecoratorContextManager): |
| r"""Context-manager that sets whether or not to always enable view-replay in autograd. |
| |
| ``set_view_replay_enabled`` will enable or disable view-replay based on its argument :attr:`mode`. |
| It can be used as a context-manager or as a function. |
| |
| This context manager is thread local; it will not affect computation |
| in other threads. |
| |
| When a tensor view is mutated, the autograd engine needs to decide whether or not |
| to regenerate the "updated view" by either replaying the chain of views from the updated base, |
| or with a single call to as_strided. |
| |
| If set_view_replay_enabled is set to True, then autograd will always use view replay. |
| Otherwise, it will fall back to its existing logic. |
| |
| Args: |
| mode (bool): Flag whether to enable view-replay (``True``), or disable |
| (``False``). |
| |
| """ |
| |
| def __init__(self, mode: bool) -> None: |
| self.prev = torch._C._is_view_replay_enabled() |
| torch._C._set_view_replay_enabled(mode) |
| self.mode = mode |
| |
| def __enter__(self) -> None: |
| pass |
| |
| def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: |
| torch._C._set_view_replay_enabled(self.prev) |
| |
| def clone(self): |
| return self.__class__(self.mode) |
| |
| |
| class _unsafe_preserve_version_counter(_DecoratorContextManager): |
| r"""DO NOT USE THIS UNLESS YOU KNOW EXACTLY WHAT YOU'RE DOING. |
| |
| This context manager can lead to arbitrary silent-correctness issues in any other part of your code |
| (even the ones not touched directly by the context manager)! |
| |
| Ordinarily, autograd will track mutations to tensors by incrementing it's `._version` attribute. |
| This is generally important for correctness, as for example, mutating a tensor that autograd has saved |
| for the backwards pass can result in incorrect gradients, and autograd uses the version counter to detect |
| and error out in this situation. |
| |
| However, there are rare instances where it might be useful to hide mutations from autograd. For example: |
| if a tensor is very large, and you'd like to free its memory by storing it elsewhere, and re-populate |
| the tensor right before it is needed by autograd. |
| |
| Args: |
| tensor (torch.Tensor): the tensor in question, that you would like to preserve the version counter of. |
| |
| .. note:: |
| This API does not apply to :ref:`forward-mode AD <forward-mode-ad>`. |
| |
| """ |
| |
| def __init__(self, tensor: torch.Tensor) -> None: |
| self.tensor = tensor |
| self.prev_version = tensor._version |
| |
| def __enter__(self) -> None: |
| pass |
| |
| def __exit__(self, *args) -> None: |
| torch._C._autograd._unsafe_set_version_counter(self.tensor, self.prev_version) |