Revert "[Dynamo] Support torch.{cuda/cpu}.amp.autocast (#95416)"

This reverts commit c88aa336aa0734f42b4d9db7f624d6cfd9b5065e.

Reverted https://github.com/pytorch/pytorch/pull/95416 on behalf of https://github.com/huydhn due to Sorry for reverting your PR. But it seems that the smoke test issue is related as it starts to fail consistently in trunk https://hud.pytorch.org/hud/pytorch/pytorch/master/1?per_page=50&name_filter=inductor_torchbench_smoketest_perf
diff --git a/functorch/experimental/_cond.py b/functorch/experimental/_cond.py
index 8f88f4a..de1c382 100644
--- a/functorch/experimental/_cond.py
+++ b/functorch/experimental/_cond.py
@@ -247,4 +247,3 @@
 cond.fallthrough(DispatchKey.PythonTLSSnapshot)
 cond.fallthrough(DispatchKey.ADInplaceOrView)
 cond.fallthrough(DispatchKey.BackendSelect)
-cond.fallthrough(DispatchKey.AutocastCPU)
diff --git a/functorch/experimental/_map.py b/functorch/experimental/_map.py
index 6f895d9..5ec3a16 100644
--- a/functorch/experimental/_map.py
+++ b/functorch/experimental/_map.py
@@ -133,4 +133,3 @@
 map.fallthrough(DispatchKey.PythonTLSSnapshot)
 map.fallthrough(DispatchKey.ADInplaceOrView)
 map.fallthrough(DispatchKey.BackendSelect)
