[dynamo] Refactor handling of state in context managers (#112939)
The prior handling was rather buggy...
Pull Request resolved: https://github.com/pytorch/pytorch/pull/112939
Approved by: https://github.com/voznesenskym, https://github.com/yanboliang
ghstack dependencies: #112897, #112898, #112920, #112899
diff --git a/benchmarks/dynamo/ci_expected_accuracy/inductor_timm_dynamic_training.csv b/benchmarks/dynamo/ci_expected_accuracy/inductor_timm_dynamic_training.csv
index 51c385c..60b69ce 100644
--- a/benchmarks/dynamo/ci_expected_accuracy/inductor_timm_dynamic_training.csv
+++ b/benchmarks/dynamo/ci_expected_accuracy/inductor_timm_dynamic_training.csv
@@ -18,7 +18,7 @@
-convit_base,pass,11
+convit_base,pass,9
diff --git a/benchmarks/dynamo/ci_expected_accuracy/inductor_timm_training.csv b/benchmarks/dynamo/ci_expected_accuracy/inductor_timm_training.csv
index 621f590..4b5b798 100644
--- a/benchmarks/dynamo/ci_expected_accuracy/inductor_timm_training.csv
+++ b/benchmarks/dynamo/ci_expected_accuracy/inductor_timm_training.csv
@@ -18,7 +18,7 @@
-convit_base,pass,11
+convit_base,pass,9
diff --git a/benchmarks/dynamo/ci_expected_accuracy/inductor_torchbench_dynamic_training.csv b/benchmarks/dynamo/ci_expected_accuracy/inductor_torchbench_dynamic_training.csv
index 1207754..a5a7ee0 100644
--- a/benchmarks/dynamo/ci_expected_accuracy/inductor_torchbench_dynamic_training.csv
+++ b/benchmarks/dynamo/ci_expected_accuracy/inductor_torchbench_dynamic_training.csv
@@ -6,7 +6,7 @@
-BERT_pytorch,pass,10
+BERT_pytorch,pass,8
@@ -202,4 +202,4 @@
-yolov3,pass,12
+yolov3,pass,10
diff --git a/benchmarks/dynamo/ci_expected_accuracy/inductor_torchbench_training.csv b/benchmarks/dynamo/ci_expected_accuracy/inductor_torchbench_training.csv
index d61cb5f..eede827 100644
--- a/benchmarks/dynamo/ci_expected_accuracy/inductor_torchbench_training.csv
+++ b/benchmarks/dynamo/ci_expected_accuracy/inductor_torchbench_training.csv
@@ -6,7 +6,7 @@
-BERT_pytorch,pass,10
+BERT_pytorch,pass,8
@@ -210,8 +210,8 @@
-vision_maskrcnn,fail_accuracy,39
+vision_maskrcnn,fail_accuracy,37
-yolov3,pass,12
+yolov3,pass,10
diff --git a/test/dynamo/test_aot_autograd.py b/test/dynamo/test_aot_autograd.py
index 58bcb9d..d301339 100644
--- a/test/dynamo/test_aot_autograd.py
+++ b/test/dynamo/test_aot_autograd.py
@@ -957,6 +957,7 @@
torch.set_grad_enabled(True)
y = f_compiled(x)
self.assertEqual(torch.is_grad_enabled(), False)
+ torch.set_grad_enabled(True)
self.assertEqual(y_ref, y)
self.assertIsNone(y_ref[0].grad_fn)
diff --git a/test/inductor/test_codecache.py b/test/inductor/test_codecache.py
index b2bbf01..bae90ed 100644
--- a/test/inductor/test_codecache.py
+++ b/test/inductor/test_codecache.py
@@ -93,6 +93,7 @@
cls.tmpdir.cleanup()
def setUp(self):
+ super().setUp()
counters.clear()
@requires_triton()
diff --git a/test/inductor/test_kernel_benchmark.py b/test/inductor/test_kernel_benchmark.py
index 93d6670..6c0a479 100644
--- a/test/inductor/test_kernel_benchmark.py
+++ b/test/inductor/test_kernel_benchmark.py
@@ -23,6 +23,7 @@
cls.exit_stack.close()
def setUp(self):
+ super().setUp()
PyCodeCache.cache.clear()
def get_compiled_module(self):
diff --git a/torch/_dynamo/convert_frame.py b/torch/_dynamo/convert_frame.py
index 5b6aa01..6ba1e46 100644
--- a/torch/_dynamo/convert_frame.py
+++ b/torch/_dynamo/convert_frame.py
@@ -122,7 +122,7 @@
return result
-def wrap_convert_context(fn):
+def preserve_global_state(fn):
"""
Context manager to:
1) Save/restore torch.is_grad_enabled() state
@@ -135,6 +135,7 @@
def _fn(*args, **kwargs):
guards = GlobalStateGuard()
prior_grad_mode = torch.is_grad_enabled()
+ prior_inference_mode = torch.is_inference_mode_enabled()
prior_deterministic = torch.are_deterministic_algorithms_enabled()
prior_warn_only = torch.is_deterministic_algorithms_warn_only_enabled()
py_rng_state = random.getstate()
@@ -149,6 +150,7 @@
finally:
cleanup.close()
torch._C._set_grad_enabled(prior_grad_mode)
+ torch.torch.autograd.grad_mode._enter_inference_mode(prior_inference_mode)
torch.use_deterministic_algorithms(
prior_deterministic, warn_only=prior_warn_only
)
@@ -425,7 +427,7 @@
return convert_frame_assert(backend, one_graph, export, export_constraints)
_convert_frame_assert._clone_with_backend = _clone_with_backend # type: ignore[attr-defined]
- return wrap_convert_context(_convert_frame_assert)
+ return _convert_frame_assert
def maybe_cprofile(func):
@@ -479,6 +481,7 @@
mutated_closure_cell_contents: Set[str] = set()
fail_reason: Optional[str] = None
+ @preserve_global_state
def transform(instructions, code_options):
nonlocal output
tracer = InstructionTranslator(
@@ -505,6 +508,8 @@
if translation_validation_enabled():
bisect(tracer.output.shape_env)
raise
+ finally:
+ tracer.output.call_cleanup_hooks()
output = tracer.output
assert output is not None
diff --git a/torch/_dynamo/output_graph.py b/torch/_dynamo/output_graph.py
index a16ae90..a90004e 100644
--- a/torch/_dynamo/output_graph.py
+++ b/torch/_dynamo/output_graph.py
@@ -258,6 +258,7 @@
self.export_constraints = export_constraints
self.frame_state = frame_state
self.tensor_weakref_to_sizes_strides: WeakIdKeyDictionary = {}
+ self.cleanup_hooks: List[Callable[[], Any]] = []
# TODO: maybe should just pass the entire f_code in here? Not
# sure...
@@ -382,6 +383,14 @@
self.guards.add(GlobalStateSource().make_guard(GuardBuilder.BACKEND_MATCH))
+ def add_cleanup_hook(self, fn: Callable[[], Any]):
+ self.cleanup_hooks.append(fn)
+
+ def call_cleanup_hooks(self):
+ for hook in reversed(self.cleanup_hooks):
+ hook()
+ self.cleanup_hooks.clear()
+
@property
def root_tracer(self):
return self.tracers[0]
@@ -1027,6 +1036,7 @@
graph_sizes_log.debug(
"%s", LazyString(lambda: self.get_graph_sizes_log_str(name))
)
+ self.call_cleanup_hooks()
with self.restore_global_state():
compiled_fn = self.call_user_compiler(gm)
compiled_fn = disable(compiled_fn)
diff --git a/torch/_dynamo/test_case.py b/torch/_dynamo/test_case.py
index bf425bf..bd21e27 100644
--- a/torch/_dynamo/test_case.py
+++ b/torch/_dynamo/test_case.py
@@ -1,5 +1,6 @@
import contextlib
import importlib
+import logging
import sys
import torch
@@ -13,6 +14,8 @@
from . import config, reset, utils
+log = logging.getLogger(__name__)
+
def run_tests(needs=()):
from torch.testing._internal.common_utils import run_tests
@@ -53,6 +56,7 @@
)
def setUp(self):
+ self._prior_is_grad_enabled = torch.is_grad_enabled()
super().setUp()
reset()
utils.counters.clear()
@@ -63,3 +67,6 @@
reset()
utils.counters.clear()
super().tearDown()
+ if self._prior_is_grad_enabled is not torch.is_grad_enabled():
+ log.warning("Running test changed grad mode")
+ torch.set_grad_enabled(self._prior_is_grad_enabled)
diff --git a/torch/_dynamo/variables/ctx_manager.py b/torch/_dynamo/variables/ctx_manager.py
index b0b6005..6fb3c1d 100644
--- a/torch/_dynamo/variables/ctx_manager.py
+++ b/torch/_dynamo/variables/ctx_manager.py
@@ -1,5 +1,6 @@
+import dataclasses
import inspect
-from typing import Dict, List
+from typing import Callable, Dict, List, Optional
import torch._C
from torch._guards import Guard
@@ -20,20 +21,60 @@
)
[email protected]
+class ContextMangerState:
+ """
+ Mutating `self` in VariableTracker is not allowed because we copy
+ them. This is a mutable container pointed to by context managers
+ that won't get copied, so it is safe to mutate.
+ """
+
+ cleanup_fn: Optional[Callable] = None
+ proxy: Optional[torch.fx.Proxy] = None
+
+ def cleanup(self):
+ if self.cleanup_fn is not None:
+ self.cleanup_fn()
+ self.cleanup_fn = None
+
+ def cleanup_assert(self):
+ assert self.cleanup_fn, "multiple exits?"
+ self.cleanup()
+
+
class ContextWrappingVariable(VariableTracker):
- def __init__(self, target_values, initial_values=None, **kwargs):
+ _nonvar_fields = {
+ "cm_obj",
+ "target_values",
+ "initial_values",
+ "state",
+ *VariableTracker._nonvar_fields,
+ }
+
+ def __init__(self, target_values, initial_values=None, *, state=None, **kwargs):
super().__init__(**kwargs)
self.target_values = target_values
self.initial_values = initial_values
+ self.state = ContextMangerState() if state is None else state
def enter(self, tx):
self._call_func(tx, self.target_values)
+ self.set_cleanup_hook(tx)
return variables.ConstantVariable.create(
None, **VariableTracker.propagate(self)
)
+ def set_cleanup_hook(self, tx, fn=None):
+ if fn is None:
+
+ def fn():
+ self._call_func(tx, self.initial_values)
+
+ self.state.cleanup_fn = fn
+ tx.output.add_cleanup_hook(self.state.cleanup)
+
def exit(self, tx, *args):
- self._call_func(tx, self.initial_values)
+ self.state.cleanup_assert()
return variables.ConstantVariable.create(
None, **VariableTracker.propagate(self)
)
@@ -66,8 +107,7 @@
class GenericContextWrappingVariable(ContextWrappingVariable):
- def __init__(self, target_values, initial_values=None, **kwargs):
- cm_obj = kwargs.pop("cm_obj", None)
+ def __init__(self, target_values, initial_values=None, *, cm_obj=None, **kwargs):
assert cm_obj is not None
super().__init__(
target_values=target_values, initial_values=initial_values, **kwargs
@@ -149,6 +189,12 @@
None, **VariableTracker.propagate(self)
)
+ def exit(self, tx, *args):
+ self._call_func(tx, self.initial_values)
+ return variables.ConstantVariable.create(
+ None, **VariableTracker.propagate(self)
+ )
+
def _call_func(self, tx, values):
assert len(values) == 1
value = values[0]
@@ -173,35 +219,38 @@
return var
def __init__(
- self, target_values, initial_values=torch.is_inference_mode_enabled(), **kwargs
+ self,
+ target_values,
+ initial_values=None,
+ **kwargs,
):
- mode = kwargs.pop("mode", None)
+ if initial_values is None:
+ # This must be called here since function defaults are evaluated at import time
+ initial_values = torch.is_inference_mode_enabled()
super().__init__(
target_values=target_values, initial_values=initial_values, **kwargs
)
self.target_values = target_values
- self.mode = mode
def exit(self, tx, *args):
- self.mode = (
- torch.autograd.grad_mode._exit_inference_mode(self.mode[0]),
- tx.output.create_node(
- "call_function",
- torch.autograd.grad_mode._exit_inference_mode,
- (self.mode[1],),
- {},
- ),
+ self.state.cleanup_assert()
+ tx.output.create_node(
+ "call_function",
+ torch.autograd.grad_mode._exit_inference_mode,
+ (self.state.proxy,),
+ {},
)
def enter(self, tx):
- self.mode = (
- torch.autograd.grad_mode._enter_inference_mode(self.target_values),
- tx.output.create_node(
- "call_function",
- torch.autograd.grad_mode._enter_inference_mode,
- (self.target_values,),
- {},
- ),
+ ctx = torch.autograd.grad_mode._enter_inference_mode(self.target_values)
+ self.set_cleanup_hook(
+ tx, lambda: torch.autograd.grad_mode._exit_inference_mode(ctx)
+ )
+ self.state.proxy = tx.output.create_node(
+ "call_function",
+ torch.autograd.grad_mode._enter_inference_mode,
+ (self.target_values,),
+ {},
)
def module_name(self):
@@ -225,6 +274,7 @@
)
# mlazos: I think this is here to make sure we don't reinvoke on clone()
var._call_func(tx, [False])
+ var.set_cleanup_hook(tx)
return var
def __init__(self, target_values, initial_values=None, **kwargs):
@@ -258,6 +308,7 @@
**kwargs,
)
var._call_func(tx, [target_value])
+ var.set_cleanup_hook(tx)
return var
def __init__(self, target_values, initial_values=None, **kwargs):
@@ -299,6 +350,7 @@
**kwargs,
)
var._call_func(tx, [target_value])
+ var.set_cleanup_hook(tx)
return var
def __init__(self, target_values, initial_values=None, **kwargs):
@@ -373,27 +425,22 @@
return var
def __init__(self, target_values, initial_values=None, **kwargs):
- mode = kwargs.pop("mode", None)
super().__init__(
target_values=target_values, initial_values=initial_values, **kwargs
)
self.target_values = target_values
- self.mode = mode
def exit(self, tx, *args):
- self.mode = (
- torch.amp._exit_autocast(self.mode[0]),
- tx.output.create_node(
- "call_function", torch.amp._exit_autocast, (self.mode[1],), {}
- ),
+ self.state.cleanup_assert()
+ tx.output.create_node(
+ "call_function", torch.amp._exit_autocast, (self.state.proxy,), {}
)
def enter(self, tx):
- self.mode = (
- torch.amp._enter_autocast(*self.target_values),
- tx.output.create_node(
- "call_function", torch.amp._enter_autocast, (*self.target_values,), {}
- ),
+ ctx = torch.amp._enter_autocast(*self.target_values)
+ self.set_cleanup_hook(tx, lambda: torch.amp._exit_autocast(ctx))
+ self.state.proxy = tx.output.create_node(
+ "call_function", torch.amp._enter_autocast, (*self.target_values,), {}
)
def module_name(self):
@@ -464,7 +511,7 @@
self.set_stream_id = get_interface_for_device(self.device)._set_stream_by_id
def enter(self, tx):
- # stream generated inside of traced function
+ # stream generated inside the traced function
if self.target_values[0].as_proxy() is not None:
tx.output.create_proxy(
"call_function",
@@ -472,7 +519,7 @@
(self.target_values[0].as_proxy(),),
{},
)
- # stream passed from outside of traced function
+ # stream passed from outside the traced function
else:
stream = self.target_values[0].value
tx.output.create_proxy(
@@ -482,6 +529,7 @@
{},
)
self.set_stream(self.target_values[0].value)
+ self.set_cleanup_hook(tx, lambda: self.set_stream(self.initial_values[0].value))
def exit(self, tx, *args):
tx.output.create_proxy(
@@ -490,7 +538,7 @@
(self.initial_values[0].as_proxy(),),
{},
)
- self.set_stream(self.initial_values[0].value)
+ self.state.cleanup_assert()
def module_name(self):
return "torch." + str(self.device)