blob: d179371389c3f7b6a39a0cd5f005b8263c97edfa [file] [log] [blame]
# Owner(s): ["oncall: export"]
# flake8: noqa
import copy
import dataclasses
import unittest
from contextlib import contextmanager
from dataclasses import dataclass
from re import escape
from typing import Any, List
import torch
import torch._dynamo as torchdynamo
from functorch.experimental.control_flow import cond, map
from torch import Tensor
from torch._export.utils import (
get_buffer,
get_param,
is_buffer,
is_param,
register_dataclass_as_pytree_node,
)
from torch._higher_order_ops.torchbind import enable_torchbind_tracing
from torch.export import Constraint, Dim, export, FlatArgsAdapter, unflatten
from torch.export._trace import DEFAULT_EXPORT_DYNAMO_CONFIG
from torch.export.unflatten import _disable_interpreter
from torch.fx.experimental.proxy_tensor import make_fx
from torch.testing import FileCheck
from torch.testing._internal.common_utils import (
find_library_location,
IS_FBCODE,
IS_MACOS,
IS_SANDCASTLE,
IS_WINDOWS,
run_tests,
skipIfTorchDynamo,
TestCase,
)
from torch.testing._internal.torchbind_impls import init_torchbind_implementations
from torch.utils._pytree import (
LeafSpec,
tree_flatten,
tree_unflatten,
TreeSpec,
treespec_dumps,
treespec_loads,
)
@unittest.skipIf(not torchdynamo.is_dynamo_supported(), "dynamo isn't support")
class TestUnflatten(TestCase):
def compare_outputs(self, eager, unflattened, args):
orig_output = eager(*args)
unflattened_output = unflattened(*args)
self.assertTrue(torch.allclose(orig_output, unflattened_output))
def test_unflatten_nested(self):
class NestedChild(torch.nn.Module):
def forward(self, x):
return x / x
class Child1(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.nested = NestedChild()
self.register_parameter(
"child1param", torch.nn.Parameter(torch.ones(2, 3))
)
def forward(self, x):
x = self.nested(x)
return x + self.child1param
class Child2(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.child2buffer = torch.nn.Buffer(torch.ones(2, 3))
def forward(self, x):
return x - self.child2buffer
class MyModule(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.foo = Child1()
self.bar = Child2()
self.register_parameter(
"rootparam", torch.nn.Parameter(torch.ones(2, 3))
)
def forward(self, x):
x = x * self.rootparam
x = self.foo(x)
x = self.bar(x)
return x
orig_eager = MyModule()
export_module = export(orig_eager, (torch.rand(2, 3),), {})
unflattened = unflatten(export_module)
inputs = (torch.rand(2, 3),)
# Compare the root modules and all submodules
self.compare_outputs(orig_eager, unflattened, inputs)
self.compare_outputs(orig_eager.foo, unflattened.foo, inputs)
self.compare_outputs(orig_eager.bar, unflattened.bar, inputs)
self.compare_outputs(orig_eager.foo.nested, unflattened.foo.nested, inputs)
# Check state dicts are equal
orig_state_dict = orig_eager.state_dict()
exported_state_dict = unflattened.state_dict()
for name, value in orig_state_dict.items():
self.assertTrue(torch.allclose(value, exported_state_dict[name]))
def test_unflatten_buffer_mutation(self):
class Child(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.child2buffer = torch.nn.Buffer(torch.ones(2, 3))
def forward(self, x):
self.child2buffer.add_(x)
return x - self.child2buffer
class MyModule(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.foo = Child()
self.register_parameter(
"rootparam", torch.nn.Parameter(torch.ones(2, 3))
)
def forward(self, x):
x = self.foo(x)
return x * self.rootparam
eager_module = MyModule()
export_module = export(eager_module, (torch.rand(2, 3),), {})
unflattened_module = unflatten(export_module)
# Buffer should look the same before and after one run
eager_buffer = eager_module.foo.child2buffer
unflattened_buffer = unflattened_module.foo.child2buffer
self.assertTrue(torch.allclose(eager_buffer, unflattened_buffer))
inputs = (torch.rand(2, 3),)
eager_module(*inputs)
unflattened_module(*inputs)
self.assertTrue(torch.allclose(eager_buffer, unflattened_buffer))
def test_unflatten_nested_access(self):
class Child(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.child2buffer = torch.nn.Buffer(torch.ones(2, 3))
def forward(self, x):
return x - self.child2buffer
class MyModule(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.foo = Child()
self.register_parameter(
"rootparam", torch.nn.Parameter(torch.ones(2, 3))
)
def forward(self, x):
x = x + self.foo.child2buffer
x = self.foo(x)
return x
eager_module = MyModule()
export_module = export(eager_module, (torch.rand(2, 3),), {})
unflattened_module = unflatten(export_module)
inputs = (torch.rand(2, 3),)
self.compare_outputs(eager_module, unflattened_module, inputs)
def test_unflatten_shared_submodule(self):
class Shared(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
layernorm = torch.nn.LayerNorm(10)
self.sub_net = torch.nn.Sequential(
layernorm,
torch.nn.ReLU(),
layernorm,
torch.nn.ReLU(),
)
def forward(self, x):
return self.sub_net(x)
eager_module = Shared()
inps = (torch.rand(10),)
export_module = export(eager_module, inps, {})
unflattened_module = unflatten(export_module)
self.compare_outputs(eager_module, unflattened_module, inps)
self.assertTrue(hasattr(unflattened_module, "sub_net"))
for i in range(len(eager_module.sub_net)):
self.assertTrue(hasattr(unflattened_module.sub_net, str(i)))
self.assertEqual(
id(getattr(unflattened_module.sub_net, "0")),
id(getattr(unflattened_module.sub_net, "2")),
)
@unittest.skipIf(IS_WINDOWS, "Windows not supported for this test")
@skipIfTorchDynamo("Non strict mode is not meant to run with dynamo")
def test_unflatten_preserve_signature(self):
class NestedChild(torch.nn.Module):
def forward(self, zx, y):
return {"x": y["key"] + zx[1], "w": y["key"] * zx[1]}
class Child1(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.nested = NestedChild()
def forward(self, x, y):
z = torch.ones_like(x)
xw = self.nested((z, x), y={"key": y})
return xw["w"] + z - xw["x"]
class Child2(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
def forward(self, x):
return x - 1
class MyModule(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.foo = Child1()
self.bar = Child2()
def forward(self, x, y):
x = self.foo(x, y)
x = self.bar(x)
return x
orig_eager = MyModule()
inps = torch.rand(2, 3), torch.rand(2, 3)
for strict in [True, False]:
export_module = export(
orig_eager,
inps,
{},
preserve_module_call_signature=("foo.nested",),
strict=strict,
)
unflattened = unflatten(export_module)
self.compare_outputs(export_module.module(), unflattened, inps)
unflattened.foo.nested = NestedChild()
self.compare_outputs(export_module.module(), unflattened, inps)
# Test tree spec mismatched input
orig_outs = export_module.module()(*inps)
new_inps = *inps, torch.rand(2, 3)
with self.assertRaisesRegex(
TypeError,
"There is no flat args adapter sepcified. Are you sure you are calling this with the right arguments?",
):
unflattened(new_inps)
# With flat args adapter
class KeepTwoFlatArgsAdapter(FlatArgsAdapter):
def adapt(
self,
target_spec: TreeSpec,
input_spec: TreeSpec,
input_args: List[Any],
) -> List[Any]:
while len(input_args) > 2:
input_args.pop(-1)
return input_args
unflattened = unflatten(export_module, KeepTwoFlatArgsAdapter())
new_outs = unflattened(*new_inps)
self.assertTrue(torch.allclose(orig_outs, new_outs))
def test_unflatten_param_list_dict(self):
class Mod(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.param_list = torch.nn.ParameterList()
self.param_dict = torch.nn.ParameterDict()
for i in range(2):
self.param_list.append(torch.nn.Parameter(torch.randn((2, 3))))
self.param_dict[f"key_{i}"] = torch.nn.Parameter(
torch.randn((2, 3))
)
def forward(self, x):
for i in range(2):
x = x + self.param_list[i]
x = x + self.param_dict[f"key_{i}"]
return x
export_module = torch.export.export(Mod(), (torch.randn((2, 3)),))
unflattened = unflatten(export_module)
self.compare_outputs(
export_module.module(), unflattened, (torch.randn((2, 3)),)
)
@unittest.skipIf(IS_WINDOWS, "Windows not supported for this test")
def test_unflatten_preserve_with_unused_input(self):
class M1(torch.nn.Module):
def forward(self, x, a, b):
return x + a, b
class M(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.m1 = M1()
def forward(self, x, y):
a, b = torch.topk(y, 2)
return self.m1(x, a, b)[0]
ep = torch.export.export(
M(),
(torch.randn(2), torch.randn(5)),
preserve_module_call_signature=("m1",),
strict=False,
)
ep.graph.eliminate_dead_code()
unflattened = unflatten(ep)
self.compare_outputs(ep.module(), unflattened, (torch.randn(2), torch.randn(5)))
def test_unflatten_wrong_input(self):
class Mod(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.param_list = torch.nn.ParameterList()
self.param_dict = torch.nn.ParameterDict()
for i in range(2):
self.param_list.append(torch.nn.Parameter(torch.randn((2, 3))))
self.param_dict[f"key_{i}"] = torch.nn.Parameter(
torch.randn((2, 3))
)
def forward(self, x):
a = x.sum()
for i in range(2):
a = a + self.param_list[i].sum()
a = a + self.param_dict[f"key_{i}"].sum()
return a
export_module = torch.export.export(Mod(), (torch.randn((2, 3)),))
with self.assertRaisesRegex(
RuntimeError,
escape("Expected input at *args[0].shape[0] to be equal to 2, but got 6"),
):
export_module.module()(torch.randn(6, 6))
unflattened = unflatten(export_module)
with self.assertRaisesRegex(
RuntimeError,
escape("Expected input at *args[0].shape[0] to be equal to 2, but got 6"),
):
unflattened(torch.randn(6, 6))
@unittest.skipIf(IS_WINDOWS, "Windows not supported for this test")
def test_unflatten_with_inplace_compile(self):
class NestedChild(torch.nn.Module):
def forward(self, x):
return x / x
class Child1(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.nested = NestedChild()
self.register_parameter(
"child1param", torch.nn.Parameter(torch.ones(2, 3))
)
def forward(self, x):
x = self.nested(x)
return x + self.child1param
class Child2(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.child2buffer = torch.nn.Buffer(torch.ones(2, 3))
def forward(self, x):
return x - self.child2buffer
class MyModule(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.foo = Child1()
self.bar = Child2()
self.register_parameter(
"rootparam", torch.nn.Parameter(torch.ones(2, 3))
)
def forward(self, x):
x = x * self.rootparam
x = self.foo(x)
x = self.bar(x)
return x
orig_eager = MyModule()
export_module = torch.export.export(orig_eager, (torch.rand(2, 3),), {})
unflattened = unflatten(export_module)
# in-place compilation should work. Pass fullgraph to ensure no graph breaks.
from torch._dynamo.backends.debugging import ExplainWithBackend
eb = ExplainWithBackend("inductor")
unflattened.foo.compile(backend=eb, fullgraph=True)
inputs = (torch.randn(2, 3),)
self.compare_outputs(orig_eager, unflattened, inputs)
self.assertEqual(len(eb.graphs), 1)
def test_fx_trace(self):
class MyModule(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
def forward(self, x, y):
x = x[0] + x[1]
x = x + y["foo"]
return x
orig_eager = MyModule()
inputs = ((torch.rand(2, 3), torch.rand(2, 3)), {"foo": torch.rand(2, 3)})
export_module = export(orig_eager, inputs, {})
unflattened = unflatten(export_module)
torch.fx.symbolic_trace(
unflattened, concrete_args=(torch.fx.PH, torch.fx.PH, torch.fx.PH)
)
def test_double_nested_submodule(self):
class SubSubMod(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
def forward(self, x):
return x * x
class SubMod(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.subsubmod = SubSubMod()
def forward(self, x):
return x - x
class MyModule(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.submod = SubMod()
def forward(self, x):
return x + self.submod.subsubmod(x)
orig_eager = MyModule()
export_module = torch.export.export(orig_eager, (torch.rand(2, 3),), {})
unflattened = unflatten(export_module)
inputs = (torch.rand(2, 3),)
self.compare_outputs(orig_eager, unflattened, inputs)
def test_unflatten_container_type(self):
class Leaf(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.linear = torch.nn.Linear(4, 4)
def forward(self, x):
return self.linear(x)
class Bar(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.leaf = Leaf()
self.buffer = torch.nn.Buffer(torch.randn(4, 4))
def forward(self, x, z):
return self.buffer.sum() + self.leaf(x).sum() + z[0].sum() + z[1].sum()
class Foo(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.bar = Bar()
def forward(self, x, z):
y = self.bar.buffer + x + z[0] + z[1]
return self.bar(x, z) + y.sum()
inp = (torch.randn(4, 4), [torch.randn(4, 4), torch.randn(4, 4)])
mod = Foo()
ep_strict = torch.export.export(mod, inp)
ep_non_strict = torch.export.export(mod, inp, strict=False)
gm_unflat_non_strict = unflatten(ep_non_strict)
ep = torch.export.export(gm_unflat_non_strict, inp, strict=False)
self.assertTrue(torch.allclose(ep.module()(*inp), mod(*inp)))
def test_unflattened_module_nodes_has_meta_val(self):
class SubMod(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
def forward(self, x):
return x + x, x * x
class MyModule(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.submod = SubMod()
def forward(self, x):
return x + sum(self.submod(x))
orig_eager = MyModule()
export_module = torch.export.export(orig_eager, (torch.rand(2, 3),), {})
unflattened = unflatten(export_module)
inputs = (torch.rand(2, 3),)
self.compare_outputs(orig_eager, unflattened, inputs)
def check_meta(gm):
for n in gm.graph.nodes:
if n.op == "output":
continue
self.assertTrue(n.meta.get("val") is not None)
for m in unflattened.modules():
check_meta(m)
def test_unflatten_requires_grad_param(self):
class M(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.p = torch.nn.Parameter(torch.ones(3, 3), requires_grad=False)
def forward(self, x):
return self.p + x
with torch.device("meta"):
mod = M()
inputs = (torch.randn(3, 3, device="meta"),)
ep = export(mod, inputs)
unflattened = unflatten(ep)
self.assertTrue(unflattened.state_dict()["p"].requires_grad is False)
self.assertTrue(unflattened.p.requires_grad is False)
def test_placeholder_and_get_attr_ordering_after_unflattened(self):
class TransposeModule(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv = torch.nn.Conv2d(3, 1, 3, stride=2)
def forward(self, x):
x = self.conv(x)
return x.transpose(0, 1)
x = torch.randn(32, 3, 64, 64)
exported_program = export(TransposeModule(), args=(x,))
unflattened_module = unflatten(exported_program)
# Check the inputs of the created call_module node are in order
call_module_input_order = []
for node in unflattened_module.graph.nodes:
if node.op == "call_module":
transpose_module = unflattened_module.get_submodule(node.target)
for sub_node in transpose_module.graph.nodes:
if sub_node.op == "placeholder" or sub_node.op == "get_attr":
call_module_input_order.append(sub_node.op)
self.assertEqual(
call_module_input_order, ["placeholder", "get_attr", "get_attr"]
)
def test_unflatten_constant_tensor(self):
class SubMod(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.initializer = 0.1
def forward(self, x):
return x + torch.tensor(self.initializer)
class Mod(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.submod = SubMod()
def forward(self, x):
return x + self.submod(x)
export_module = torch.export.export(Mod(), (torch.randn((2, 3)),))
unflattened = unflatten(export_module)
self.compare_outputs(
export_module.module(), unflattened, (torch.randn((2, 3)),)
)
@skipIfTorchDynamo("custom objects not supported in dynamo yet")
def test_unflatten_constant_obj(self):
init_torchbind_implementations()
@torch._library.register_fake_class("_TorchScriptTesting::_Foo")
class FakeFoo:
def __init__(self, x: int, y: int):
self.x = x
self.y = y
@classmethod
def __obj_unflatten__(cls, flat_ctx):
return cls(**dict(flat_ctx))
def add_tensor(self, z):
return (self.x + self.y) * z
class SubMod(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.attr = torch.classes._TorchScriptTesting._Foo(10, 20)
def forward(self, x):
return x + self.attr.add_tensor(x)
class Mod(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.submod = SubMod()
def forward(self, x):
return x + self.submod(x)
with enable_torchbind_tracing():
export_module = torch.export.export(
Mod(), (torch.randn((2, 3)),), strict=False
)
unflattened = unflatten(export_module)
self.compare_outputs(
export_module.module(), unflattened, (torch.randn((2, 3)),)
)
# skip connection is not supported yet
@unittest.expectedFailure
def test_unflatten_skipped_call_module(self):
class C(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
return a.d(x.cos())
class B(torch.nn.Module):
def __init__(self):
super().__init__()
self.c = C()
def forward(self, x):
return self.c(x) + x
class D(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
return x.sin()
class A(torch.nn.Module):
def __init__(self):
super().__init__()
self.b = B()
self.d = D()
def forward(self, x):
return self.b(x)
a = A()
# The call chain looks like this:
# A -> B -> C -> A.d
ep = torch.export.export(a, (torch.randn(3),), strict=False)
unflattened = unflatten(ep)
def test_nested_leaf_non_strict(self):
class Leaf(torch.nn.Module):
def forward(self, x):
return x + 1
class Nested(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.leaf = Leaf()
def forward(self, x):
return self.leaf(x) + 2
class TopLevel(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.nested = Nested()
def forward(self, x):
return self.nested(x) + 3
ep = torch.export.export(
TopLevel(),
(torch.randn(3),),
strict=False,
preserve_module_call_signature=("nested",),
)
torch.export.unflatten(ep)
def test_unflatten_submodule_ordering(self):
class Module2(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.buffer = torch.nn.Buffer(torch.rand(3, 4))
self.register_parameter("param", torch.nn.Parameter(torch.rand(3, 4)))
def forward(self, x):
return x + self.buffer + self.param
class Module1(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.buffer = torch.nn.Buffer(torch.rand(3, 4))
self.register_parameter("param", torch.nn.Parameter(torch.rand(3, 4)))
def forward(self, x):
return x + self.buffer + self.param
class Module(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.mod2 = Module2()
self.mod3 = self.mod2
self.mod1 = Module1()
def forward(self, x):
return self.mod3(self.mod2(self.mod1(x)))
mod = Module()
ep = torch.export.export(mod, (torch.randn(3, 4),))
unflattened = torch.export.unflatten(ep)
fqn_list = [x for x, _ in unflattened.named_modules(remove_duplicate=False)]
self.assertEqual(len(fqn_list), 4)
self.assertEqual(
[x for x, _ in mod.named_modules(remove_duplicate=False)],
fqn_list,
)
def test_duplicate_placeholder(self):
N, C, H, W = 1, 2, 2, 3
class MyModule(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
layer = torch.nn.LayerNorm([C, H, W])
self.norms = torch.nn.ModuleList(
[
layer, # reuse layer norm
layer,
layer,
]
)
def forward(self, input_):
for i in range(len(self.norms)):
output = self.norms[i](input_)
input_ = output
return output
mod = MyModule()
input_ = torch.randn(N, C, H, W)
ep_strict = export(copy.deepcopy(mod), (input_,), strict=True)
umod = unflatten(ep_strict)
self.assertTrue(torch.allclose(umod(input_), mod(input_)))
ep_non_strict = export(copy.deepcopy(mod), (input_,), strict=False)
umod = unflatten(ep_non_strict)
self.assertTrue(torch.allclose(umod(input_), mod(input_)))
def test_simple_alias(self):
# handle weight sharing, check tensor ids after unflattening
class Foo(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
# alias param
self.bias = torch.nn.Parameter(torch.randn(4))
self.m = torch.nn.Linear(4, 4)
self.m.bias = self.bias
def forward(self, x):
return self.m(x) + self.bias
m = Foo()
inps = (torch.randn(4, 4),)
ep = export(m, inps)
unep = unflatten(ep)
self.assertTrue(id(unep.m.bias) == id(unep.bias))
# handle aliasing where one alias is unused
class Foo(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.bias = torch.nn.Parameter(torch.randn(4))
self.m = torch.nn.Linear(4, 4)
self.m.bias = (
self.bias
) # self.bias is unused, aliasing should be handled
def forward(self, x):
return self.m(x)
m = Foo()
inps = (torch.randn(4, 4),)
ep = export(m, inps)
unep = unflatten(ep)
self.assertTrue(torch.allclose(unep(*inps), m(*inps)))
def test_attr_as_submod_input(self):
class layer(torch.nn.Module):
def forward(self, x, const) -> torch.Tensor:
return x + const
class M(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.const = torch.nn.Buffer(torch.ones(4, 8))
self.layers = torch.nn.ModuleList([layer() for _ in range(2)])
def forward(self, x: torch.Tensor) -> torch.Tensor:
for layer in self.layers:
x = layer(x, self.const)
return x
mod = M()
x = torch.randn(4, 8)
ep = export(mod, (x,))
unflattened = unflatten(ep)
torch.testing.assert_close(unflattened(x), mod(x))
def test_dedup_sym_size(self):
# Here, sym_size & floor div are used in 3 subgraphs (top-level, m1, m2),
# but only one copy of sym_size is created in the initial export graph.
# For m1, sym_size & floordiv should be copied as recompute since we preserve the call signature,
# but for m2 floordiv should be passed in as a placeholder.
# Test that this is preserved, and the unflattened module runs correctly.
class M1(torch.nn.Module):
def forward(self, x, y):
d = x.size(0) // 2
return y[:d]
class M2(torch.nn.Module):
def forward(self, x, y):
d = x.size(0) // 2
return y[:d]
class M(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.m1 = M1()
self.m2 = M2()
def forward(self, x, y):
d = x.size(0) // 2
m1_res = self.m1(x, y)
m2_res = self.m2(x, y)
return y[d:] + m1_res + m2_res
inputs = (torch.ones(10), torch.ones(10))
d_ = torch.export.Dim("foo", max=2048)
d = 2 * d_
ep = torch.export.export(
M(),
inputs,
dynamic_shapes=((d,), (d,)),
strict=False,
preserve_module_call_signature=("m1",),
)
unflat = unflatten(ep)
unflat(*inputs)
fn_count_sym_size = lambda graph: [node.target for node in graph.nodes].count(
torch.ops.aten.sym_size.int
)
self.assertEqual(fn_count_sym_size(unflat.graph), 1)
self.assertEqual(fn_count_sym_size(unflat.m1.graph), 1)
self.assertEqual(fn_count_sym_size(unflat.m2.graph), 0)
def test_unflatten_eager(self):
class NestedChild(torch.nn.Module):
def forward(self, x):
return x / x
class Child1(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.nested = NestedChild()
self.register_parameter(
"child1param", torch.nn.Parameter(torch.ones(2, 3))
)
def forward(self, x):
x = self.nested(x)
return x + self.child1param
class Child2(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.child2buffer = torch.nn.Buffer(torch.ones(2, 3))
def forward(self, x):
return x - self.child2buffer
class MyModule(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.foo = Child1()
self.bar = Child2()
self.register_parameter(
"rootparam", torch.nn.Parameter(torch.ones(2, 3))
)
def forward(self, x):
x = x * self.rootparam
x = self.foo(x)
x = self.bar(x)
return x
orig_eager = MyModule()
export_module = export(orig_eager, (torch.rand(2, 3),), {})
with _disable_interpreter():
unflattened = unflatten(export_module)
self.assertEqual(unflattened._run_with_interpeter, False)
self.assertEqual(unflattened.foo._run_with_interpeter, False)
inputs = (torch.rand(2, 3),)
# Compare the root modules and all submodules
self.compare_outputs(orig_eager, unflattened, inputs)
self.compare_outputs(orig_eager.foo, unflattened.foo, inputs)
self.compare_outputs(orig_eager.bar, unflattened.bar, inputs)
self.compare_outputs(orig_eager.foo.nested, unflattened.foo.nested, inputs)
# Check state dicts are equal
orig_state_dict = orig_eager.state_dict()
exported_state_dict = unflattened.state_dict()
for name, value in orig_state_dict.items():
self.assertTrue(torch.allclose(value, exported_state_dict[name]))
# Check composability with symbolic trace, as torchrec ddp uses symbolic
# tracer
symbolic_traced = torch.fx.symbolic_trace(unflattened, concrete_args=inputs)
self.assertTrue(torch.allclose(orig_eager(*inputs), symbolic_traced(*inputs)))
# torch.compile submodule
unflattened.foo = torch.compile(unflattened.foo, fullgraph=True)
self.compare_outputs(orig_eager, unflattened, inputs)
if __name__ == "__main__":
run_tests()