Revert "[Dynamo] Support torch.{cuda/cpu}.amp.autocast (#95416)"
This reverts commit c88aa336aa0734f42b4d9db7f624d6cfd9b5065e.
Reverted https://github.com/pytorch/pytorch/pull/95416 on behalf of https://github.com/huydhn due to Sorry for reverting your PR. But it seems that the smoke test issue is related as it starts to fail consistently in trunk https://hud.pytorch.org/hud/pytorch/pytorch/master/1?per_page=50&name_filter=inductor_torchbench_smoketest_perf
diff --git a/functorch/experimental/_cond.py b/functorch/experimental/_cond.py
index 8f88f4a..de1c382 100644
--- a/functorch/experimental/_cond.py
+++ b/functorch/experimental/_cond.py
@@ -247,4 +247,3 @@
cond.fallthrough(DispatchKey.PythonTLSSnapshot)
cond.fallthrough(DispatchKey.ADInplaceOrView)
cond.fallthrough(DispatchKey.BackendSelect)
-cond.fallthrough(DispatchKey.AutocastCPU)
diff --git a/functorch/experimental/_map.py b/functorch/experimental/_map.py
index 6f895d9..5ec3a16 100644
--- a/functorch/experimental/_map.py
+++ b/functorch/experimental/_map.py
@@ -133,4 +133,3 @@
map.fallthrough(DispatchKey.PythonTLSSnapshot)
map.fallthrough(DispatchKey.ADInplaceOrView)
map.fallthrough(DispatchKey.BackendSelect)
-map.fallthrough(DispatchKey.AutocastCPU)
diff --git a/test/dynamo/test_dynamic_shapes.py b/test/dynamo/test_dynamic_shapes.py
index 1e0c639..b77dfaf 100644
--- a/test/dynamo/test_dynamic_shapes.py
+++ b/test/dynamo/test_dynamic_shapes.py
@@ -26,6 +26,7 @@
ALL_DYNAMIC_XFAILS = {
"MiscTests": [
+ "test_autocast_sdpa",
"test_parsing_sdpa",
],
"ReproTests": [
diff --git a/test/dynamo/test_misc.py b/test/dynamo/test_misc.py
index 920d96a..a2202f6 100644
--- a/test/dynamo/test_misc.py
+++ b/test/dynamo/test_misc.py
@@ -3288,51 +3288,10 @@
self.assertEqual(exported.device.index, 0)
self.assertEqual(exported.dtype, torch.bfloat16)
- @unittest.skipIf(not torch.cuda.is_available(), "requires cuda")
- def test_cuda_amp_autocast(self):
- class MyModule(torch.nn.Module):
- def forward(self, x):
- a_float32 = torch.rand((8, 8), device="cuda")
- b_float32 = torch.rand((8, 8), device="cuda")
-
- with torch.cuda.amp.autocast(dtype=torch.torch.float64):
- c_float64 = torch.mm(a_float32, b_float32)
- return c_float64
-
- module = MyModule()
- real = module(torch.tensor([0.5]))
- real_device = real.device
- real_dtype = real.dtype
-
- graph, _ = torch._dynamo.export(module, torch.tensor([[0.0, 0], [0, 0]]))
- exported = graph(torch.tensor([0.5]))
- self.assertEqual(exported.device, real_device)
- self.assertEqual(exported.dtype, real_dtype)
-
- self.assertEqual(exported.device.type, "cuda")
- self.assertEqual(exported.device.index, 0)
- self.assertEqual(exported.dtype, torch.float64)
-
- def test_is_autocast_cpu_enabled(self):
- def fn(a_float32, b_float32):
- with torch.cpu.amp.autocast(dtype=torch.bfloat16):
- c_float16 = torch.mm(a_float32, b_float32)
- if torch.is_autocast_cpu_enabled():
- c_float16 = c_float16 + 1
- return c_float16
-
- a = torch.rand((8, 8))
- b = torch.rand((8, 8))
- ref = fn(a, b)
- opt_fn = torch._dynamo.optimize("eager", nopython=True)(fn)
- res = opt_fn(a, b)
- self.assertTrue(same(ref, res))
-
@unittest.skipIf(
not PLATFORM_SUPPORTS_FUSED_SDPA or not SM80OrLater,
"Can't run fused SDPA on this platform",
)
- @patch.object(torch._dynamo.config, "dynamic_shapes", False)
def test_autocast_sdpa(self):
class MyModule(torch.nn.Module):
def forward(self, query, key, value):
diff --git a/test/test_jit_autocast.py b/test/test_jit_autocast.py
index c71d0cc..6fbb04b 100644
--- a/test/test_jit_autocast.py
+++ b/test/test_jit_autocast.py
@@ -7,13 +7,12 @@
import unittest
from test_jit import JitTestCase
from torch.testing._internal.common_cuda import TEST_CUDA
-from torch.testing._internal.common_utils import run_tests, skipIfTorchDynamo
+from torch.testing._internal.common_utils import run_tests
from torch.testing import FileCheck
from jit.test_models import MnistNet
TEST_BFLOAT16 = TEST_CUDA and torch.cuda.is_bf16_supported()
-@skipIfTorchDynamo("Not a TorchDynamo suitable test")
class TestAutocast(JitTestCase):
def setUp(self):
# common input tensors
@@ -758,7 +757,6 @@
def forward(self, x):
return self.bn(self.conv(x))
-@skipIfTorchDynamo("Not a TorchDynamo suitable test")
class TestJitTraceAutocast(JitTestCase):
def setUp(self):
super().setUp()
diff --git a/torch/_dynamo/allowed_functions.py b/torch/_dynamo/allowed_functions.py
index 869b67a..5440521 100644
--- a/torch/_dynamo/allowed_functions.py
+++ b/torch/_dynamo/allowed_functions.py
@@ -96,6 +96,8 @@
torch.autograd.grad,
torch.clear_autocast_cache,
torch.cuda.current_device,
+ torch.cuda.amp.autocast_mode.autocast,
+ torch.cpu.amp.autocast_mode.autocast,
torch.distributions.constraints.is_dependent,
torch.distributions.normal.Normal,
torch.inference_mode,
diff --git a/torch/_dynamo/variables/misc.py b/torch/_dynamo/variables/misc.py
index 616b2ab..584b275 100644
--- a/torch/_dynamo/variables/misc.py
+++ b/torch/_dynamo/variables/misc.py
@@ -280,19 +280,13 @@
self.mode = mode
def exit(self, tx, *args):
- self.mode = (
- exit_functional_autocast(self.mode[0]),
- tx.output.create_node(
- "call_function", exit_functional_autocast, (self.mode[1],), {}
- ),
+ self.mode = tx.output.create_node(
+ "call_function", exit_functional_autocast, (self.mode,), {}
)
def enter(self, tx):
- self.mode = (
- enter_functional_autocast(*self.target_values),
- tx.output.create_node(
- "call_function", enter_functional_autocast, (*self.target_values,), {}
- ),
+ self.mode = tx.output.create_node(
+ "call_function", enter_functional_autocast, (*self.target_values,), {}
)
def module_name(self):
diff --git a/torch/_dynamo/variables/torch.py b/torch/_dynamo/variables/torch.py
index 3bfd8b9..edf726a 100644
--- a/torch/_dynamo/variables/torch.py
+++ b/torch/_dynamo/variables/torch.py
@@ -64,9 +64,6 @@
torch.finfo,
torch.get_default_dtype,
torch.iinfo,
- torch.is_autocast_cache_enabled,
- torch.is_autocast_cpu_enabled,
- torch.is_autocast_enabled,
torch.is_floating_point,
torch.nn.functional._Reduction.get_enum,
]
@@ -327,13 +324,6 @@
)
elif self.value is torch.amp.autocast_mode.autocast:
return AutocastModeVariable.create(target_values=args, kwargs=kwargs)
- elif self.value in [torch.cuda.amp.autocast, torch.cpu.amp.autocast]:
- assert "device_type" not in kwargs
- if self.value is torch.cuda.amp.autocast:
- kwargs.update({"device_type": ConstantVariable("cuda")})
- else:
- kwargs.update({"device_type": ConstantVariable("cpu")})
- return AutocastModeVariable.create(target_values=args, kwargs=kwargs)
elif self.value in (
torch.profiler.profile,
torch.profiler.record_function,