| # Owner(s): ["oncall: jit"] |
| |
| import torch |
| from torch.cuda.amp import autocast |
| from typing import Optional, Tuple |
| |
| import unittest |
| from test_jit import JitTestCase |
| from torch.testing._internal.common_cuda import TEST_CUDA |
| from torch.testing._internal.common_utils import run_tests, skipIfTorchDynamo |
| from torch.testing import FileCheck |
| from jit.test_models import MnistNet |
| |
| TEST_BFLOAT16 = TEST_CUDA and torch.cuda.is_bf16_supported() |
| |
| @skipIfTorchDynamo("Not a TorchDynamo suitable test") |
| class TestAutocast(JitTestCase): |
| def setUp(self): |
| # common input tensors |
| if TEST_CUDA: |
| self.a_fp16 = torch.rand((2, 2), dtype=torch.float16, device='cuda') |
| self.b_fp16 = torch.rand((2, 2), dtype=torch.float16, device='cuda') |
| self.c_fp16 = torch.rand((2, 2), dtype=torch.float16, device='cuda') |
| self.d_fp16 = torch.rand((2, 2), dtype=torch.float16, device='cuda') |
| self.a_fp32 = torch.rand((2, 2), dtype=torch.float32, device='cuda') |
| self.b_fp32 = torch.rand((2, 2), dtype=torch.float32, device='cuda') |
| self.c_fp32 = torch.rand((2, 2), dtype=torch.float32, device='cuda') |
| self.d_fp32 = torch.rand((2, 2), dtype=torch.float32, device='cuda') |
| self.old_value = torch._C._jit_set_autocast_mode(True) |
| super().setUp() |
| |
| def tearDown(self): |
| torch._C._jit_set_autocast_mode(self.old_value) |
| super().tearDown() |
| |
| @unittest.skipIf(not TEST_CUDA, "No cuda") |
| def test_jit_generic_autocast(self): |
| @torch.jit.script |
| def fn_cuda_autocast(a, b): |
| with autocast(): |
| x = torch.mm(a, b) |
| y = torch.sum(x) |
| return x, y |
| |
| @torch.jit.script |
| def fn_generic_autocast(a, b): |
| with torch.amp.autocast(device_type='cuda'): |
| x = torch.mm(a, b) |
| y = torch.sum(x) |
| return x, y |
| self.assertEqual(fn_cuda_autocast(self.a_fp32, self.b_fp32), fn_generic_autocast(self.a_fp32, self.b_fp32)) |
| |
| @unittest.skipIf(not TEST_CUDA, "No cuda") |
| def test_minimal(self): |
| @torch.jit.script |
| def fn(a, b): |
| with autocast(): |
| x = torch.mm(a, b) |
| y = torch.sum(x) |
| return x, y |
| x, y = fn(self.a_fp32, self.b_fp32) |
| self.assertEqual(x.dtype, torch.float16) |
| self.assertEqual(y.dtype, torch.float32) |
| |
| @unittest.skipIf(not TEST_CUDA or not TEST_BFLOAT16, "No cuda bfloat16 support") |
| def test_linear_bf16(self): |
| @torch.jit.script |
| def fn(a, b): |
| with autocast(dtype=torch.bfloat16): |
| x = torch.mm(a, b) |
| y = torch.sum(x) |
| return x, y |
| x, y = fn(self.a_fp32, self.b_fp32) |
| self.assertEqual(x.dtype, torch.bfloat16) |
| self.assertEqual(y.dtype, torch.float32) |
| |
| @unittest.skipIf(not TEST_CUDA, "No cuda") |
| def test_minimal_cpu(self): |
| @torch.jit.script |
| def fn(a, b): |
| with autocast(): |
| return torch.mm(a, b) |
| result = fn(self.a_fp32.to('cpu'), self.b_fp32.to('cpu')) |
| self.assertEqual(result.dtype, torch.float32) |
| |
| @unittest.skipIf(not TEST_CUDA, "No cuda") |
| def test_minimal_off(self): |
| @torch.jit.script |
| def fn(a, b): |
| with autocast(enabled=False): |
| return torch.mm(a, b) |
| result = fn(self.a_fp32, self.b_fp32) |
| self.assertEqual(result.dtype, torch.float32) |
| |
| @unittest.skipIf(not TEST_CUDA, "No cuda") |
| def test_runtime_autocast_state(self): |
| @torch.jit.script |
| def fn(a, b, use_amp: bool): |
| with autocast(enabled=use_amp): |
| return torch.mm(a, b) |
| # runtime values for autocast enable argument are not supported |
| with self.assertRaises(RuntimeError): |
| fn(self.a_fp32, self.b_fp32, True) |
| |
| @unittest.skipIf(not TEST_CUDA, "No cuda") |
| def test_runtime_autocast_state_expr(self): |
| @torch.jit.script |
| def fn(a, b): |
| with autocast(enabled=True if a[0][0] > 0.5 else False): |
| return torch.mm(a, b) |
| # runtime values for autocast enable argument are not supported |
| with self.assertRaises(RuntimeError): |
| fn(self.a_fp32, self.b_fp32) |
| |
| @unittest.skipIf(not TEST_CUDA, "No cuda") |
| def test_explicit_casts(self): |
| @torch.jit.script |
| def fn(a, b, c, d): |
| with autocast(): |
| e = torch.mm(a.double(), b.double()).float() |
| f = torch.mm(c, d).double() |
| g = torch.mm(c.double(), f) |
| return e, f, g |
| e, f, g = fn(self.a_fp32, self.b_fp32, self.c_fp32, self.d_fp32) |
| self.assertEqual(e.dtype, torch.float32) |
| self.assertEqual(f.dtype, torch.float64) |
| self.assertEqual(g.dtype, torch.float64) |
| |
| # multiple uses of the same input value |
| @unittest.skipIf(not TEST_CUDA, "No cuda") |
| def test_duplicate_inputs(self): |
| @torch.jit.script |
| def fn(a, b): |
| with autocast(): |
| e = torch.mm(a, a) |
| f = torch.mm(e, e) |
| return e, f |
| e, f = fn(self.a_fp32, self.b_fp32) |
| self.assertEqual(e.dtype, torch.float16) |
| self.assertEqual(f.dtype, torch.float16) |
| |
| @unittest.skipIf(not TEST_CUDA, "No cuda") |
| def test_fp32_policy(self): |
| @torch.jit.script |
| def fn(a): |
| with autocast(enabled=True): |
| return torch.log(a) |
| result = fn(self.a_fp16) |
| self.assertEqual(result.dtype, torch.float32) |
| |
| @unittest.skipIf(not TEST_CUDA, "No cuda") |
| def test_fp32_policy_with_fp64(self): |
| @torch.jit.script |
| def fn(a): |
| with autocast(enabled=True): |
| return torch.log(a) |
| # fp32 policy should not narrow fp64 to fp32! |
| result = fn(self.a_fp32.double()) |
| self.assertEqual(result.dtype, torch.float64) |
| |
| @unittest.skipIf(not TEST_CUDA, "No cuda") |
| def test_promote_policy(self): |
| @torch.jit.script |
| def fn(a, b, c, d): |
| with autocast(): |
| e = torch.mm(a, b) |
| f = torch.addcmul(e, c, d, value=0.1) |
| return e, f |
| e, f = fn(self.a_fp32, self.b_fp32, self.c_fp32, self.d_fp32) |
| self.assertEqual(e.dtype, torch.float16) |
| self.assertEqual(f.dtype, torch.float32) |
| |
| @unittest.skipIf(not TEST_CUDA, "No cuda") |
| def test_promote_policy_fp64(self): |
| @torch.jit.script |
| def fn(a, b): |
| with autocast(enabled=True): |
| return torch.addcmul(a, a, b, value=0.1) |
| result = fn(self.a_fp32.double(), self.b_fp32.double()) |
| self.assertEqual(result.dtype, torch.float64) |
| |
| @unittest.skipIf(not TEST_CUDA, "No cuda") |
| def test_fp32_set_opt_dtype_policy(self): |
| @torch.jit.script |
| def fn(a, b, c, d, dtype: Optional[int]): |
| with autocast(enabled=True): |
| x = torch.softmax(a, 0) |
| y = torch.softmax(b, 0, None) |
| z = torch.softmax(c, 0, torch.float64) |
| w = torch.softmax(d, 0, dtype) |
| return x, y, z, w |
| x, y, z, w = fn(self.a_fp16, self.b_fp16, self.c_fp16, self.d_fp16, None) |
| self.assertEqual(x.dtype, torch.float32) |
| self.assertEqual(y.dtype, torch.float32) |
| self.assertEqual(z.dtype, torch.float64) |
| self.assertEqual(w.dtype, torch.float16) |
| |
| @unittest.skipIf(not TEST_CUDA, "No cuda") |
| def test_fp32_set_opt_dtype_policy_fp64(self): |
| @torch.jit.script |
| def fn(a, b, c, d, dtype: Optional[int]): |
| with autocast(enabled=True): |
| x = torch.softmax(a, 0) |
| y = torch.softmax(b, 0, None) |
| z = torch.softmax(c, 0, torch.float64) |
| w = torch.softmax(d, 0, dtype) |
| return x, y, z, w |
| x, y, z, w = fn(self.a_fp32.double(), self.b_fp32.double(), self.c_fp32.double(), self.d_fp32.double(), None) |
| self.assertEqual(x.dtype, torch.float64) |
| self.assertEqual(y.dtype, torch.float64) |
| self.assertEqual(z.dtype, torch.float64) |
| self.assertEqual(w.dtype, torch.float64) |
| |
| @unittest.skipIf(True, "broken due to lack of type propagation") |
| @unittest.skipIf(not TEST_CUDA, "No cuda") |
| def test_control_flow(self): |
| @torch.jit.script |
| def fn(a, b, c, d): |
| with autocast(): |
| if a[0][0] > 0.5: |
| e = torch.mm(a, b) |
| x = 1 |
| else: |
| e = torch.mm(c, d) |
| x = 2 |
| f = torch.mm(d, e) * x |
| return e, f |
| e, f = fn(self.a_fp32, self.b_fp32, self.c_fp32, self.d_fp32) |
| self.assertEqual(e.dtype, torch.float16) |
| self.assertEqual(f.dtype, torch.float16) |
| |
| # this works find in regular Python, but it creates a delicate |
| # situation in TorchScript where the types are not consistent across |
| # the then/else branches |
| @unittest.skipIf(not TEST_CUDA, "No cuda") |
| def test_divergent_types(self): |
| @torch.jit.script |
| def fn(a, b, c, d): |
| with autocast(): |
| if a[0][0] > 0.5: |
| e = torch.mm(a, b) |
| f = torch.mm(a, b).float() |
| else: |
| e = torch.mm(c, d).float() |
| f = torch.mm(a, b) |
| return torch.mm(e.float(), f.float()) |
| result = fn(self.a_fp32, self.b_fp32, self.c_fp32, self.d_fp32) |
| self.assertEqual(result.dtype, torch.float32) |
| |
| # another, more complex case of divergent types |
| @unittest.skipIf(not TEST_CUDA, "No cuda") |
| def test_divergent_autocast(self): |
| @torch.jit.script |
| def fn(a, b, c, d): |
| autocast_on = autocast(enabled=True) |
| autocast_off = autocast(enabled=False) |
| if a[0][0] > 0.5: |
| with autocast_on: |
| e = torch.mm(a, b) |
| else: |
| with autocast_off: |
| e = torch.mm(c, d) |
| return torch.mm(e, e) |
| fn(self.a_fp32, self.b_fp32, self.c_fp32, self.d_fp32) |
| |
| @unittest.skipIf(not TEST_CUDA, "No cuda") |
| def test_conditional_autocast(self): |
| @torch.jit.script |
| def fn(a, b): |
| autocast_on = autocast(enabled=True) |
| autocast_off = autocast(enabled=False) |
| with autocast_on if a[0][0] > 0.5 else autocast_off: |
| return torch.mm(a, b) |
| # conditional autocast expressions are not supported |
| with self.assertRaises(RuntimeError): |
| fn(self.a_fp32, self.b_fp32) |
| |
| @unittest.skipIf(not TEST_CUDA, "No cuda") |
| def test_nested_autocast(self): |
| @torch.jit.script |
| def fn(a, b, c, d): |
| with autocast(enabled=False): |
| e = torch.mm(a, b) |
| with autocast(enabled=True): |
| f = torch.mm(e, c) |
| with autocast(enabled=False): |
| g = torch.mm(e, d) |
| return e, f, g |
| e, f, g = fn(self.a_fp32, self.b_fp32, self.c_fp32, self.d_fp32) |
| self.assertEqual(e.dtype, torch.float32) |
| self.assertEqual(f.dtype, torch.float16) |
| self.assertEqual(g.dtype, torch.float32) |
| |
| @unittest.skipIf(not TEST_CUDA, "No cuda") |
| def test_implicitly_nested_autocast(self): |
| @torch.jit.script |
| def fn(a, b): |
| with autocast(enabled=False), autocast(enabled=True): |
| return torch.mm(a, b) |
| result = fn(self.a_fp32, self.b_fp32) |
| self.assertEqual(result.dtype, torch.float16) |
| |
| @unittest.skipIf(not TEST_CUDA, "No cuda") |
| def test_reused_autocast(self): |
| @torch.jit.script |
| def fn(a, b, c, d): |
| autocast_instance = autocast(enabled=True) |
| with autocast_instance: |
| e = torch.mm(a, b) |
| with autocast_instance: |
| e = torch.mm(c, d) |
| f = torch.mm(d, e) |
| g = torch.mm(e, f) |
| return e, f, g |
| e, f, g = fn(self.a_fp32, self.b_fp32, self.c_fp32, self.d_fp32) |
| self.assertEqual(e.dtype, torch.float16) |
| self.assertEqual(f.dtype, torch.float16) |
| self.assertEqual(g.dtype, torch.float16) |
| |
| # TODO: fix and enable this test? |
| # (we could technically fix this, but is it really worth it?) |
| @unittest.skipIf(True, "unsuported autocast syntax") |
| def test_reused_autocast_expr(self): |
| @torch.jit.script |
| def fn(a, b, c, d): |
| with autocast(enabled=True) as autocast_instance: |
| e = torch.mm(a, b) |
| with autocast_instance: |
| e = torch.mm(c, d) |
| f = torch.mm(d, e) |
| g = torch.mm(e, f) |
| return e, f, g |
| e, f, g = fn(self.a_fp32, self.b_fp32, self.c_fp32, self.d_fp32) |
| self.assertEqual(e.dtype, torch.float16) |
| self.assertEqual(f.dtype, torch.float16) |
| self.assertEqual(g.dtype, torch.float16) |
| |
| @unittest.skipIf(not TEST_CUDA, "No cuda") |
| def test_callees(self): |
| def helper(a, b): |
| return torch.mm(a, b) |
| |
| @torch.jit.script |
| def fn(a, b): |
| with autocast(enabled=True): |
| tmp = helper(a, b) |
| tmp = helper(tmp, tmp) |
| tmp = helper(tmp, tmp) |
| tmp = helper(tmp, tmp) |
| return helper(tmp, b) |
| |
| result = fn(self.a_fp32, self.b_fp32) |
| self.assertEqual(result.dtype, torch.float16) |
| |
| @unittest.skipIf(not TEST_CUDA, "No cuda") |
| def test_callees_with_autocast_on(self): |
| def helper(a, b): |
| with autocast(enabled=True): |
| return torch.mm(a, b) |
| |
| @torch.jit.script |
| def fn(a, b): |
| with autocast(enabled=False): |
| return helper(a, b) |
| |
| result = fn(self.a_fp32, self.b_fp32) |
| self.assertEqual(result.dtype, torch.float16) |
| |
| @unittest.skipIf(not TEST_CUDA, "No cuda") |
| def test_callees_with_autocast_off(self): |
| def helper(a, b): |
| with autocast(enabled=False): |
| return torch.mm(a, b) |
| |
| @torch.jit.script |
| def fn(a, b): |
| with autocast(enabled=True): |
| return helper(a, b) |
| |
| result = fn(self.a_fp32, self.b_fp32) |
| self.assertEqual(result.dtype, torch.float32) |
| |
| # scripting inside eager autocast |
| @unittest.skipIf(not TEST_CUDA, "No cuda") |
| def test_eager_and_script(self): |
| @torch.jit.script |
| def fn(a, b): |
| return torch.mm(a, b) |
| for i in range(8): |
| use_autocast = (i % 2 == 0) |
| expected_dtype = torch.float16 if use_autocast else torch.float32 |
| with autocast(enabled=use_autocast): |
| result = fn(self.a_fp32, self.b_fp32) |
| self.assertEqual(result.dtype, expected_dtype) |
| |
| # traced inside scripting |
| @unittest.skipIf(not TEST_CUDA, "No cuda") |
| def test_script_and_tracing(self): |
| def helper(a, b): |
| return torch.mm(a, b) |
| |
| traced = torch.jit.trace(helper, (self.a_fp32, self.a_fp32)) |
| |
| @torch.jit.script |
| def fn(a, b): |
| with autocast(enabled=True): |
| return traced(a, b) |
| |
| result = fn(self.a_fp32, self.b_fp32) |
| self.assertEqual(result.dtype, torch.float16) |
| |
| # traced with autocast inside scripting |
| @unittest.skipIf(True, "autocast(False) is ignored inside traced functions") |
| @unittest.skipIf(not TEST_CUDA, "No cuda") |
| def test_script_and_tracing_with_autocast(self): |
| def helper(a, b): |
| with autocast(enabled=False): |
| return torch.mm(a, b) * 2.0 |
| |
| traced = torch.jit.trace(helper, (self.a_fp32, self.a_fp32)) |
| |
| @torch.jit.script |
| def fn(a, b): |
| with autocast(enabled=True): |
| return traced(a, b) |
| |
| result = fn(self.a_fp32, self.b_fp32) |
| self.assertEqual(result.dtype, torch.float32) |
| |
| # scripted called from traced |
| @unittest.skipIf(not TEST_CUDA, "No cuda") |
| def test_tracing_and_script(self): |
| @torch.jit.script |
| def fn(a, b): |
| with autocast(): |
| return torch.mm(a, b) |
| |
| def traced(a, b): |
| return fn(a, b) |
| |
| traced = torch.jit.trace(traced, (self.a_fp32, self.b_fp32)) |
| result = traced(self.a_fp32, self.b_fp32) |
| self.assertEqual(result.dtype, torch.float16) |
| |
| # scripted called from traced with autocast |
| @unittest.skipIf(True, "scripted called from traced TorchScript is not yet working") |
| @unittest.skipIf(not TEST_CUDA, "No cuda") |
| def test_tracing_with_autocast_and_script(self): |
| @torch.jit.script |
| def fn(a, b): |
| return torch.mm(a, b) |
| |
| def traced(a, b): |
| with autocast(enabled=True): |
| return fn(a, b) |
| |
| traced = torch.jit.trace(traced, (self.a_fp32, self.b_fp32)) |
| result = traced(self.a_fp32, self.b_fp32) |
| self.assertEqual(result.dtype, torch.float16) |
| |
| @unittest.skipIf(not TEST_CUDA, "No cuda") |
| def test_script_module(self): |
| class TestModule(torch.nn.Module): |
| def __init__(self, N, M): |
| super().__init__() |
| self.weight = torch.nn.Parameter(torch.rand((N, M), dtype=torch.float32)) |
| self.linear = torch.nn.Linear(N, M).float() |
| |
| def forward(self, input): |
| with autocast(enabled=True): |
| output = self.weight.mv(input) |
| output = self.linear(output) |
| return output |
| |
| scripted_module = torch.jit.script(TestModule(2, 3)).cuda() |
| input = torch.rand(3, dtype=torch.float32, device='cuda') |
| result = scripted_module(input) |
| self.assertEqual(result.dtype, torch.float16) |
| |
| @unittest.skipIf(True, "autocast decorators not supported") |
| @unittest.skipIf(not TEST_CUDA, "No cuda") |
| def test_autocast_decorator(self): |
| @torch.jit.script |
| @autocast(enabled=True) |
| def fn(a, b): |
| return torch.mm(a, b) |
| result = fn(self.a_fp32, self.b_fp32) |
| self.assertEqual(result.dtype, torch.float16) |
| |
| # this is equivalent to running scripted functions inside autocast) |
| # (see also test_eager_and_script) |
| @unittest.skipIf(not TEST_CUDA, "No cuda") |
| def test_autocast_decorator_outside_jit(self): |
| @autocast(enabled=True) |
| @torch.jit.script |
| def fn(a, b): |
| return torch.mm(a, b) |
| result = fn(self.a_fp32, self.b_fp32) |
| self.assertEqual(result.dtype, torch.float16) |
| |
| @unittest.skipIf(not TEST_CUDA, "No cuda") |
| def test_inplace(self): |
| @torch.jit.script |
| def fn(a, b, c): |
| with autocast(enabled=True): |
| x = torch.addmm(a, b, c) |
| y = torch.addmm(a, b, c, out=a) |
| z = a.addmm_(b, c) |
| return x, y, z |
| x, y, z = fn(self.a_fp32, self.b_fp32, self.c_fp32) |
| self.assertEqual(x.dtype, torch.float16) |
| self.assertEqual(y.dtype, torch.float32) |
| self.assertEqual(z.dtype, torch.float32) |
| |
| def _test_autocast(self, func, cast_op, *args): |
| jit_func = torch.jit.script(func) |
| o = func(*args) |
| jit_o = jit_func(*args) |
| if cast_op is not None: |
| FileCheck().check(cast_op).run(jit_func.graph_for(*args)) |
| for o0, o1 in zip(o, jit_o): |
| self.assertEqual(o0.dtype, o1.dtype) |
| |
| @unittest.skipIf(not TEST_CUDA, "No cuda") |
| def test_autocast_api(self): |
| |
| def t_autocast_cpu(x, y): |
| with torch.autocast("cpu", dtype=torch.bfloat16): |
| return torch.mm(x, y) |
| |
| def t_autocast_cuda(x, y): |
| with torch.autocast("cuda", dtype=torch.half): |
| return torch.mm(x, y) |
| |
| def t_cuda_amp_autocast(x, y): |
| with torch.cuda.amp.autocast(): |
| return torch.mm(x, y) |
| |
| def t_cpu_amp_autocast(x, y): |
| with torch.cpu.amp.autocast(): |
| return torch.mm(x, y) |
| |
| x = torch.randn(5, 5, device="cuda", dtype=torch.float32) |
| y = torch.randn(5, 5, device="cuda", dtype=torch.float32) |
| self._test_autocast(t_autocast_cpu, "aten::_autocast_to_reduced_precision", x, y) |
| self._test_autocast(t_autocast_cuda, "aten::_autocast_to_reduced_precision", x, y) |
| self._test_autocast(t_cuda_amp_autocast, "aten::_autocast_to_reduced_precision", x, y) |
| self._test_autocast(t_cpu_amp_autocast, "aten::_autocast_to_reduced_precision", x, y) |
| |
| @unittest.skipIf(True, "we need to provide dtype argument at this moment") |
| @unittest.skipIf(not TEST_CUDA, "No cuda") |
| def test_autocast_api_not_supported(self): |
| |
| def t_autocast_cpu(x, y): |
| # no dtype provided is not currently supported |
| with torch.autocast("cpu"): |
| return torch.mm(x, y) |
| |
| def t_autocast_cuda(x, y): |
| # no dtype provided is not currently supported |
| with torch.autocast("cuda"): |
| return torch.mm(x, y) |
| |
| x = torch.randn(5, 5, device="cuda", dtype=torch.float32) |
| y = torch.randn(5, 5, device="cuda", dtype=torch.float32) |
| self._test_autocast(t_autocast_cpu, "aten::_autocast_to_reduced_precision", x, y) |
| self._test_autocast(t_autocast_cuda, "aten::_autocast_to_reduced_precision", x, y) |
| |
| @unittest.skipIf(not TEST_CUDA, "No cuda") |
| def test_autocast_mixed_dtypes(self): |
| |
| def t(cpu0, cpu1, cuda0, cuda1): |
| with torch.autocast("cpu", torch.bfloat16): |
| with torch.autocast("cuda", torch.float16): |
| cpu_o = torch.mm(cpu0, cpu1) |
| cuda_o = torch.mm(cuda0, cuda1) |
| return cpu_o, cuda_o |
| |
| jit_t = torch.jit.script(t) |
| cpu0 = torch.randn(5, 5, device="cpu", dtype=torch.float32) |
| cpu1 = torch.randn(5, 5, device="cpu", dtype=torch.float32) |
| cuda0 = torch.randn(5, 5, device="cuda", dtype=torch.float32) |
| cuda1 = torch.randn(5, 5, device="cuda", dtype=torch.float32) |
| self._test_autocast(t, "aten::_autocast_to_reduced_precision", cpu0, cpu1, cuda0, cuda1) |
| |
| @unittest.skipIf(not TEST_CUDA, "No cuda") |
| def test_jit_executor_under_autocast(self): |
| |
| def t(cpu0, cpu1, cuda0, cuda1): |
| cpu_o = torch.mm(cpu0, cpu1) |
| cuda_o = torch.mm(cuda0, cuda1) |
| return cpu_o, cuda_o |
| |
| jit_t = torch.jit.script(t) |
| cpu0 = torch.randn(5, 5, device="cpu", dtype=torch.float32) |
| cpu1 = torch.randn(5, 5, device="cpu", dtype=torch.float32) |
| cuda0 = torch.randn(5, 5, device="cuda", dtype=torch.float32) |
| cuda1 = torch.randn(5, 5, device="cuda", dtype=torch.float32) |
| |
| with torch.autocast("cpu", torch.bfloat16): |
| with torch.autocast("cuda", torch.float16): |
| self._test_autocast(t, "aten::_autocast_to_reduced_precision", cpu0, cpu1, cuda0, cuda1) |
| |
| with torch.autocast("cpu", torch.bfloat16): |
| self._test_autocast(t, "aten::_autocast_to_reduced_precision", cpu0, cpu1, cuda0, cuda1) |
| |
| with torch.autocast("cuda", torch.float16): |
| self._test_autocast(t, "aten::_autocast_to_reduced_precision", cpu0, cpu1, cuda0, cuda1) |
| |
| # no cast op should be observed when executing outside autocast context |
| self._test_autocast(t, None, cpu0, cpu1, cuda0, cuda1) |
| |
| @unittest.skipIf(not TEST_CUDA, "No cuda") |
| def test_autocast_autodiff(self): |
| def t(t0, t1): |
| o = torch.mm(t0, t1) |
| return o.relu() |
| |
| jit_t = torch.jit.script(t) |
| t0 = torch.randn(5, 5, device="cuda", dtype=torch.float32).requires_grad_() |
| t1 = torch.randn(5, 5, device="cuda", dtype=torch.float32).requires_grad_() |
| |
| # run optimization |
| for i in range(5): |
| with torch.autocast("cuda", torch.float16): |
| jit_o = jit_t(t0, t1) |
| jit_o.sum().backward() |
| |
| t0.grad = None |
| t1.grad = None |
| ref_t0 = t0.detach().requires_grad_() |
| ref_t1 = t1.detach().requires_grad_() |
| |
| with torch.autocast("cuda", torch.float16): |
| o = t(ref_t0, ref_t1) |
| jit_o = jit_t(t0, t1) |
| jit_o.sum().backward() |
| o.sum().backward() |
| self.assertEqual(o, jit_o) |
| self.assertEqual(t0.grad, ref_t0.grad) |
| self.assertEqual(t1.grad, ref_t1.grad) |
| self.assertEqual(o.dtype, jit_o.dtype) |
| self.assertEqual(t0.grad.dtype, ref_t0.grad.dtype) |
| self.assertEqual(t1.grad.dtype, ref_t1.grad.dtype) |
| |
| @unittest.skipIf(not TEST_CUDA, "No cuda") |
| def test_jit_call_method_under_autocast(self): |
| @torch.jit.interface |
| class Iface(torch.nn.Module): |
| def forward(self, x, y) -> torch.Tensor: |
| pass |
| |
| class Impl(Iface): |
| def forward(self, x, y): |
| return torch.mm(x, y) |
| |
| class Thing1(torch.nn.Module): |
| impl: Iface |
| |
| def forward(self, x, y): |
| with torch.cuda.amp.autocast(): |
| a = torch.mm(x, y) |
| b = self.impl.forward(a, x) |
| return b |
| |
| scripted_impl = torch.jit.script(Impl()) |
| thing1 = Thing1() |
| thing1.impl = scripted_impl |
| scripted_thing1 = torch.jit.script(thing1) |
| x = torch.rand([2, 2]) |
| y = torch.rand([2, 2]) |
| |
| # make sure this doesn't throw an error |
| with torch.cuda.amp.autocast(): |
| ans = scripted_thing1.forward(x, y) |
| self.assertEqual(torch.mm(torch.mm(x, y), x), ans) |
| |
| # sanity check: this isn't supported currently when global autocasting |
| # isn't enabled |
| self.assertRaises(RuntimeError, lambda: scripted_thing1.forward(x, y)) |
| |
| @unittest.skipIf(not TEST_CUDA, "No cuda") |
| def test_jit_freeze_autocast_basic(self): |
| class TestModule(torch.nn.Module): |
| def forward(self, x, y): |
| with torch.cuda.amp.autocast(): |
| return torch.mm(x, y) |
| |
| x = torch.rand((3, 4), dtype=torch.float).cuda() |
| y = torch.rand((4, 5), dtype=torch.float).cuda() |
| |
| mod = TestModule().eval() |
| |
| # sanity check |
| self._test_autocast(mod, "aten::_autocast_to_reduced_precision", x, y) |
| |
| frozen_mod = torch.jit.freeze(torch.jit.script(mod).eval()) |
| FileCheck().check_count("aten::_autocast_to_reduced_precision", 2, True).run(frozen_mod.graph) |
| |
| # make sure that the runtime pass doesn't duplicate autocast nodes |
| frozen_mod(x, y) |
| optimized_graph = frozen_mod.graph_for(x, y) |
| FileCheck().check_count("aten::_autocast_to_reduced_precision", 2, True).run(optimized_graph) |
| |
| @unittest.skipIf(not TEST_CUDA, "No cuda") |
| def test_jit_freeze_autocast_constants(self): |
| class TestModule(torch.nn.Module): |
| def __init__(self) -> None: |
| super().__init__() |
| self.x = torch.rand((3, 4), dtype=torch.float).cuda() |
| |
| def forward(self, y): |
| with torch.cuda.amp.autocast(): |
| return torch.mm(self.x, y) |
| |
| y = torch.rand((4, 5), dtype=torch.float).cuda() |
| mod = TestModule().eval() |
| |
| frozen_mod = torch.jit.freeze(torch.jit.script(mod).eval()) |
| # freezing should pre-cast the constant self.x to remove one autocast call |
| FileCheck().check_count("aten::_autocast_to_reduced_precision", 1, True).run(frozen_mod.graph) |
| |
| # the runtime autocasting pass will re-insert the second autocast call, |
| # but constant propagation will merge it with the constant that it's casting. |
| frozen_mod(y) |
| optimized_graph = frozen_mod.graph_for(y) |
| FileCheck().check_count("aten::_autocast_to_reduced_precision", 1, True).run(optimized_graph) |
| |
| @unittest.skipIf(TEST_CUDA, "CPU-only test") |
| def test_jit_autocast_softmax_cpu(self): |
| def fn(x): |
| with torch.cpu.amp.autocast(): |
| return torch.nn.functional.softmax(x, dim=0) |
| |
| fn_s = torch.jit.script(fn) |
| x = torch.rand((2, 2), dtype=torch.bfloat16) |
| fn_s(x) |
| y = fn_s(x) |
| |
| self.assertTrue(y.dtype == torch.bfloat16) |
| |
| @unittest.skipIf(not TEST_CUDA, "No cuda") |
| def test_jit_autocast_softmax_gpu(self): |
| def fn(x): |
| with torch.cuda.amp.autocast(): |
| return torch.nn.functional.softmax(x, dim=0) |
| |
| fn_s = torch.jit.script(fn) |
| x = torch.rand((2, 2), dtype=torch.half).cuda() |
| fn_s(x) |
| y = fn_s(x) |
| |
| self.assertTrue(y.dtype == torch.float) |
| |
| def test_ignore_amp(self): |
| @torch.jit.script |
| def foo(x): |
| return torch.mm(x, x) |
| |
| inp = torch.rand([10, 10], dtype=torch.float) |
| foo._set_ignore_amp(True) |
| with torch.cpu.amp.autocast(): |
| foo(inp) |
| foo(inp) |
| |
| g = torch.jit.last_executed_optimized_graph() |
| FileCheck().check_not("_autocast_to_reduced").run(g) |
| |
| class convbn(torch.nn.Module): |
| def __init__(self, bias_enabled=True): |
| super().__init__() |
| self.conv = torch.nn.Conv2d(3, 64, 7, stride=2, bias=bias_enabled) |
| self.bn = torch.nn.BatchNorm2d(64) |
| |
| def forward(self, x): |
| return self.bn(self.conv(x)) |
| |
| @skipIfTorchDynamo("Not a TorchDynamo suitable test") |
| class TestJitTraceAutocast(JitTestCase): |
| def setUp(self): |
| super().setUp() |
| self.previous_default_dtype = torch.get_default_dtype() |
| torch.set_default_dtype(torch.float32) |
| self.models = [MnistNet(), |
| convbn(bias_enabled=True), |
| convbn(bias_enabled=False)] |
| self.inputs = [torch.randn(5, 1, 28, 28, device='cpu'), |
| torch.randn(32, 3, 224, 224, device='cpu'), |
| torch.randn(32, 3, 224, 224, device='cpu')] |
| self.previous_jit_autocast_pass = torch._C._jit_set_autocast_mode(False) |
| |
| def tearDown(self): |
| torch._C._jit_set_autocast_mode(self.previous_jit_autocast_pass) |
| torch.set_default_dtype(self.previous_default_dtype) |
| super().tearDown() |
| |
| def test_generate_autocast_jit_trace_model(self): |
| def test_generate_autocast_jit_trace_model(model, x): |
| model.eval() |
| with torch.cpu.amp.autocast(cache_enabled=False), torch.no_grad(): |
| traced_model = torch.jit.trace(model, x) |
| traced_model = torch.jit.freeze(traced_model) |
| for i in range(self.models.__len__()): |
| test_generate_autocast_jit_trace_model(self.models[i], self.inputs[i]) |
| |
| def test_nchw_autocast_jit_trace_model(self): |
| def test_nchw_autocast_jit_trace_model(model, x): |
| model.eval() |
| with torch.cpu.amp.autocast(cache_enabled=False), torch.no_grad(): |
| traced_model = torch.jit.trace(model, x) |
| traced_model = torch.jit.freeze(traced_model) |
| with torch.no_grad(): |
| y = traced_model(x.clone()) |
| with torch.cpu.amp.autocast(), torch.no_grad(): |
| y2 = model(x.clone()) |
| torch.testing.assert_close(y.double(), y2.double(), rtol=1e-03, atol=1e-03) |
| for i in range(self.models.__len__()): |
| test_nchw_autocast_jit_trace_model(self.models[i], self.inputs[i]) |
| |
| def test_nhwc_autocast_jit_trace_model(self): |
| def test_nhwc_autocast_jit_trace_model(model, x): |
| model = model.to(memory_format=torch.channels_last) |
| model.eval() |
| with torch.cpu.amp.autocast(cache_enabled=False), torch.no_grad(): |
| traced_model = torch.jit.trace(model, x.to(memory_format=torch.channels_last)) |
| traced_model = torch.jit.freeze(traced_model) |
| with torch.no_grad(): |
| y = traced_model(x.clone().to(memory_format=torch.channels_last)) |
| with torch.cpu.amp.autocast(), torch.no_grad(): |
| y2 = model(x.clone().to(memory_format=torch.channels_last)) |
| torch.testing.assert_close(y.double(), y2.double(), rtol=1e-03, atol=1e-03) |
| for i in range(self.models.__len__()): |
| if self.inputs[i].size().__len__() == 5: |
| # NHWC 3D case not support yet |
| continue |
| test_nhwc_autocast_jit_trace_model(self.models[i], self.inputs[i]) |
| |
| def test_cat_promote(self): |
| class TestModel(torch.nn.Module): |
| def forward(self, a, b): |
| return torch.cat([a, b], 0) |
| |
| with torch.jit.fuser("none"): |
| # In this testcase, we will check whether cat has done the promotion in AMP with mixed dtype inputs. |
| # To avoid the fusion group from TE, we will disable the fuser here. |
| for jit_freeze_or_not in [False, True]: |
| test_model = TestModel().eval() |
| with torch.cpu.amp.autocast(cache_enabled=False, dtype=torch.bfloat16), torch.no_grad(): |
| a = torch.rand(24, 128, 128) |
| b = torch.rand(24, 128, 128, dtype=torch.bfloat16) |
| c = test_model(a, b) |
| traced = torch.jit.trace(test_model, (a, b)) |
| if jit_freeze_or_not: |
| traced = torch.jit.freeze(traced) |
| for _ in range(3): |
| c2 = traced(a, b) |
| self.assertTrue(c.dtype, torch.float32) |
| self.assertTrue(c2.dtype, torch.float32) |
| traced_graph = traced.graph_for(a, b) |
| self.assertTrue(any(n.kind() == "aten::to" for n in traced_graph.nodes())) |
| |
| def test_script_autocast_cpu(self): |
| def fn(x): |
| if torch.is_autocast_cpu_enabled(): |
| return x.relu() |
| else: |
| return x.sin() |
| |
| fn_s = torch.jit.script(fn) |
| |
| x = torch.rand((4, 4)) - 0.5 |
| with torch.cpu.amp.autocast(): |
| self.assertEqual(fn_s(x), fn(x)) |
| |
| with torch.cpu.amp.autocast(enabled=True): |
| self.assertEqual(fn_s(x), fn(x)) |
| |
| self.assertTrue(any("is_autocast_cpu_enabled" in x.kind() for x in fn_s.graph.nodes())) |
| |
| @unittest.skipIf(not TEST_CUDA, "No cuda") |
| def test_script_autocast_cuda(self): |
| def fn(x): |
| if torch.is_autocast_enabled(): |
| return x.relu() |
| else: |
| return x.sin() |
| |
| fn_s = torch.jit.script(fn) |
| |
| x = torch.rand((4, 4)) - 0.5 |
| with torch.cpu.amp.autocast(): |
| self.assertEqual(fn_s(x), fn(x)) |
| |
| with torch.cuda.amp.autocast(enabled=True): |
| self.assertEqual(fn_s(x), fn(x)) |
| |
| self.assertTrue(any("is_autocast_enabled" in x.kind() for x in fn_s.graph.nodes())) |
| |
| |
| def test_scripted_aliasing(self): |
| # torch.is_autocast_enabled should not be able to move inside of the autocast context. |
| def fn(x): |
| if torch.is_autocast_enabled(): |
| y = True |
| else: |
| y = False |
| with torch.cuda.amp.autocast(enabled=True): |
| z = x.relu() |
| return y, z |
| |
| fn_s = torch.jit.script(fn) |
| graph = fn_s.graph |
| |
| aliasdb = graph.alias_db() |
| |
| is_enabled_nodes = graph.findAllNodes("aten::is_autocast_enabled") |
| enter_nodes = graph.findAllNodes("prim::Enter") |
| |
| self.assertEqual(len(is_enabled_nodes), 1) |
| self.assertEqual(len(enter_nodes), 1) |
| |
| self.assertFalse(aliasdb.move_after_topologically_valid(is_enabled_nodes[0], enter_nodes[0])) |
| |
| |
| def test_script_autocast_enable_and_check(self): |
| def fn(x, y) -> Tuple[torch.Tensor, bool, torch.Tensor, bool, torch.Tensor, bool]: |
| b1 = torch.is_autocast_cpu_enabled() |
| v1 = torch.mm(x, y) |
| with torch.cpu.amp.autocast(enabled=True): |
| b2 = torch.is_autocast_cpu_enabled() |
| v2 = torch.mm(x, y) |
| with torch.cpu.amp.autocast(enabled=False): |
| b3 = torch.is_autocast_cpu_enabled() |
| v3 = torch.mm(x, y) |
| return (v1, b1, v2, b2, v3, b3) |
| |
| # bx = is_autocast_cpu_enabled() result should be False iff (vx = mm(x, y)).dtype is float |
| def check_fn_results(arr): |
| [v1, b1, v2, b2, v3, b3] = arr |
| self.assertTrue((v1.dtype == torch.float) != b1) |
| self.assertTrue((v2.dtype == torch.float) != b2) |
| self.assertTrue((v3.dtype == torch.float) != b3) |
| |
| x = torch.rand((2, 2), dtype=torch.float) |
| y = torch.rand((2, 2), dtype=torch.float) |
| |
| fn_s = torch.jit.script(fn) |
| |
| with torch.cpu.amp.autocast(enabled=False): |
| check_fn_results(fn(x, y)) |
| check_fn_results(fn_s(x, y)) |
| |
| with torch.cpu.amp.autocast(enabled=True): |
| check_fn_results(fn(x, y)) |
| check_fn_results(fn_s(x, y)) |
| |
| |
| if __name__ == "__main__": |
| run_tests() |