Disabling amp context when invoking compiler (#138659)
Disabling amp context when invoking compiler (#138624)
Fix for https://github.com/pytorch/pytorch/issues/133974
Pull Request resolved: https://github.com/pytorch/pytorch/pull/138624
Approved by: https://github.com/bdhirsh, https://github.com/drisspg
(cherry picked from commit 5942b2985000e0c69ec955b6c88dee8b5d7e67fd)
Co-authored-by: eellison <[email protected]>
diff --git a/test/inductor/test_cpu_repro.py b/test/inductor/test_cpu_repro.py
index f90b9da..34d908d 100644
--- a/test/inductor/test_cpu_repro.py
+++ b/test/inductor/test_cpu_repro.py
@@ -3941,6 +3941,47 @@
x = torch.randn(1, 4, 2, 2)
self.common(fn, (x,))
+ @parametrize("is_inference", (True, False))
+ def test_disabled_amp(self, is_inference):
+ class M(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.all_head_size = 12 * 64
+ self.dense = nn.Linear(self.all_head_size, self.all_head_size)
+
+ def forward(self, q, k, v):
+ context_layer = F.scaled_dot_product_attention(
+ q, k, v, attn_mask=None, dropout_p=0.2
+ )
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
+ new_context_layer_shape = context_layer.size()[:-2] + (
+ self.all_head_size,
+ )
+ context_layer = context_layer.view(new_context_layer_shape)
+ return self.dense(context_layer)
+
+ mod = M().to(torch.bfloat16).eval()
+
+ q = torch.randn((4, 12, 512, 64), dtype=torch.bfloat16) / 10.0
+ k = torch.randn((4, 12, 512, 64), dtype=torch.bfloat16) / 10.0
+ v = torch.randn((4, 12, 512, 64), dtype=torch.bfloat16) / 10.0
+ inputs = (
+ q,
+ k,
+ v,
+ )
+ compiler_mode = torch.compile(mod)
+ from torch.nn.attention import sdpa_kernel, SDPBackend
+
+ context = contextlib.nullcontext if not is_inference else torch.no_grad
+ with config.patch(
+ {"fallback_random": True}
+ ), torch.cpu.amp.autocast(), context(), sdpa_kernel(SDPBackend.MATH):
+ torch.manual_seed(0)
+ eager = mod(*inputs)
+ torch.manual_seed(0)
+ self.assertEqual(compiler_mode(*inputs), eager)
+
@requires_vectorization
def test_vec_indirect_load_cse_cache(self):
# https://github.com/pytorch/pytorch/issues/123502
diff --git a/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py b/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py
index 5dc236f..b86fbad 100644
--- a/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py
+++ b/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py
@@ -555,7 +555,9 @@
),
)
- with track_graph_compiling(aot_config, "forward"):
+ # AMP is already traced out in joint graph. we do not wish to reapply it accidentally
+ # in the compiler.
+ with track_graph_compiling(aot_config, "forward"), torch._C._DisableAutocast():
# flat_args at this point might still be subclasses-
# make sure to pass the unwrapped fake tensors into the compiler!
adjusted_flat_args = joint_inputs[0]
@@ -620,7 +622,7 @@
# NB: It's important to compile backwards ahead of time, as this may
# add extra guards which we need to apply to the Dynamo cache at
# forwards
- with track_graph_compiling(aot_config, "backward"):
+ with track_graph_compiling(aot_config, "backward"), torch._C._DisableAutocast():
placeholder_list = fx_placeholder_vals(bw_module)
forward_saved_for_backwards_strides = None
@@ -672,28 +674,24 @@
compiled_bw_func = None
if num_symints_saved_for_bw > 0:
- context = torch._C._DisableAutocast if disable_amp else nullcontext
- with context():
- try:
- compiled_bw_func = aot_config.bw_compiler(
- bw_module, placeholder_list
- )
- except Exception as e:
- exc = e
- trace_structured(
- "artifact",
- metadata_fn=lambda: {
- "name": "eager_compile_backwards_failure",
- "encoding": "string",
- },
- payload_fn=lambda: "\n".join(
- traceback.format_exception(exc)
- ),
- )
- log.warning(
- "failed to eagerly compile backwards for dynamic, suppressing in case backwards not needed",
- exc_info=True,
- )
+ try:
+ compiled_bw_func = aot_config.bw_compiler(
+ bw_module, placeholder_list
+ )
+ except Exception as e:
+ exc = e
+ trace_structured(
+ "artifact",
+ metadata_fn=lambda: {
+ "name": "eager_compile_backwards_failure",
+ "encoding": "string",
+ },
+ payload_fn=lambda: "\n".join(traceback.format_exception(exc)),
+ )
+ log.warning(
+ "failed to eagerly compile backwards for dynamic, suppressing in case backwards not needed",
+ exc_info=True,
+ )
# Compiled autograd will run the bw_module in the backward pass,
# so recompilation need happen anyway if the backward pass is ever
# called.