Fix DDPOptimizer fake_mode execution (#92986)
Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at bottom):
* __->__ #92986
When running compiled submods for the purpose of producing outputs to pass
to the compilation step for the next submod, we use fake parameters and
assume fake inputs, but we forgot to activate our fake_mode during execution.
This caused certain edge cases where tensors other than activations or parameters
got created during execution, such as scalar->tensor expansion in the case
of executing torch.where(tensor, scalar, scalar).
Also add a test and clarify behavior of DDPOptimizer via comments.
Fixes #92941
Pull Request resolved: https://github.com/pytorch/pytorch/pull/92986
Approved by: https://github.com/bdhirsh
diff --git a/test/distributed/test_dynamo_distributed.py b/test/distributed/test_dynamo_distributed.py
index ade7d92..59fe020 100644
--- a/test/distributed/test_dynamo_distributed.py
+++ b/test/distributed/test_dynamo_distributed.py
@@ -66,7 +66,10 @@
self.weight = nn.Parameter(torch.randn(512, 512))
def forward(self, x):
- return torch.mm(x, self.weight.t())
+ tmp = torch.mm(x, self.weight.t())
+ # test an edge case where torch.where.scalar was decomposed to aten.where.self(tensor, tensor, tensor)
+ # and the tensors T(0.4) and T(0.5) were not wrapped in FakeTensors during DDPOptimizer compilation
+ return tmp + torch.where(tmp < 0.5, 0.3, 0.6)
class MyLinear(torch.nn.Module):
def __init__(self):
diff --git a/torch/_dynamo/optimizations/distributed.py b/torch/_dynamo/optimizations/distributed.py
index 32f5aaf..23f0f01 100644
--- a/torch/_dynamo/optimizations/distributed.py
+++ b/torch/_dynamo/optimizations/distributed.py
@@ -296,8 +296,6 @@
assert isinstance(args, tuple)
assert isinstance(kwargs, dict)
- # modify the currently running FX graph
- # maybe this isn't sound in general, but only changing the target of a node might be ok?
if n.op == "call_module":
real_mod = self.fetch_attr(n.target)
if fake_mode:
@@ -308,15 +306,28 @@
log.debug(
f"\n---{n.target} graph---\n" + str(curr_submod.graph)
)
+
+ # When calling the compiler on the submod, inputs (new_args) are expected to
+ # be FakeTensors already since Dynamo would have made them FakeTensors in the
+ # non-DDP flow. However, the parameters are _not_ expected to be FakeTensors,
+ # since this wrapping happens during compilation
compiled_submod_real = self.compile_submod(
real_mod, new_args, kwargs
)
+
+ # We update the original (outer) graph with a call into the compiled module
+ # instead of the uncompiled one.
self.module.delete_submodule(n.target)
n.target = "compiled_" + n.target
self.module.add_submodule(n.target, compiled_submod_real)
- return curr_submod(*new_args, **kwargs)
- # then we execute the modified node using the usual logic
- return getattr(self, n.op)(n.target, new_args, kwargs)
+
+ # Finally, we have to produce inputs for use compiling the next submodule,
+ # and these need to be FakeTensors, so we execute the module under fake_mode
+ with fake_mode:
+ return curr_submod(*new_args, **kwargs)
+ else:
+ # placeholder or output nodes don't need to get compiled, just executed
+ return getattr(self, n.op)(n.target, new_args, kwargs)
submod_compiler = SubmodCompiler(split_gm, self.backend_compile_fn)
submod_compiler.run(*example_inputs)