-map.fallthrough(DispatchKey.AutocastCPU)
diff --git a/test/dynamo/test_dynamic_shapes.py b/test/dynamo/test_dynamic_shapes.py
index 1e0c639..b77dfaf 100644
--- a/test/dynamo/test_dynamic_shapes.py
+++ b/test/dynamo/test_dynamic_shapes.py
@@ -26,6 +26,7 @@
 
 ALL_DYNAMIC_XFAILS = {
     "MiscTests": [
+        "test_autocast_sdpa",
         "test_parsing_sdpa",
     ],
     "ReproTests": [
diff --git a/test/dynamo/test_misc.py b/test/dynamo/test_misc.py
index 920d96a..a2202f6 100644
--- a/test/dynamo/test_misc.py
+++ b/test/dynamo/test_misc.py
@@ -3288,51 +3288,10 @@
         self.assertEqual(exported.device.index, 0)
         self.assertEqual(exported.dtype, torch.bfloat16)
 
-    @unittest.skipIf(not torch.cuda.is_available(), "requires cuda")
-    def test_cuda_amp_autocast(self):
-        class MyModule(torch.nn.Module):
-            def forward(self, x):
-                a_float32 = torch.rand((8, 8), device="cuda")
-                b_float32 = torch.rand((8, 8), device="cuda")
-
-                with torch.cuda.amp.autocast(dtype=torch.torch.float64):
-                    c_float64 = torch.mm(a_float32, b_float32)
-                return c_float64
-
-        module = MyModule()
-        real = module(torch.tensor([0.5]))
-        real_device = real.device
-        real_dtype = real.dtype
-
-        graph, _ = torch._dynamo.export(module, torch.tensor([[0.0, 0], [0, 0]]))
-        exported = graph(torch.tensor([0.5]))
-        self.assertEqual(exported.device, real_device)
-        self.assertEqual(exported.dtype, real_dtype)
-
-        self.assertEqual(exported.device.type, "cuda")
-        self.assertEqual(exported.device.index, 0)
-        self.assertEqual(exported.dtype, torch.float64)
-
-    def test_is_autocast_cpu_enabled(self):
-        def fn(a_float32, b_float32):
-            with torch.cpu.amp.autocast(dtype=torch.bfloat16):
-                c_float16 = torch.mm(a_float32, b_float32)
-                if torch.is_autocast_cpu_enabled():
-                    c_float16 = c_float16 + 1
-            return c_float16
-
-        a = torch.rand((8, 8))
-        b = torch.rand((8, 8))
-        ref = fn(a, b)
-        opt_fn = torch._dynamo.optimize("eager", nopython=True)(fn)
-        res = opt_fn(a, b)
-        self.assertTrue(same(ref, res))
-
     @unittest.skipIf(
         not PLATFORM_SUPPORTS_FUSED_SDPA or not SM80OrLater,
         "Can't run fused SDPA on this platform",
     )
-    @patch.object(torch._dynamo.config, "dynamic_shapes", False)
     def test_autocast_sdpa(self):
         class MyModule(torch.nn.Module):
             def forward(self, query, key, value):
diff --git a/test/test_jit_autocast.py b/test/test_jit_autocast.py
index c71d0cc..6fbb04b 100644
--- a/test/test_jit_autocast.py
+++ b/test/test_jit_autocast.py
@@ -7,13 +7,12 @@
 import unittest
 from test_jit import JitTestCase
 from torch.testing._internal.common_cuda import TEST_CUDA
-from torch.testing._internal.common_utils import run_tests, skipIfTorchDynamo
+from torch.testing._internal.common_utils import run_tests
 from torch.testing import FileCheck
 from jit.test_models import MnistNet
 
 TEST_BFLOAT16 = TEST_CUDA and torch.cuda.is_bf16_supported()
 
-@skipIfTorchDynamo("Not a TorchDynamo suitable test")
 class TestAutocast(JitTestCase):
     def setUp(self):
         # common input tensors
@@ -758,7 +757,6 @@
     def forward(self, x):
         return self.bn(self.conv(x))
 
-@skipIfTorchDynamo("Not a TorchDynamo suitable test")
 class TestJitTraceAutocast(JitTestCase):
     def setUp(self):
         super().setUp()
diff --git a/torch/_dynamo/allowed_functions.py b/torch/_dynamo/allowed_functions.py
index 869b67a..5440521 100644
--- a/torch/_dynamo/allowed_functions.py
+++ b/torch/_dynamo/allowed_functions.py
@@ -96,6 +96,8 @@
         torch.autograd.grad,
         torch.clear_autocast_cache,
         torch.cuda.current_device,
+        torch.cuda.amp.autocast_mode.autocast,
+        torch.cpu.amp.autocast_mode.autocast,
         torch.distributions.constraints.is_dependent,
         torch.distributions.normal.Normal,
         torch.inference_mode,
diff --git a/torch/_dynamo/variables/misc.py b/torch/_dynamo/variables/misc.py
index 616b2ab..584b275 100644
--- a/torch/_dynamo/variables/misc.py
+++ b/torch/_dynamo/variables/misc.py
@@ -280,19 +280,13 @@
         self.mode = mode
 
     def exit(self, tx, *args):
-        self.mode = (
-            exit_functional_autocast(self.mode[0]),
-            tx.output.create_node(
-                "call_function", exit_functional_autocast, (self.mode[1],), {}
-            ),
+        self.mode = tx.output.create_node(
+            "call_function", exit_functional_autocast, (self.mode,), {}
         )
 
     def enter(self, tx):
-        self.mode = (
-            enter_functional_autocast(*self.target_values),
-            tx.output.create_node(
-                "call_function", enter_functional_autocast, (*self.target_values,), {}
-            ),
+        self.mode = tx.output.create_node(
+            "call_function", enter_functional_autocast, (*self.target_values,), {}
         )
 
     def module_name(self):
diff --git a/torch/_dynamo/variables/torch.py b/torch/_dynamo/variables/torch.py
index 3bfd8b9..edf726a 100644
--- a/torch/_dynamo/variables/torch.py
+++ b/torch/_dynamo/variables/torch.py
@@ -64,9 +64,6 @@
     torch.finfo,
     torch.get_default_dtype,
     torch.iinfo,
-    torch.is_autocast_cache_enabled,
-    torch.is_autocast_cpu_enabled,
-    torch.is_autocast_enabled,
     torch.is_floating_point,
     torch.nn.functional._Reduction.get_enum,
 ]
@@ -327,13 +324,6 @@
             )
         elif self.value is torch.amp.autocast_mode.autocast:
             return AutocastModeVariable.create(target_values=args, kwargs=kwargs)
-        elif self.value in [torch.cuda.amp.autocast, torch.cpu.amp.autocast]:
-            assert "device_type" not in kwargs
-            if self.value is torch.cuda.amp.autocast:
-                kwargs.update({"device_type": ConstantVariable("cuda")})
-            else:
-                kwargs.update({"device_type": ConstantVariable("cpu")})
-            return AutocastModeVariable.create(target_values=args, kwargs=kwargs)
         elif self.value in (
             torch.profiler.profile,
             torch.profiler.record_function,