| import torch |
| from typing import Callable, Any |
| |
| |
| class saved_tensors_hooks(): |
| """Context-manager that sets a pair of pack / unpack hooks for saved tensors. |
| |
| Use this context-manager to define how intermediary results of an operation |
| should be packed before saving, and unpacked on retrieval. |
| |
| In that context, the ``pack_hook`` function will be called everytime an |
| operation saves a tensor for backward (this includes intermediary results |
| saved using |
| :func:`~torch.autograd.function._ContextMethodMixin.save_for_backward` but |
| also those recorded by a PyTorch-defined operation). The output of |
| ``pack_hook`` is then stored in the computation graph instead of the |
| original tensor. |
| |
| The ``unpack_hook`` is called when the saved tensor needs to be accessed, |
| namely when executing :func:`torch.Tensor.backward()` or |
| :func:`torch.autograd.grad()`. It takes as argument the *packed* object |
| returned by ``pack_hook`` and should return a tensor which has the same |
| content as the original tensor (passed as input to the corresponding |
| ``pack_hook``). |
| |
| The hooks should have the following signatures: |
| |
| pack_hook(tensor: Tensor) -> Any |
| |
| unpack_hook(Any) -> Tensor |
| |
| where the return value of ``pack_hook`` is a valid input to ``unpack_hook``. |
| |
| In general, you want ``unpack_hook(pack_hook(t))`` to be equal to ``t`` in terms |
| of value, size, dtype and device. |
| |
| Example:: |
| |
| >>> def pack_hook(x): |
| ... print("Packing", x) |
| ... return x |
| >>> |
| >>> def unpack_hook(x): |
| ... print("Unpacking", x) |
| ... return x |
| >>> |
| >>> a = torch.ones(5, requires_grad=True) |
| >>> b = torch.ones(5, requires_grad=True) * 2 |
| >>> with torch.autograd.graph.saved_tensors_hooks(pack_hook, unpack_hook): |
| ... y = a * b |
| Packing tensor([1., 1., 1., 1., 1.], requires_grad=True) |
| Packing tensor([2., 2., 2., 2., 2.], grad_fn=<MulBackward0>) |
| >>> y.sum().backward() |
| Unpacking tensor([1., 1., 1., 1., 1.], requires_grad=True) |
| Unpacking tensor([2., 2., 2., 2., 2.], grad_fn=<MulBackward0>) |
| |
| .. warning :: |
| Performing an inplace operation on the input to either hooks may lead |
| to undefined behavior. |
| |
| .. warning :: |
| Only one pair of hooks is allowed at a time. When recursively nesting this |
| context-manager, only the inner-most pair of hooks will be applied. |
| """ |
| def __init__(self, pack_hook: Callable[[torch.Tensor], Any], unpack_hook: Callable[[Any], torch.Tensor]): |
| self.pack_hook = pack_hook |
| self.unpack_hook = unpack_hook |
| |
| def __enter__(self): |
| torch._C._autograd._push_saved_tensors_default_hooks(self.pack_hook, self.unpack_hook) |
| |
| def __exit__(self, *args: Any): |
| torch._C._autograd._pop_saved_tensors_default_hooks() |
| |
| |
| class save_on_cpu(saved_tensors_hooks): |
| """Context-manager under which tensors saved by the forward pass will be |
| stored on cpu, then retrieved for backward. |
| |
| When performing operations within this context manager, intermediary |
| results saved in the graph during the forward pass will be moved to CPU, |
| then copied back to the original device when needed for the backward pass. |
| If the graph was already on CPU, no tensor copy is performed. |
| |
| Use this context-manager to trade compute for GPU memory usage (e.g. |
| when your model doesn't fit in GPU memory during training). |
| |
| Args: |
| pin_memory (bool): If ``True`` tensors will be saved to CPU pinned memory |
| during packing and copied to GPU asynchronously during unpacking. |
| Defaults to ``False``. |
| Also see :ref:`cuda-memory-pinning`. |
| |
| |
| Example:: |
| |
| >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA) |
| >>> a = torch.randn(5, requires_grad=True, device="cuda") |
| >>> b = torch.randn(5, requires_grad=True, device="cuda") |
| >>> c = torch.randn(5, requires_grad=True, device="cuda") |
| >>> |
| >>> def f(a, b, c): |
| ... prod_1 = a * b # a and b are saved on GPU |
| ... with torch.autograd.graph.save_on_cpu(): |
| ... prod_2 = prod_1 * c # prod_1 and c are saved on CPU |
| ... y = prod_2 * a # prod_2 and a are saved on GPU |
| ... return y |
| >>> |
| >>> y = f(a, b, c) |
| >>> del a, b, c # for illustration only |
| >>> # the content of a, b, and prod_2 are still alive on GPU |
| >>> # the content of prod_1 and c only live on CPU |
| >>> y.sum().backward() # all CPU tensors are moved back to GPU, for backward |
| >>> # all intermediary tensors are released (deleted) after the call to backward |
| |
| """ |
| def __init__(self, pin_memory=False): |
| def pack_to_cpu(tensor): |
| if not pin_memory: |
| return (tensor.device, tensor.cpu()) |
| |
| packed = torch.empty( |
| tensor.size(), |
| dtype=tensor.dtype, |
| layout=tensor.layout, |
| pin_memory=(torch.cuda.is_available() and not tensor.is_sparse)) |
| packed.copy_(tensor) |
| return (tensor.device, packed) |
| |
| def unpack_from_cpu(packed): |
| device, tensor = packed |
| return tensor.to(device, non_blocking=pin_memory) |
| |
| super().__init__(pack_to_cpu, unpack_from_cpu) |