Disabling amp context when invoking compiler (#138659)

Disabling amp context when invoking compiler (#138624)

Fix for https://github.com/pytorch/pytorch/issues/133974

Pull Request resolved: https://github.com/pytorch/pytorch/pull/138624
Approved by: https://github.com/bdhirsh, https://github.com/drisspg

(cherry picked from commit 5942b2985000e0c69ec955b6c88dee8b5d7e67fd)

Co-authored-by: eellison <[email protected]>
diff --git a/test/inductor/test_cpu_repro.py b/test/inductor/test_cpu_repro.py
index f90b9da..34d908d 100644
--- a/test/inductor/test_cpu_repro.py
+++ b/test/inductor/test_cpu_repro.py
@@ -3941,6 +3941,47 @@
         x = torch.randn(1, 4, 2, 2)
         self.common(fn, (x,))
 
+    @parametrize("is_inference", (True, False))
+    def test_disabled_amp(self, is_inference):
+        class M(torch.nn.Module):
+            def __init__(self):
+                super().__init__()
+                self.all_head_size = 12 * 64
+                self.dense = nn.Linear(self.all_head_size, self.all_head_size)
+
+            def forward(self, q, k, v):
+                context_layer = F.scaled_dot_product_attention(
+                    q, k, v, attn_mask=None, dropout_p=0.2
+                )
+                context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
+                new_context_layer_shape = context_layer.size()[:-2] + (
+                    self.all_head_size,
+                )
+                context_layer = context_layer.view(new_context_layer_shape)
+                return self.dense(context_layer)
+
+        mod = M().to(torch.bfloat16).eval()
+
+        q = torch.randn((4, 12, 512, 64), dtype=torch.bfloat16) / 10.0
+        k = torch.randn((4, 12, 512, 64), dtype=torch.bfloat16) / 10.0
+        v = torch.randn((4, 12, 512, 64), dtype=torch.bfloat16) / 10.0
+        inputs = (
+            q,
+            k,
+            v,
+        )
+        compiler_mode = torch.compile(mod)
+        from torch.nn.attention import sdpa_kernel, SDPBackend
+
+        context = contextlib.nullcontext if not is_inference else torch.no_grad
+        with config.patch(
+            {"fallback_random": True}
+        ), torch.cpu.amp.autocast(), context(), sdpa_kernel(SDPBackend.MATH):
+            torch.manual_seed(0)
+            eager = mod(*inputs)
+            torch.manual_seed(0)
+            self.assertEqual(compiler_mode(*inputs), eager)
+
     @requires_vectorization
     def test_vec_indirect_load_cse_cache(self):
         # https://github.com/pytorch/pytorch/issues/123502
diff --git a/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py b/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py
index 5dc236f..b86fbad 100644
--- a/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py
+++ b/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py
@@ -555,7 +555,9 @@
                 ),
             )
 
-        with track_graph_compiling(aot_config, "forward"):
+        # AMP is already traced out in joint graph. we do not wish to reapply it accidentally
+        # in the compiler.
+        with track_graph_compiling(aot_config, "forward"), torch._C._DisableAutocast():
             # flat_args at this point might still be subclasses-
             # make sure to pass the unwrapped fake tensors into the compiler!
             adjusted_flat_args = joint_inputs[0]
@@ -620,7 +622,7 @@
         # NB: It's important to compile backwards ahead of time, as this may
         # add extra guards which we need to apply to the Dynamo cache at
         # forwards
-        with track_graph_compiling(aot_config, "backward"):
+        with track_graph_compiling(aot_config, "backward"), torch._C._DisableAutocast():
             placeholder_list = fx_placeholder_vals(bw_module)
 
             forward_saved_for_backwards_strides = None
@@ -672,28 +674,24 @@
 
             compiled_bw_func = None
             if num_symints_saved_for_bw > 0:
-                context = torch._C._DisableAutocast if disable_amp else nullcontext
-                with context():
-                    try:
-                        compiled_bw_func = aot_config.bw_compiler(
-                            bw_module, placeholder_list
-                        )
-                    except Exception as e:
-                        exc = e
-                        trace_structured(
-                            "artifact",
-                            metadata_fn=lambda: {
-                                "name": "eager_compile_backwards_failure",
-                                "encoding": "string",
-                            },
-                            payload_fn=lambda: "\n".join(
-                                traceback.format_exception(exc)
-                            ),
-                        )
-                        log.warning(
-                            "failed to eagerly compile backwards for dynamic, suppressing in case backwards not needed",
-                            exc_info=True,
-                        )
+                try:
+                    compiled_bw_func = aot_config.bw_compiler(
+                        bw_module, placeholder_list
+                    )
+                except Exception as e:
+                    exc = e
+                    trace_structured(
+                        "artifact",
+                        metadata_fn=lambda: {
+                            "name": "eager_compile_backwards_failure",
+                            "encoding": "string",
+                        },
+                        payload_fn=lambda: "\n".join(traceback.format_exception(exc)),
+                    )
+                    log.warning(
+                        "failed to eagerly compile backwards for dynamic, suppressing in case backwards not needed",
+                        exc_info=True,
+                    )
             # Compiled autograd will run the bw_module in the backward pass,
             # so recompilation need happen anyway if the backward pass is ever
             # called.