blob: 1a8263a79f93d7324baa2c271e74965ba66560b3 [file] [log] [blame]
# Owner(s): ["module: unknown"]
import collections
import unittest
import torch
from torch.testing._internal.common_utils import TestCase, run_tests
from torch.testing._internal.autocast_test_lists import AutocastCPUTestLists
from torch.utils._python_dispatch import TorchDispatchMode
class TestAutocastCPU(TestCase):
def setUp(self):
super(TestAutocastCPU, self).setUp()
self.autocast_lists = AutocastCPUTestLists(torch.device('cpu'))
def tearDown(self):
del self.autocast_lists
super(TestAutocastCPU, self).tearDown()
def _run_autocast_outofplace(self, op, args, run_as_type, out_type=None, module=torch, add_kwargs=None):
# helper to cast args
def cast(val, to_type):
if isinstance(val, torch.Tensor):
return val.to(to_type) if val.is_floating_point() else val
elif isinstance(val, collections.abc.Iterable):
return type(val)(cast(v, to_type) for v in val)
else:
return val
if add_kwargs is None:
add_kwargs = {}
self.assertFalse(torch.is_autocast_cpu_enabled())
with torch.cpu.amp.autocast():
self.assertTrue(torch.is_autocast_cpu_enabled())
out_type = out_type if out_type is not None else run_as_type
output = output_method = None
# Try module.* variant, if requested:
if module is not None and hasattr(module, op):
output = getattr(module, op)(*args, **add_kwargs)
if isinstance(output, torch.Tensor):
self.assertTrue(out_type == output.dtype,
"autocast for torch.{} produced {}, should produce {}"
.format(op, output.dtype, out_type))
# Try Tensor.* variant:
if hasattr(torch.Tensor, op):
output_method = getattr(args[0], op)(*args[1:], **add_kwargs)
if isinstance(output_method, torch.Tensor):
self.assertTrue(out_type == output_method.dtype,
"autocast for torch.{} produced {}, should produce torch.{}"
.format(op, output_method.dtype, out_type))
self.assertTrue((output is not None) or (output_method is not None),
"{} not found as an attribute on either Tensor or the requested module {}".format(
op, module))
# Accounts for ops that return Tensors, iterables, and other non-Tensors.
# For example, lstm_cell returns a tuple and equal returns bool.
def compare(first, second):
if isinstance(first, torch.Tensor):
return torch.equal(first, second)
elif isinstance(first, collections.abc.Iterable):
return all(compare(f, s) for f, s in zip(first, second))
else:
return first == second
# If both torch.* and Tensor.* variants were found, check outputs are identical
if (output is not None) and (output_method is not None):
self.assertTrue(type(output) == type(output_method))
comparison = compare(output, output_method)
self.assertTrue(comparison, "torch.{0} result did not match Tensor.{0} result".format(op))
# Compare numerics to Python-side "autocasting" that (we expect) does the same thing
# as the C++-side autocasting, and should be bitwise accurate.
output_to_compare = output if output is not None else output_method
with torch.cpu.amp.autocast(enabled=False):
self.assertFalse(torch.is_autocast_cpu_enabled())
if module is not None and hasattr(module, op):
control = getattr(module, op)(*cast(args, run_as_type), **add_kwargs)
else:
control = getattr(args[0].to(run_as_type), op)(*cast(args[1:], run_as_type), **add_kwargs)
self.assertTrue(type(output_to_compare) == type(control))
comparison = compare(output_to_compare, control)
self.assertTrue(comparison, "torch.{} result did not match control".format(op))
self.assertTrue(torch.is_autocast_cpu_enabled())
self.assertFalse(torch.is_autocast_cpu_enabled())
def args_maybe_kwargs(self, op_with_args):
if len(op_with_args) == 2:
return op_with_args[0], op_with_args[1], {}
else:
return op_with_args[0], op_with_args[1], op_with_args[2]
def test_autocast_torch_expect_builtin_promote(self):
for op, args, out_type in self.autocast_lists.torch_expect_builtin_promote:
self._run_autocast_outofplace(op, args, torch.float32, out_type=out_type)
def test_autocast_methods_expect_builtin_promote(self):
for op, args, out_type in self.autocast_lists.methods_expect_builtin_promote:
self._run_autocast_outofplace(op, args, torch.float32, module=None, out_type=out_type)
def test_autocast_torch_bf16(self):
for op_with_args in self.autocast_lists.torch_bf16:
op, args, maybe_kwargs = self.args_maybe_kwargs(op_with_args)
self._run_autocast_outofplace(op, args, torch.bfloat16, add_kwargs=maybe_kwargs)
def test_autocast_nn_bf16(self):
for op_with_args in self.autocast_lists.nn_bf16:
op, args, maybe_kwargs = self.args_maybe_kwargs(op_with_args)
self._run_autocast_outofplace(op, args, torch.bfloat16, module=torch._C._nn, add_kwargs=maybe_kwargs)
def test_autocast_torch_fp32(self):
for op_with_args in self.autocast_lists.torch_fp32:
op, args, maybe_kwargs = self.args_maybe_kwargs(op_with_args)
self._run_autocast_outofplace(op, args, torch.float32, add_kwargs=maybe_kwargs)
def test_autocast_nn_fp32(self):
for op_with_args in self.autocast_lists.nn_fp32:
op, args, maybe_kwargs = self.args_maybe_kwargs(op_with_args)
self._run_autocast_outofplace(op, args, torch.float32, module=torch._C._nn, add_kwargs=maybe_kwargs)
def test_autocast_torch_need_autocast_promote(self):
for op, args in self.autocast_lists.torch_need_autocast_promote:
self._run_autocast_outofplace(op, args, torch.float32)
@unittest.skipIf(not torch.cuda.is_available(), "requires cuda")
class TestAutocastGPU(TestCase):
def test_cast_cache_is_global(self):
"""
Verifies that the autocast cache is global. This is done by
mocking out cache clearing at the end of the forward pass,
running forward+backward with an explicit call to autocast in the
backward, and verifying that the weight only get cast to float16 once.
"""
class CustomLinear(torch.autograd.Function):
@staticmethod
def forward(ctx, x, w_t):
ctx.save_for_backward(x, w_t)
return torch.nn.functional.linear(x, w_t)
@staticmethod
def backward(ctx, grad_output):
x, w_t = ctx.saved_tensors
with torch.autocast(device_type='cuda'):
dL_dX = torch.matmul(grad_output, w_t)
dL_dW = torch.matmul(x.transpose(0, 1), grad_output).transpose(0, 1)
return dL_dX, dL_dW
data = torch.randn(2, 3).cuda()
weight = torch.nn.Parameter(torch.randn(4, 3).cuda())
weight_dtype_cast_counter = 0
class WeightDTypeCastCounterMode(TorchDispatchMode):
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
if (
func is torch.ops.aten._to_copy.default and
args[0] is weight and
kwargs['dtype'] is torch.float16
):
nonlocal weight_dtype_cast_counter
weight_dtype_cast_counter += 1
return func(*args, **kwargs)
def __enter__(self):
self.old_clear_cache = torch.clear_autocast_cache
torch.clear_autocast_cache = lambda: None
return super().__enter__()
def __exit__(self, exc_type, exc_val, exc_tb):
torch.clear_autocast_cache = self.old_clear_cache
return super().__exit__(exc_type, exc_val, exc_tb)
with WeightDTypeCastCounterMode():
with torch.autocast(device_type='cuda'):
output = CustomLinear.apply(data, weight)
s = output.sum()
s.backward()
self.assertEqual(weight_dtype_cast_counter, 1)
class TestTorchAutocast(TestCase):
def test_autocast_fast_dtype(self):
gpu_fast_dtype = torch.get_autocast_gpu_dtype()
cpu_fast_dtype = torch.get_autocast_cpu_dtype()
self.assertEqual(gpu_fast_dtype, torch.half)
self.assertEqual(cpu_fast_dtype, torch.bfloat16)
if __name__ == '__main__':
run_tests()