| import torch |
| import functools |
| import warnings |
| |
| from typing import Any, Optional |
| from torch.types import _dtype |
| |
| __all__ = ['autocast_decorator', 'autocast'] |
| |
| def autocast_decorator(autocast_instance, func): |
| @functools.wraps(func) |
| def decorate_autocast(*args, **kwargs): |
| with autocast_instance: |
| return func(*args, **kwargs) |
| decorate_autocast.__script_unsupported = '@autocast() decorator is not supported in script mode' # type: ignore[attr-defined] |
| return decorate_autocast |
| |
| class autocast(object): |
| r""" |
| Instances of :class:`autocast` serve as context managers or decorators that |
| allow regions of your script to run in mixed precision. |
| |
| In these regions, ops run in an op-specific dtype chosen by autocast |
| to improve performance while maintaining accuracy. |
| See the :ref:`Autocast Op Reference<autocast-op-reference>` for details. |
| |
| When entering an autocast-enabled region, Tensors may be any type. |
| You should not call ``half()`` or ``bfloat16()`` on your model(s) or inputs when using autocasting. |
| |
| :class:`autocast` should wrap only the forward pass(es) of your network, including the loss |
| computation(s). Backward passes under autocast are not recommended. |
| Backward ops run in the same type that autocast used for corresponding forward ops. |
| |
| Example for CUDA Devices:: |
| |
| # Creates model and optimizer in default precision |
| model = Net().cuda() |
| optimizer = optim.SGD(model.parameters(), ...) |
| |
| for input, target in data: |
| optimizer.zero_grad() |
| |
| # Enables autocasting for the forward pass (model + loss) |
| with autocast(): |
| output = model(input) |
| loss = loss_fn(output, target) |
| |
| # Exits the context manager before backward() |
| loss.backward() |
| optimizer.step() |
| |
| See the :ref:`CUDA Automatic Mixed Precision examples<amp-examples>` for usage (along with gradient scaling) |
| in more complex scenarios (e.g., gradient penalty, multiple models/losses, custom autograd functions). |
| |
| :class:`autocast` can also be used as a decorator, e.g., on the ``forward`` method of your model:: |
| |
| class AutocastModel(nn.Module): |
| ... |
| @autocast() |
| def forward(self, input): |
| ... |
| |
| Floating-point Tensors produced in an autocast-enabled region may be ``float16``. |
| After returning to an autocast-disabled region, using them with floating-point |
| Tensors of different dtypes may cause type mismatch errors. If so, cast the Tensor(s) |
| produced in the autocast region back to ``float32`` (or other dtype if desired). |
| If a Tensor from the autocast region is already ``float32``, the cast is a no-op, |
| and incurs no additional overhead. |
| CUDA Example:: |
| |
| # Creates some tensors in default dtype (here assumed to be float32) |
| a_float32 = torch.rand((8, 8), device="cuda") |
| b_float32 = torch.rand((8, 8), device="cuda") |
| c_float32 = torch.rand((8, 8), device="cuda") |
| d_float32 = torch.rand((8, 8), device="cuda") |
| |
| with autocast(): |
| # torch.mm is on autocast's list of ops that should run in float16. |
| # Inputs are float32, but the op runs in float16 and produces float16 output. |
| # No manual casts are required. |
| e_float16 = torch.mm(a_float32, b_float32) |
| # Also handles mixed input types |
| f_float16 = torch.mm(d_float32, e_float16) |
| |
| # After exiting autocast, calls f_float16.float() to use with d_float32 |
| g_float32 = torch.mm(d_float32, f_float16.float()) |
| |
| CPU Training Example:: |
| |
| # Creates model and optimizer in default precision |
| model = Net() |
| optimizer = optim.SGD(model.parameters(), ...) |
| |
| for epoch in epochs: |
| for input, target in data: |
| optimizer.zero_grad() |
| |
| # Runs the forward pass with autocasting. |
| with torch.autocast(device_type="cpu", dtype=torch.bfloat16): |
| output = model(input) |
| loss = loss_fn(output, target) |
| |
| loss.backward() |
| optimizer.step() |
| |
| |
| CPU Inference Example:: |
| |
| # Creates model in default precision |
| model = Net().eval() |
| |
| with torch.autocast(device_type="cpu", dtype=torch.bfloat16): |
| for input in data: |
| # Runs the forward pass with autocasting. |
| output = model(input) |
| |
| CPU Inference Example with Jit Trace:: |
| |
| class TestModel(nn.Module): |
| def __init__(self, input_size, num_classes): |
| super(TestModel, self).__init__() |
| self.fc1 = nn.Linear(input_size, num_classes) |
| def forward(self, x): |
| return self.fc1(x) |
| |
| input_size = 2 |
| num_classes = 2 |
| model = TestModel(input_size, num_classes).eval() |
| |
| # For now, we suggest to disable the Jit Autocast Pass, |
| # As the issue: https://github.com/pytorch/pytorch/issues/75956 |
| torch._C._jit_set_autocast_mode(False) |
| |
| with torch.cpu.amp.autocast(cache_enabled=False): |
| model = torch.jit.trace(model, torch.randn(1, input_size)) |
| model = torch.jit.freeze(model) |
| # Models Run |
| for _ in range(3): |
| model(torch.randn(1, input_size)) |
| |
| Type mismatch errors *in* an autocast-enabled region are a bug; if this is what you observe, |
| please file an issue. |
| |
| ``autocast(enabled=False)`` subregions can be nested in autocast-enabled regions. |
| Locally disabling autocast can be useful, for example, if you want to force a subregion |
| to run in a particular ``dtype``. Disabling autocast gives you explicit control over |
| the execution type. In the subregion, inputs from the surrounding region |
| should be cast to ``dtype`` before use:: |
| |
| # Creates some tensors in default dtype (here assumed to be float32) |
| a_float32 = torch.rand((8, 8), device="cuda") |
| b_float32 = torch.rand((8, 8), device="cuda") |
| c_float32 = torch.rand((8, 8), device="cuda") |
| d_float32 = torch.rand((8, 8), device="cuda") |
| |
| with autocast(): |
| e_float16 = torch.mm(a_float32, b_float32) |
| with autocast(enabled=False): |
| # Calls e_float16.float() to ensure float32 execution |
| # (necessary because e_float16 was created in an autocasted region) |
| f_float32 = torch.mm(c_float32, e_float16.float()) |
| |
| # No manual casts are required when re-entering the autocast-enabled region. |
| # torch.mm again runs in float16 and produces float16 output, regardless of input types. |
| g_float16 = torch.mm(d_float32, f_float32) |
| |
| The autocast state is thread-local. If you want it enabled in a new thread, the context manager or decorator |
| must be invoked in that thread. This affects :class:`torch.nn.DataParallel` and |
| :class:`torch.nn.parallel.DistributedDataParallel` when used with more than one GPU per process |
| (see :ref:`Working with Multiple GPUs<amp-multigpu>`). |
| |
| Args: |
| device_type(str, required): Whether to use 'cuda' or 'cpu' device |
| enabled(bool, optional): Whether autocasting should be enabled in the region. |
| Default: ``True`` |
| dtype(torch_dtype, optional): Whether to use torch.float16 or torch.bfloat16. |
| cache_enabled(bool, optional): Whether the weight cache inside autocast should be enabled. |
| Default: ``True`` |
| """ |
| def __init__(self, device_type : str, |
| dtype : Optional[_dtype] = None, |
| enabled : bool = True, |
| cache_enabled : Optional[bool] = None): |
| if torch._jit_internal.is_scripting(): |
| self._enabled = enabled |
| self.device = device_type |
| self.fast_dtype = dtype |
| # TODO: support get_autocast_gpu/cpu_dtype |
| assert dtype is not None |
| return |
| self.device = device_type |
| if self.device == 'cuda': |
| self.fast_dtype = torch.get_autocast_gpu_dtype() |
| elif self.device == 'cpu': |
| self.fast_dtype = torch.get_autocast_cpu_dtype() |
| elif self.device == 'xpu': |
| self.fast_dtype = torch.xpu.get_autocast_xpu_dtype() # type: ignore[attr-defined] |
| elif self.device == 'hpu': |
| self.fast_dtype = torch.hpu.get_autocast_hpu_dtype() # type: ignore[attr-defined] |
| else: |
| raise RuntimeError('User specified autocast device_type must be \'cuda\' or \'cpu\'') |
| self._cache_enabled = torch.is_autocast_cache_enabled() |
| if enabled and torch.cuda.amp.common.amp_definitely_not_available() and self.device == 'cuda': |
| warnings.warn('User provided device_type of \'cuda\', but CUDA is not available. Disabling') |
| enabled = False |
| if dtype is not None: |
| self.fast_dtype = dtype |
| if cache_enabled is not None: |
| self._cache_enabled = cache_enabled |
| |
| if self.device == 'cpu': |
| supported_dtype = [torch.bfloat16] |
| if self.fast_dtype not in supported_dtype: |
| error_message = 'In CPU autocast, but the target dtype is not supported. Disabling autocast.\n' |
| error_message += 'CPU Autocast only supports dtype of torch.bfloat16 currently.' |
| warnings.warn(error_message) |
| enabled = False |
| elif self.device == 'xpu': |
| supported_dtype = [torch.bfloat16, torch.float16] |
| if self.fast_dtype not in supported_dtype: |
| error_message = 'In XPU autocast, but the target dtype is not supported. Disabling autocast.\n' |
| error_message += 'XPU Autocast only supports dtype of torch.bfloat16 currently.' |
| warnings.warn(error_message) |
| enabled = False |
| elif self.device == 'hpu': |
| supported_dtype = [torch.bfloat16, torch.float16] |
| if self.fast_dtype not in supported_dtype: |
| error_message = 'In HPU autocast, but the target dtype is not supported. Disabling autocast.\n' |
| error_message += 'HPU Autocast only supports dtypes of torch.bfloat16 and torch.float16 currently.' |
| warnings.warn(error_message) |
| enabled = False |
| elif self.device == 'cuda': |
| if self.fast_dtype == torch.bfloat16 and not torch.cuda.is_bf16_supported(): |
| raise RuntimeError('Current CUDA Device does not support bfloat16. Please switch dtype to float16.') |
| self._enabled = enabled |
| |
| def __enter__(self): |
| if torch._jit_internal.is_scripting(): |
| assert self.fast_dtype is not None |
| return self |
| |
| self.prev_cache_enabled = torch.is_autocast_cache_enabled() |
| if self.device == 'cpu': |
| self.prev = torch.is_autocast_cpu_enabled() |
| self.prev_fastdtype = torch.get_autocast_cpu_dtype() |
| torch.set_autocast_cpu_enabled(self._enabled) |
| torch.set_autocast_cpu_dtype(self.fast_dtype) # type: ignore[arg-type] |
| torch.autocast_increment_nesting() |
| elif self.device == 'xpu': |
| self.prev = torch.xpu.is_autocast_xpu_enabled() # type: ignore[attr-defined] |
| self.prev_fastdtype = torch.xpu.get_autocast_xpu_dtype() # type: ignore[attr-defined] |
| torch.xpu.set_autocast_xpu_enabled(self._enabled) # type: ignore[attr-defined] |
| torch.xpu.set_autocast_xpu_dtype(self.fast_dtype) # type: ignore[attr-defined] |
| torch.autocast_increment_nesting() |
| elif self.device == 'hpu': |
| self.prev = torch.hpu.is_autocast_hpu_enabled() # type: ignore[attr-defined] |
| self.prev_fastdtype = torch.hpu.get_autocast_hpu_dtype() # type: ignore[attr-defined] |
| torch.hpu.set_autocast_hpu_enabled(self._enabled) # type: ignore[attr-defined] |
| torch.hpu.set_autocast_hpu_dtype(self.fast_dtype) # type: ignore[attr-defined] |
| torch.autocast_increment_nesting() |
| else: |
| self.prev = torch.is_autocast_enabled() |
| self.prev_fastdtype = torch.get_autocast_gpu_dtype() |
| torch.set_autocast_gpu_dtype(self.fast_dtype) # type: ignore[arg-type] |
| torch.set_autocast_enabled(self._enabled) |
| torch.autocast_increment_nesting() |
| torch.set_autocast_cache_enabled(self._cache_enabled) |
| |
| def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any): # type: ignore[override] |
| if torch._jit_internal.is_scripting(): |
| return |
| |
| # Drop the cache when we exit to a nesting level that's outside any instance of autocast. |
| if self.device == 'cpu': |
| if torch.autocast_decrement_nesting() == 0: |
| torch.clear_autocast_cache() |
| torch.set_autocast_cpu_enabled(self.prev) |
| torch.set_autocast_cpu_dtype(self.prev_fastdtype) |
| elif self.device == 'xpu': |
| if torch.autocast_decrement_nesting() == 0: |
| torch.clear_autocast_cache() |
| torch.xpu.set_autocast_xpu_enabled(self.prev) # type: ignore[attr-defined] |
| torch.xpu.set_autocast_xpu_dtype(self.prev_fastdtype) # type: ignore[attr-defined] |
| elif self.device == 'hpu': |
| if torch.autocast_decrement_nesting() == 0: |
| torch.clear_autocast_cache() |
| torch.hpu.set_autocast_hpu_enabled(self.prev) # type: ignore[attr-defined] |
| torch.hpu.set_autocast_hpu_dtype(self.prev_fastdtype) # type: ignore[attr-defined] |
| else: |
| if torch.autocast_decrement_nesting() == 0: |
| torch.clear_autocast_cache() |
| torch.set_autocast_enabled(self.prev) |
| torch.set_autocast_gpu_dtype(self.prev_fastdtype) |
| torch.set_autocast_cache_enabled(self.prev_cache_enabled) |
| return False |
| |
| def __call__(self, func): |
| if torch._jit_internal.is_scripting(): |
| return func |
| return autocast_decorator(self, func) |