| torch.utils.checkpoint |
| ====================== |
| |
| .. note:: |
| Checkpointing is implemented by rerunning a forward-pass segment for |
| each checkpointed segment during backward propagation. This can cause persistent |
| states like the RNG state to be more advanced than they would without |
| checkpointing. By default, checkpointing includes logic to juggle |
| the RNG state such that checkpointed passes making use of RNG |
| (through dropout for example) have deterministic output as |
| compared to non-checkpointed passes. The logic to stash and restore |
| RNG states can incur a moderate performance hit depending on the runtime |
| of checkpointed operations. If deterministic output compared to |
| non-checkpointed passes is not required, supply ``preserve_rng_state=False`` |
| to ``checkpoint`` or ``checkpoint_sequential`` to omit stashing and |
| restoring the RNG state during each checkpoint. |
| |
| The stashing logic saves and restores the RNG state for CPU and another |
| device type (infer the device type from Tensor arguments excluding CPU |
| tensors by ``_infer_device_type``) to the ``run_fn``. If there are multiple |
| device, device state will only be saved for devices of a single device type, |
| and the remaining devices will be ignored. Consequently, if any checkpointed |
| functions involve randomness, this may result in incorrect gradients. (Note |
| that if CUDA devices are among the devices detected, it will be prioritized; |
| otherwise, the first device encountered will be selected.) If there are no |
| CPU-tensors, the default device type state (default value is `cuda`, and it |
| could be set to other device by ``DefaultDeviceType``) will be saved and restored. |
| However, the logic has no way to anticipate if the user will move |
| Tensors to a new device within the ``run_fn`` itself. Therefore, if you move |
| Tensors to a new device ("new" meaning not belonging to the set of |
| [current device + devices of Tensor arguments]) within ``run_fn``, deterministic |
| output compared to non-checkpointed passes is never guaranteed. |
| |
| .. currentmodule:: torch.utils.checkpoint |
| .. autofunction:: checkpoint |
| .. autofunction:: checkpoint_sequential |
| .. autofunction:: set_checkpoint_debug_enabled |
| .. autoclass:: CheckpointPolicy |
| .. autoclass:: SelectiveCheckpointContext |
| .. autofunction:: create_selective_checkpoint_contexts |