[dynamo] Refactor handling of state in context managers (#112939)

The prior handling was rather buggy...

Pull Request resolved: https://github.com/pytorch/pytorch/pull/112939
Approved by: https://github.com/voznesenskym, https://github.com/yanboliang
ghstack dependencies: #112897, #112898, #112920, #112899
diff --git a/benchmarks/dynamo/ci_expected_accuracy/inductor_timm_dynamic_training.csv b/benchmarks/dynamo/ci_expected_accuracy/inductor_timm_dynamic_training.csv
index 51c385c..60b69ce 100644
--- a/benchmarks/dynamo/ci_expected_accuracy/inductor_timm_dynamic_training.csv
+++ b/benchmarks/dynamo/ci_expected_accuracy/inductor_timm_dynamic_training.csv
@@ -18,7 +18,7 @@
 
 
 
-convit_base,pass,11
+convit_base,pass,9
 
 
 
diff --git a/benchmarks/dynamo/ci_expected_accuracy/inductor_timm_training.csv b/benchmarks/dynamo/ci_expected_accuracy/inductor_timm_training.csv
index 621f590..4b5b798 100644
--- a/benchmarks/dynamo/ci_expected_accuracy/inductor_timm_training.csv
+++ b/benchmarks/dynamo/ci_expected_accuracy/inductor_timm_training.csv
@@ -18,7 +18,7 @@
 
 
 
-convit_base,pass,11
+convit_base,pass,9
 
 
 
diff --git a/benchmarks/dynamo/ci_expected_accuracy/inductor_torchbench_dynamic_training.csv b/benchmarks/dynamo/ci_expected_accuracy/inductor_torchbench_dynamic_training.csv
index 1207754..a5a7ee0 100644
--- a/benchmarks/dynamo/ci_expected_accuracy/inductor_torchbench_dynamic_training.csv
+++ b/benchmarks/dynamo/ci_expected_accuracy/inductor_torchbench_dynamic_training.csv
@@ -6,7 +6,7 @@
 
 
 
-BERT_pytorch,pass,10
+BERT_pytorch,pass,8
 
 
 
@@ -202,4 +202,4 @@
 
 
 
-yolov3,pass,12
+yolov3,pass,10
diff --git a/benchmarks/dynamo/ci_expected_accuracy/inductor_torchbench_training.csv b/benchmarks/dynamo/ci_expected_accuracy/inductor_torchbench_training.csv
index d61cb5f..eede827 100644
--- a/benchmarks/dynamo/ci_expected_accuracy/inductor_torchbench_training.csv
+++ b/benchmarks/dynamo/ci_expected_accuracy/inductor_torchbench_training.csv
@@ -6,7 +6,7 @@
 
 
 
-BERT_pytorch,pass,10
+BERT_pytorch,pass,8
 
 
 
@@ -210,8 +210,8 @@
 
 
 
-vision_maskrcnn,fail_accuracy,39
+vision_maskrcnn,fail_accuracy,37
 
 
 
-yolov3,pass,12
+yolov3,pass,10
diff --git a/test/dynamo/test_aot_autograd.py b/test/dynamo/test_aot_autograd.py
index 58bcb9d..d301339 100644
--- a/test/dynamo/test_aot_autograd.py
+++ b/test/dynamo/test_aot_autograd.py
@@ -957,6 +957,7 @@
             torch.set_grad_enabled(True)
             y = f_compiled(x)
             self.assertEqual(torch.is_grad_enabled(), False)
+            torch.set_grad_enabled(True)
             self.assertEqual(y_ref, y)
 
             self.assertIsNone(y_ref[0].grad_fn)
diff --git a/test/inductor/test_codecache.py b/test/inductor/test_codecache.py
index b2bbf01..bae90ed 100644
--- a/test/inductor/test_codecache.py
+++ b/test/inductor/test_codecache.py
@@ -93,6 +93,7 @@
         cls.tmpdir.cleanup()
 
     def setUp(self):
+        super().setUp()
         counters.clear()
 
     @requires_triton()
diff --git a/test/inductor/test_kernel_benchmark.py b/test/inductor/test_kernel_benchmark.py
index 93d6670..6c0a479 100644
--- a/test/inductor/test_kernel_benchmark.py
+++ b/test/inductor/test_kernel_benchmark.py
@@ -23,6 +23,7 @@
         cls.exit_stack.close()
 
     def setUp(self):
+        super().setUp()
         PyCodeCache.cache.clear()
 
     def get_compiled_module(self):
diff --git a/torch/_dynamo/convert_frame.py b/torch/_dynamo/convert_frame.py
index 5b6aa01..6ba1e46 100644
--- a/torch/_dynamo/convert_frame.py
+++ b/torch/_dynamo/convert_frame.py
@@ -122,7 +122,7 @@
     return result
 
 
-def wrap_convert_context(fn):
+def preserve_global_state(fn):
     """
     Context manager to:
         1) Save/restore torch.is_grad_enabled() state
@@ -135,6 +135,7 @@
     def _fn(*args, **kwargs):
         guards = GlobalStateGuard()
         prior_grad_mode = torch.is_grad_enabled()
+        prior_inference_mode = torch.is_inference_mode_enabled()
         prior_deterministic = torch.are_deterministic_algorithms_enabled()
         prior_warn_only = torch.is_deterministic_algorithms_warn_only_enabled()
         py_rng_state = random.getstate()
@@ -149,6 +150,7 @@
         finally:
             cleanup.close()
             torch._C._set_grad_enabled(prior_grad_mode)
+            torch.torch.autograd.grad_mode._enter_inference_mode(prior_inference_mode)
             torch.use_deterministic_algorithms(
                 prior_deterministic, warn_only=prior_warn_only
             )
@@ -425,7 +427,7 @@
         return convert_frame_assert(backend, one_graph, export, export_constraints)
 
     _convert_frame_assert._clone_with_backend = _clone_with_backend  # type: ignore[attr-defined]
-    return wrap_convert_context(_convert_frame_assert)
+    return _convert_frame_assert
 
 
 def maybe_cprofile(func):
@@ -479,6 +481,7 @@
     mutated_closure_cell_contents: Set[str] = set()
     fail_reason: Optional[str] = None
 
+    @preserve_global_state
     def transform(instructions, code_options):
         nonlocal output
         tracer = InstructionTranslator(
@@ -505,6 +508,8 @@
             if translation_validation_enabled():
                 bisect(tracer.output.shape_env)
             raise
+        finally:
+            tracer.output.call_cleanup_hooks()
 
         output = tracer.output
         assert output is not None
diff --git a/torch/_dynamo/output_graph.py b/torch/_dynamo/output_graph.py
index a16ae90..a90004e 100644
--- a/torch/_dynamo/output_graph.py
+++ b/torch/_dynamo/output_graph.py
@@ -258,6 +258,7 @@
         self.export_constraints = export_constraints
         self.frame_state = frame_state
         self.tensor_weakref_to_sizes_strides: WeakIdKeyDictionary = {}
+        self.cleanup_hooks: List[Callable[[], Any]] = []
 
         # TODO: maybe should just pass the entire f_code in here?  Not
         # sure...
@@ -382,6 +383,14 @@
 
         self.guards.add(GlobalStateSource().make_guard(GuardBuilder.BACKEND_MATCH))
 
+    def add_cleanup_hook(self, fn: Callable[[], Any]):
+        self.cleanup_hooks.append(fn)
+
+    def call_cleanup_hooks(self):
+        for hook in reversed(self.cleanup_hooks):
+            hook()
+        self.cleanup_hooks.clear()
+
     @property
     def root_tracer(self):
         return self.tracers[0]
@@ -1027,6 +1036,7 @@
         graph_sizes_log.debug(
             "%s", LazyString(lambda: self.get_graph_sizes_log_str(name))
         )
+        self.call_cleanup_hooks()
         with self.restore_global_state():
             compiled_fn = self.call_user_compiler(gm)
         compiled_fn = disable(compiled_fn)
diff --git a/torch/_dynamo/test_case.py b/torch/_dynamo/test_case.py
index bf425bf..bd21e27 100644
--- a/torch/_dynamo/test_case.py
+++ b/torch/_dynamo/test_case.py
@@ -1,5 +1,6 @@
 import contextlib
 import importlib
+import logging
 import sys
 
 import torch
@@ -13,6 +14,8 @@
 
 from . import config, reset, utils
 
+log = logging.getLogger(__name__)
+
 
 def run_tests(needs=()):
     from torch.testing._internal.common_utils import run_tests
@@ -53,6 +56,7 @@
         )
 
     def setUp(self):
+        self._prior_is_grad_enabled = torch.is_grad_enabled()
         super().setUp()
         reset()
         utils.counters.clear()
@@ -63,3 +67,6 @@
         reset()
         utils.counters.clear()
         super().tearDown()
+        if self._prior_is_grad_enabled is not torch.is_grad_enabled():
+            log.warning("Running test changed grad mode")
+            torch.set_grad_enabled(self._prior_is_grad_enabled)
diff --git a/torch/_dynamo/variables/ctx_manager.py b/torch/_dynamo/variables/ctx_manager.py
index b0b6005..6fb3c1d 100644
--- a/torch/_dynamo/variables/ctx_manager.py
+++ b/torch/_dynamo/variables/ctx_manager.py
@@ -1,5 +1,6 @@
+import dataclasses
 import inspect
-from typing import Dict, List
+from typing import Callable, Dict, List, Optional
 
 import torch._C
 from torch._guards import Guard
@@ -20,20 +21,60 @@
 )
 
 
[email protected]
+class ContextMangerState:
+    """
+    Mutating `self` in VariableTracker is not allowed because we copy
+    them.  This is a mutable container pointed to by context managers
+    that won't get copied, so it is safe to mutate.
+    """
+
+    cleanup_fn: Optional[Callable] = None
+    proxy: Optional[torch.fx.Proxy] = None
+
+    def cleanup(self):
+        if self.cleanup_fn is not None:
+            self.cleanup_fn()
+            self.cleanup_fn = None
+
+    def cleanup_assert(self):
+        assert self.cleanup_fn, "multiple exits?"
+        self.cleanup()
+
+
 class ContextWrappingVariable(VariableTracker):
-    def __init__(self, target_values, initial_values=None, **kwargs):
+    _nonvar_fields = {
+        "cm_obj",
+        "target_values",
+        "initial_values",
+        "state",
+        *VariableTracker._nonvar_fields,
+    }
+
+    def __init__(self, target_values, initial_values=None, *, state=None, **kwargs):
         super().__init__(**kwargs)
         self.target_values = target_values
         self.initial_values = initial_values
+        self.state = ContextMangerState() if state is None else state
 
     def enter(self, tx):
         self._call_func(tx, self.target_values)
+        self.set_cleanup_hook(tx)
         return variables.ConstantVariable.create(
             None, **VariableTracker.propagate(self)
         )
 
+    def set_cleanup_hook(self, tx, fn=None):
+        if fn is None:
+
+            def fn():
+                self._call_func(tx, self.initial_values)
+
+        self.state.cleanup_fn = fn
+        tx.output.add_cleanup_hook(self.state.cleanup)
+
     def exit(self, tx, *args):
-        self._call_func(tx, self.initial_values)
+        self.state.cleanup_assert()
         return variables.ConstantVariable.create(
             None, **VariableTracker.propagate(self)
         )
@@ -66,8 +107,7 @@
 
 
 class GenericContextWrappingVariable(ContextWrappingVariable):
-    def __init__(self, target_values, initial_values=None, **kwargs):
-        cm_obj = kwargs.pop("cm_obj", None)
+    def __init__(self, target_values, initial_values=None, *, cm_obj=None, **kwargs):
         assert cm_obj is not None
         super().__init__(
             target_values=target_values, initial_values=initial_values, **kwargs
@@ -149,6 +189,12 @@
             None, **VariableTracker.propagate(self)
         )
 
+    def exit(self, tx, *args):
+        self._call_func(tx, self.initial_values)
+        return variables.ConstantVariable.create(
+            None, **VariableTracker.propagate(self)
+        )
+
     def _call_func(self, tx, values):
         assert len(values) == 1
         value = values[0]
@@ -173,35 +219,38 @@
         return var
 
     def __init__(
-        self, target_values, initial_values=torch.is_inference_mode_enabled(), **kwargs
+        self,
+        target_values,
+        initial_values=None,
+        **kwargs,
     ):
-        mode = kwargs.pop("mode", None)
+        if initial_values is None:
+            # This must be called here since function defaults are evaluated at import time
+            initial_values = torch.is_inference_mode_enabled()
         super().__init__(
             target_values=target_values, initial_values=initial_values, **kwargs
         )
         self.target_values = target_values
-        self.mode = mode
 
     def exit(self, tx, *args):
-        self.mode = (
-            torch.autograd.grad_mode._exit_inference_mode(self.mode[0]),
-            tx.output.create_node(
-                "call_function",
-                torch.autograd.grad_mode._exit_inference_mode,
-                (self.mode[1],),
-                {},
-            ),
+        self.state.cleanup_assert()
+        tx.output.create_node(
+            "call_function",
+            torch.autograd.grad_mode._exit_inference_mode,
+            (self.state.proxy,),
+            {},
         )
 
     def enter(self, tx):
-        self.mode = (
-            torch.autograd.grad_mode._enter_inference_mode(self.target_values),
-            tx.output.create_node(
-                "call_function",
-                torch.autograd.grad_mode._enter_inference_mode,
-                (self.target_values,),
-                {},
-            ),
+        ctx = torch.autograd.grad_mode._enter_inference_mode(self.target_values)
+        self.set_cleanup_hook(
+            tx, lambda: torch.autograd.grad_mode._exit_inference_mode(ctx)
+        )
+        self.state.proxy = tx.output.create_node(
+            "call_function",
+            torch.autograd.grad_mode._enter_inference_mode,
+            (self.target_values,),
+            {},
         )
 
     def module_name(self):
@@ -225,6 +274,7 @@
         )
         # mlazos: I think this is here to make sure we don't reinvoke on clone()
         var._call_func(tx, [False])
+        var.set_cleanup_hook(tx)
         return var
 
     def __init__(self, target_values, initial_values=None, **kwargs):
@@ -258,6 +308,7 @@
             **kwargs,
         )
         var._call_func(tx, [target_value])
+        var.set_cleanup_hook(tx)
         return var
 
     def __init__(self, target_values, initial_values=None, **kwargs):
@@ -299,6 +350,7 @@
             **kwargs,
         )
         var._call_func(tx, [target_value])
+        var.set_cleanup_hook(tx)
         return var
 
     def __init__(self, target_values, initial_values=None, **kwargs):
@@ -373,27 +425,22 @@
         return var
 
     def __init__(self, target_values, initial_values=None, **kwargs):
-        mode = kwargs.pop("mode", None)
         super().__init__(
             target_values=target_values, initial_values=initial_values, **kwargs
         )
         self.target_values = target_values
-        self.mode = mode
 
     def exit(self, tx, *args):
-        self.mode = (
-            torch.amp._exit_autocast(self.mode[0]),
-            tx.output.create_node(
-                "call_function", torch.amp._exit_autocast, (self.mode[1],), {}
-            ),
+        self.state.cleanup_assert()
+        tx.output.create_node(
+            "call_function", torch.amp._exit_autocast, (self.state.proxy,), {}
         )
 
     def enter(self, tx):
-        self.mode = (
-            torch.amp._enter_autocast(*self.target_values),
-            tx.output.create_node(
-                "call_function", torch.amp._enter_autocast, (*self.target_values,), {}
-            ),
+        ctx = torch.amp._enter_autocast(*self.target_values)
+        self.set_cleanup_hook(tx, lambda: torch.amp._exit_autocast(ctx))
+        self.state.proxy = tx.output.create_node(
+            "call_function", torch.amp._enter_autocast, (*self.target_values,), {}
         )
 
     def module_name(self):
@@ -464,7 +511,7 @@
         self.set_stream_id = get_interface_for_device(self.device)._set_stream_by_id
 
     def enter(self, tx):
-        # stream generated inside of traced function
+        # stream generated inside the traced function
         if self.target_values[0].as_proxy() is not None:
             tx.output.create_proxy(
                 "call_function",
@@ -472,7 +519,7 @@
                 (self.target_values[0].as_proxy(),),
                 {},
             )
-        # stream passed from outside of traced function
+        # stream passed from outside the traced function
         else:
             stream = self.target_values[0].value
             tx.output.create_proxy(
@@ -482,6 +529,7 @@
                 {},
             )
         self.set_stream(self.target_values[0].value)
+        self.set_cleanup_hook(tx, lambda: self.set_stream(self.initial_values[0].value))
 
     def exit(self, tx, *args):
         tx.output.create_proxy(
@@ -490,7 +538,7 @@
             (self.initial_values[0].as_proxy(),),
             {},
         )
-        self.set_stream(self.initial_values[0].value)
+        self.state.cleanup_assert()
 
     def module_name(self):
         return "torch." + str(self.device)