| # Owner(s): ["oncall: export"] |
| # flake8: noqa |
| import copy |
| import dataclasses |
| import unittest |
| from contextlib import contextmanager |
| from dataclasses import dataclass |
| from re import escape |
| from typing import Any, List |
| |
| import torch |
| import torch._dynamo as torchdynamo |
| from functorch.experimental.control_flow import cond, map |
| from torch import Tensor |
| from torch._export.utils import ( |
| get_buffer, |
| get_param, |
| is_buffer, |
| is_param, |
| register_dataclass_as_pytree_node, |
| ) |
| from torch._higher_order_ops.torchbind import enable_torchbind_tracing |
| from torch.export import Constraint, Dim, export, FlatArgsAdapter, unflatten |
| from torch.export._trace import DEFAULT_EXPORT_DYNAMO_CONFIG |
| from torch.export.unflatten import _disable_interpreter |
| from torch.fx.experimental.proxy_tensor import make_fx |
| from torch.testing import FileCheck |
| from torch.testing._internal.common_utils import ( |
| find_library_location, |
| IS_FBCODE, |
| IS_MACOS, |
| IS_SANDCASTLE, |
| IS_WINDOWS, |
| run_tests, |
| skipIfTorchDynamo, |
| TestCase, |
| ) |
| from torch.testing._internal.torchbind_impls import init_torchbind_implementations |
| from torch.utils._pytree import ( |
| LeafSpec, |
| tree_flatten, |
| tree_unflatten, |
| TreeSpec, |
| treespec_dumps, |
| treespec_loads, |
| ) |
| |
| |
| @unittest.skipIf(not torchdynamo.is_dynamo_supported(), "dynamo isn't support") |
| class TestUnflatten(TestCase): |
| def compare_outputs(self, eager, unflattened, args): |
| orig_output = eager(*args) |
| unflattened_output = unflattened(*args) |
| self.assertTrue(torch.allclose(orig_output, unflattened_output)) |
| |
| def test_unflatten_nested(self): |
| class NestedChild(torch.nn.Module): |
| def forward(self, x): |
| return x / x |
| |
| class Child1(torch.nn.Module): |
| def __init__(self) -> None: |
| super().__init__() |
| self.nested = NestedChild() |
| self.register_parameter( |
| "child1param", torch.nn.Parameter(torch.ones(2, 3)) |
| ) |
| |
| def forward(self, x): |
| x = self.nested(x) |
| return x + self.child1param |
| |
| class Child2(torch.nn.Module): |
| def __init__(self) -> None: |
| super().__init__() |
| self.child2buffer = torch.nn.Buffer(torch.ones(2, 3)) |
| |
| def forward(self, x): |
| return x - self.child2buffer |
| |
| class MyModule(torch.nn.Module): |
| def __init__(self) -> None: |
| super().__init__() |
| self.foo = Child1() |
| self.bar = Child2() |
| self.register_parameter( |
| "rootparam", torch.nn.Parameter(torch.ones(2, 3)) |
| ) |
| |
| def forward(self, x): |
| x = x * self.rootparam |
| x = self.foo(x) |
| x = self.bar(x) |
| return x |
| |
| orig_eager = MyModule() |
| export_module = export(orig_eager, (torch.rand(2, 3),), {}) |
| unflattened = unflatten(export_module) |
| |
| inputs = (torch.rand(2, 3),) |
| |
| # Compare the root modules and all submodules |
| self.compare_outputs(orig_eager, unflattened, inputs) |
| self.compare_outputs(orig_eager.foo, unflattened.foo, inputs) |
| self.compare_outputs(orig_eager.bar, unflattened.bar, inputs) |
| self.compare_outputs(orig_eager.foo.nested, unflattened.foo.nested, inputs) |
| |
| # Check state dicts are equal |
| orig_state_dict = orig_eager.state_dict() |
| exported_state_dict = unflattened.state_dict() |
| for name, value in orig_state_dict.items(): |
| self.assertTrue(torch.allclose(value, exported_state_dict[name])) |
| |
| def test_unflatten_buffer_mutation(self): |
| class Child(torch.nn.Module): |
| def __init__(self) -> None: |
| super().__init__() |
| self.child2buffer = torch.nn.Buffer(torch.ones(2, 3)) |
| |
| def forward(self, x): |
| self.child2buffer.add_(x) |
| return x - self.child2buffer |
| |
| class MyModule(torch.nn.Module): |
| def __init__(self) -> None: |
| super().__init__() |
| self.foo = Child() |
| self.register_parameter( |
| "rootparam", torch.nn.Parameter(torch.ones(2, 3)) |
| ) |
| |
| def forward(self, x): |
| x = self.foo(x) |
| return x * self.rootparam |
| |
| eager_module = MyModule() |
| export_module = export(eager_module, (torch.rand(2, 3),), {}) |
| unflattened_module = unflatten(export_module) |
| |
| # Buffer should look the same before and after one run |
| eager_buffer = eager_module.foo.child2buffer |
| unflattened_buffer = unflattened_module.foo.child2buffer |
| self.assertTrue(torch.allclose(eager_buffer, unflattened_buffer)) |
| |
| inputs = (torch.rand(2, 3),) |
| eager_module(*inputs) |
| unflattened_module(*inputs) |
| self.assertTrue(torch.allclose(eager_buffer, unflattened_buffer)) |
| |
| def test_unflatten_nested_access(self): |
| class Child(torch.nn.Module): |
| def __init__(self) -> None: |
| super().__init__() |
| self.child2buffer = torch.nn.Buffer(torch.ones(2, 3)) |
| |
| def forward(self, x): |
| return x - self.child2buffer |
| |
| class MyModule(torch.nn.Module): |
| def __init__(self) -> None: |
| super().__init__() |
| self.foo = Child() |
| self.register_parameter( |
| "rootparam", torch.nn.Parameter(torch.ones(2, 3)) |
| ) |
| |
| def forward(self, x): |
| x = x + self.foo.child2buffer |
| x = self.foo(x) |
| return x |
| |
| eager_module = MyModule() |
| export_module = export(eager_module, (torch.rand(2, 3),), {}) |
| unflattened_module = unflatten(export_module) |
| |
| inputs = (torch.rand(2, 3),) |
| self.compare_outputs(eager_module, unflattened_module, inputs) |
| |
| def test_unflatten_shared_submodule(self): |
| class Shared(torch.nn.Module): |
| def __init__(self) -> None: |
| super().__init__() |
| layernorm = torch.nn.LayerNorm(10) |
| self.sub_net = torch.nn.Sequential( |
| layernorm, |
| torch.nn.ReLU(), |
| layernorm, |
| torch.nn.ReLU(), |
| ) |
| |
| def forward(self, x): |
| return self.sub_net(x) |
| |
| eager_module = Shared() |
| inps = (torch.rand(10),) |
| export_module = export(eager_module, inps, {}) |
| unflattened_module = unflatten(export_module) |
| self.compare_outputs(eager_module, unflattened_module, inps) |
| self.assertTrue(hasattr(unflattened_module, "sub_net")) |
| for i in range(len(eager_module.sub_net)): |
| self.assertTrue(hasattr(unflattened_module.sub_net, str(i))) |
| self.assertEqual( |
| id(getattr(unflattened_module.sub_net, "0")), |
| id(getattr(unflattened_module.sub_net, "2")), |
| ) |
| |
| @unittest.skipIf(IS_WINDOWS, "Windows not supported for this test") |
| @skipIfTorchDynamo("Non strict mode is not meant to run with dynamo") |
| def test_unflatten_preserve_signature(self): |
| class NestedChild(torch.nn.Module): |
| def forward(self, zx, y): |
| return {"x": y["key"] + zx[1], "w": y["key"] * zx[1]} |
| |
| class Child1(torch.nn.Module): |
| def __init__(self) -> None: |
| super().__init__() |
| self.nested = NestedChild() |
| |
| def forward(self, x, y): |
| z = torch.ones_like(x) |
| xw = self.nested((z, x), y={"key": y}) |
| return xw["w"] + z - xw["x"] |
| |
| class Child2(torch.nn.Module): |
| def __init__(self) -> None: |
| super().__init__() |
| |
| def forward(self, x): |
| return x - 1 |
| |
| class MyModule(torch.nn.Module): |
| def __init__(self) -> None: |
| super().__init__() |
| self.foo = Child1() |
| self.bar = Child2() |
| |
| def forward(self, x, y): |
| x = self.foo(x, y) |
| x = self.bar(x) |
| return x |
| |
| orig_eager = MyModule() |
| inps = torch.rand(2, 3), torch.rand(2, 3) |
| for strict in [True, False]: |
| export_module = export( |
| orig_eager, |
| inps, |
| {}, |
| preserve_module_call_signature=("foo.nested",), |
| strict=strict, |
| ) |
| unflattened = unflatten(export_module) |
| self.compare_outputs(export_module.module(), unflattened, inps) |
| unflattened.foo.nested = NestedChild() |
| self.compare_outputs(export_module.module(), unflattened, inps) |
| |
| # Test tree spec mismatched input |
| orig_outs = export_module.module()(*inps) |
| new_inps = *inps, torch.rand(2, 3) |
| with self.assertRaisesRegex( |
| TypeError, |
| "There is no flat args adapter sepcified. Are you sure you are calling this with the right arguments?", |
| ): |
| unflattened(new_inps) |
| |
| # With flat args adapter |
| class KeepTwoFlatArgsAdapter(FlatArgsAdapter): |
| def adapt( |
| self, |
| target_spec: TreeSpec, |
| input_spec: TreeSpec, |
| input_args: List[Any], |
| ) -> List[Any]: |
| while len(input_args) > 2: |
| input_args.pop(-1) |
| return input_args |
| |
| unflattened = unflatten(export_module, KeepTwoFlatArgsAdapter()) |
| new_outs = unflattened(*new_inps) |
| self.assertTrue(torch.allclose(orig_outs, new_outs)) |
| |
| def test_unflatten_param_list_dict(self): |
| class Mod(torch.nn.Module): |
| def __init__(self) -> None: |
| super().__init__() |
| self.param_list = torch.nn.ParameterList() |
| self.param_dict = torch.nn.ParameterDict() |
| for i in range(2): |
| self.param_list.append(torch.nn.Parameter(torch.randn((2, 3)))) |
| self.param_dict[f"key_{i}"] = torch.nn.Parameter( |
| torch.randn((2, 3)) |
| ) |
| |
| def forward(self, x): |
| for i in range(2): |
| x = x + self.param_list[i] |
| x = x + self.param_dict[f"key_{i}"] |
| return x |
| |
| export_module = torch.export.export(Mod(), (torch.randn((2, 3)),)) |
| unflattened = unflatten(export_module) |
| |
| self.compare_outputs( |
| export_module.module(), unflattened, (torch.randn((2, 3)),) |
| ) |
| |
| @unittest.skipIf(IS_WINDOWS, "Windows not supported for this test") |
| def test_unflatten_preserve_with_unused_input(self): |
| class M1(torch.nn.Module): |
| def forward(self, x, a, b): |
| return x + a, b |
| |
| class M(torch.nn.Module): |
| def __init__(self) -> None: |
| super().__init__() |
| self.m1 = M1() |
| |
| def forward(self, x, y): |
| a, b = torch.topk(y, 2) |
| return self.m1(x, a, b)[0] |
| |
| ep = torch.export.export( |
| M(), |
| (torch.randn(2), torch.randn(5)), |
| preserve_module_call_signature=("m1",), |
| strict=False, |
| ) |
| ep.graph.eliminate_dead_code() |
| unflattened = unflatten(ep) |
| self.compare_outputs(ep.module(), unflattened, (torch.randn(2), torch.randn(5))) |
| |
| def test_unflatten_wrong_input(self): |
| class Mod(torch.nn.Module): |
| def __init__(self) -> None: |
| super().__init__() |
| self.param_list = torch.nn.ParameterList() |
| self.param_dict = torch.nn.ParameterDict() |
| for i in range(2): |
| self.param_list.append(torch.nn.Parameter(torch.randn((2, 3)))) |
| self.param_dict[f"key_{i}"] = torch.nn.Parameter( |
| torch.randn((2, 3)) |
| ) |
| |
| def forward(self, x): |
| a = x.sum() |
| for i in range(2): |
| a = a + self.param_list[i].sum() |
| a = a + self.param_dict[f"key_{i}"].sum() |
| return a |
| |
| export_module = torch.export.export(Mod(), (torch.randn((2, 3)),)) |
| with self.assertRaisesRegex( |
| RuntimeError, |
| escape("Expected input at *args[0].shape[0] to be equal to 2, but got 6"), |
| ): |
| export_module.module()(torch.randn(6, 6)) |
| |
| unflattened = unflatten(export_module) |
| with self.assertRaisesRegex( |
| RuntimeError, |
| escape("Expected input at *args[0].shape[0] to be equal to 2, but got 6"), |
| ): |
| unflattened(torch.randn(6, 6)) |
| |
| @unittest.skipIf(IS_WINDOWS, "Windows not supported for this test") |
| def test_unflatten_with_inplace_compile(self): |
| class NestedChild(torch.nn.Module): |
| def forward(self, x): |
| return x / x |
| |
| class Child1(torch.nn.Module): |
| def __init__(self) -> None: |
| super().__init__() |
| self.nested = NestedChild() |
| self.register_parameter( |
| "child1param", torch.nn.Parameter(torch.ones(2, 3)) |
| ) |
| |
| def forward(self, x): |
| x = self.nested(x) |
| return x + self.child1param |
| |
| class Child2(torch.nn.Module): |
| def __init__(self) -> None: |
| super().__init__() |
| self.child2buffer = torch.nn.Buffer(torch.ones(2, 3)) |
| |
| def forward(self, x): |
| return x - self.child2buffer |
| |
| class MyModule(torch.nn.Module): |
| def __init__(self) -> None: |
| super().__init__() |
| self.foo = Child1() |
| self.bar = Child2() |
| self.register_parameter( |
| "rootparam", torch.nn.Parameter(torch.ones(2, 3)) |
| ) |
| |
| def forward(self, x): |
| x = x * self.rootparam |
| x = self.foo(x) |
| x = self.bar(x) |
| return x |
| |
| orig_eager = MyModule() |
| export_module = torch.export.export(orig_eager, (torch.rand(2, 3),), {}) |
| unflattened = unflatten(export_module) |
| |
| # in-place compilation should work. Pass fullgraph to ensure no graph breaks. |
| from torch._dynamo.backends.debugging import ExplainWithBackend |
| |
| eb = ExplainWithBackend("inductor") |
| unflattened.foo.compile(backend=eb, fullgraph=True) |
| inputs = (torch.randn(2, 3),) |
| self.compare_outputs(orig_eager, unflattened, inputs) |
| self.assertEqual(len(eb.graphs), 1) |
| |
| def test_fx_trace(self): |
| class MyModule(torch.nn.Module): |
| def __init__(self) -> None: |
| super().__init__() |
| |
| def forward(self, x, y): |
| x = x[0] + x[1] |
| x = x + y["foo"] |
| return x |
| |
| orig_eager = MyModule() |
| inputs = ((torch.rand(2, 3), torch.rand(2, 3)), {"foo": torch.rand(2, 3)}) |
| export_module = export(orig_eager, inputs, {}) |
| |
| unflattened = unflatten(export_module) |
| torch.fx.symbolic_trace( |
| unflattened, concrete_args=(torch.fx.PH, torch.fx.PH, torch.fx.PH) |
| ) |
| |
| def test_double_nested_submodule(self): |
| class SubSubMod(torch.nn.Module): |
| def __init__(self) -> None: |
| super().__init__() |
| |
| def forward(self, x): |
| return x * x |
| |
| class SubMod(torch.nn.Module): |
| def __init__(self) -> None: |
| super().__init__() |
| self.subsubmod = SubSubMod() |
| |
| def forward(self, x): |
| return x - x |
| |
| class MyModule(torch.nn.Module): |
| def __init__(self) -> None: |
| super().__init__() |
| self.submod = SubMod() |
| |
| def forward(self, x): |
| return x + self.submod.subsubmod(x) |
| |
| orig_eager = MyModule() |
| export_module = torch.export.export(orig_eager, (torch.rand(2, 3),), {}) |
| unflattened = unflatten(export_module) |
| |
| inputs = (torch.rand(2, 3),) |
| self.compare_outputs(orig_eager, unflattened, inputs) |
| |
| def test_unflatten_container_type(self): |
| class Leaf(torch.nn.Module): |
| def __init__(self) -> None: |
| super().__init__() |
| self.linear = torch.nn.Linear(4, 4) |
| |
| def forward(self, x): |
| return self.linear(x) |
| |
| class Bar(torch.nn.Module): |
| def __init__(self) -> None: |
| super().__init__() |
| self.leaf = Leaf() |
| self.buffer = torch.nn.Buffer(torch.randn(4, 4)) |
| |
| def forward(self, x, z): |
| return self.buffer.sum() + self.leaf(x).sum() + z[0].sum() + z[1].sum() |
| |
| class Foo(torch.nn.Module): |
| def __init__(self) -> None: |
| super().__init__() |
| self.bar = Bar() |
| |
| def forward(self, x, z): |
| y = self.bar.buffer + x + z[0] + z[1] |
| return self.bar(x, z) + y.sum() |
| |
| inp = (torch.randn(4, 4), [torch.randn(4, 4), torch.randn(4, 4)]) |
| mod = Foo() |
| ep_strict = torch.export.export(mod, inp) |
| ep_non_strict = torch.export.export(mod, inp, strict=False) |
| |
| gm_unflat_non_strict = unflatten(ep_non_strict) |
| ep = torch.export.export(gm_unflat_non_strict, inp, strict=False) |
| self.assertTrue(torch.allclose(ep.module()(*inp), mod(*inp))) |
| |
| def test_unflattened_module_nodes_has_meta_val(self): |
| class SubMod(torch.nn.Module): |
| def __init__(self) -> None: |
| super().__init__() |
| |
| def forward(self, x): |
| return x + x, x * x |
| |
| class MyModule(torch.nn.Module): |
| def __init__(self) -> None: |
| super().__init__() |
| self.submod = SubMod() |
| |
| def forward(self, x): |
| return x + sum(self.submod(x)) |
| |
| orig_eager = MyModule() |
| export_module = torch.export.export(orig_eager, (torch.rand(2, 3),), {}) |
| unflattened = unflatten(export_module) |
| |
| inputs = (torch.rand(2, 3),) |
| self.compare_outputs(orig_eager, unflattened, inputs) |
| |
| def check_meta(gm): |
| for n in gm.graph.nodes: |
| if n.op == "output": |
| continue |
| self.assertTrue(n.meta.get("val") is not None) |
| |
| for m in unflattened.modules(): |
| check_meta(m) |
| |
| def test_unflatten_requires_grad_param(self): |
| class M(torch.nn.Module): |
| def __init__(self) -> None: |
| super().__init__() |
| self.p = torch.nn.Parameter(torch.ones(3, 3), requires_grad=False) |
| |
| def forward(self, x): |
| return self.p + x |
| |
| with torch.device("meta"): |
| mod = M() |
| |
| inputs = (torch.randn(3, 3, device="meta"),) |
| ep = export(mod, inputs) |
| unflattened = unflatten(ep) |
| self.assertTrue(unflattened.state_dict()["p"].requires_grad is False) |
| self.assertTrue(unflattened.p.requires_grad is False) |
| |
| def test_placeholder_and_get_attr_ordering_after_unflattened(self): |
| class TransposeModule(torch.nn.Module): |
| def __init__(self) -> None: |
| super().__init__() |
| self.conv = torch.nn.Conv2d(3, 1, 3, stride=2) |
| |
| def forward(self, x): |
| x = self.conv(x) |
| return x.transpose(0, 1) |
| |
| x = torch.randn(32, 3, 64, 64) |
| exported_program = export(TransposeModule(), args=(x,)) |
| unflattened_module = unflatten(exported_program) |
| |
| # Check the inputs of the created call_module node are in order |
| call_module_input_order = [] |
| for node in unflattened_module.graph.nodes: |
| if node.op == "call_module": |
| transpose_module = unflattened_module.get_submodule(node.target) |
| for sub_node in transpose_module.graph.nodes: |
| if sub_node.op == "placeholder" or sub_node.op == "get_attr": |
| call_module_input_order.append(sub_node.op) |
| self.assertEqual( |
| call_module_input_order, ["placeholder", "get_attr", "get_attr"] |
| ) |
| |
| def test_unflatten_constant_tensor(self): |
| class SubMod(torch.nn.Module): |
| def __init__(self) -> None: |
| super().__init__() |
| self.initializer = 0.1 |
| |
| def forward(self, x): |
| return x + torch.tensor(self.initializer) |
| |
| class Mod(torch.nn.Module): |
| def __init__(self) -> None: |
| super().__init__() |
| self.submod = SubMod() |
| |
| def forward(self, x): |
| return x + self.submod(x) |
| |
| export_module = torch.export.export(Mod(), (torch.randn((2, 3)),)) |
| unflattened = unflatten(export_module) |
| |
| self.compare_outputs( |
| export_module.module(), unflattened, (torch.randn((2, 3)),) |
| ) |
| |
| @skipIfTorchDynamo("custom objects not supported in dynamo yet") |
| def test_unflatten_constant_obj(self): |
| init_torchbind_implementations() |
| |
| @torch._library.register_fake_class("_TorchScriptTesting::_Foo") |
| class FakeFoo: |
| def __init__(self, x: int, y: int): |
| self.x = x |
| self.y = y |
| |
| @classmethod |
| def __obj_unflatten__(cls, flat_ctx): |
| return cls(**dict(flat_ctx)) |
| |
| def add_tensor(self, z): |
| return (self.x + self.y) * z |
| |
| class SubMod(torch.nn.Module): |
| def __init__(self) -> None: |
| super().__init__() |
| self.attr = torch.classes._TorchScriptTesting._Foo(10, 20) |
| |
| def forward(self, x): |
| return x + self.attr.add_tensor(x) |
| |
| class Mod(torch.nn.Module): |
| def __init__(self) -> None: |
| super().__init__() |
| self.submod = SubMod() |
| |
| def forward(self, x): |
| return x + self.submod(x) |
| |
| with enable_torchbind_tracing(): |
| export_module = torch.export.export( |
| Mod(), (torch.randn((2, 3)),), strict=False |
| ) |
| unflattened = unflatten(export_module) |
| |
| self.compare_outputs( |
| export_module.module(), unflattened, (torch.randn((2, 3)),) |
| ) |
| |
| # skip connection is not supported yet |
| @unittest.expectedFailure |
| def test_unflatten_skipped_call_module(self): |
| class C(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| |
| def forward(self, x): |
| return a.d(x.cos()) |
| |
| class B(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.c = C() |
| |
| def forward(self, x): |
| return self.c(x) + x |
| |
| class D(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| |
| def forward(self, x): |
| return x.sin() |
| |
| class A(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.b = B() |
| self.d = D() |
| |
| def forward(self, x): |
| return self.b(x) |
| |
| a = A() |
| |
| # The call chain looks like this: |
| # A -> B -> C -> A.d |
| ep = torch.export.export(a, (torch.randn(3),), strict=False) |
| unflattened = unflatten(ep) |
| |
| def test_nested_leaf_non_strict(self): |
| class Leaf(torch.nn.Module): |
| def forward(self, x): |
| return x + 1 |
| |
| class Nested(torch.nn.Module): |
| def __init__(self) -> None: |
| super().__init__() |
| self.leaf = Leaf() |
| |
| def forward(self, x): |
| return self.leaf(x) + 2 |
| |
| class TopLevel(torch.nn.Module): |
| def __init__(self) -> None: |
| super().__init__() |
| self.nested = Nested() |
| |
| def forward(self, x): |
| return self.nested(x) + 3 |
| |
| ep = torch.export.export( |
| TopLevel(), |
| (torch.randn(3),), |
| strict=False, |
| preserve_module_call_signature=("nested",), |
| ) |
| |
| torch.export.unflatten(ep) |
| |
| def test_unflatten_submodule_ordering(self): |
| class Module2(torch.nn.Module): |
| def __init__(self) -> None: |
| super().__init__() |
| self.buffer = torch.nn.Buffer(torch.rand(3, 4)) |
| self.register_parameter("param", torch.nn.Parameter(torch.rand(3, 4))) |
| |
| def forward(self, x): |
| return x + self.buffer + self.param |
| |
| class Module1(torch.nn.Module): |
| def __init__(self) -> None: |
| super().__init__() |
| self.buffer = torch.nn.Buffer(torch.rand(3, 4)) |
| self.register_parameter("param", torch.nn.Parameter(torch.rand(3, 4))) |
| |
| def forward(self, x): |
| return x + self.buffer + self.param |
| |
| class Module(torch.nn.Module): |
| def __init__(self) -> None: |
| super().__init__() |
| self.mod2 = Module2() |
| self.mod3 = self.mod2 |
| self.mod1 = Module1() |
| |
| def forward(self, x): |
| return self.mod3(self.mod2(self.mod1(x))) |
| |
| mod = Module() |
| |
| ep = torch.export.export(mod, (torch.randn(3, 4),)) |
| |
| unflattened = torch.export.unflatten(ep) |
| fqn_list = [x for x, _ in unflattened.named_modules(remove_duplicate=False)] |
| self.assertEqual(len(fqn_list), 4) |
| self.assertEqual( |
| [x for x, _ in mod.named_modules(remove_duplicate=False)], |
| fqn_list, |
| ) |
| |
| def test_duplicate_placeholder(self): |
| N, C, H, W = 1, 2, 2, 3 |
| |
| class MyModule(torch.nn.Module): |
| def __init__(self) -> None: |
| super().__init__() |
| layer = torch.nn.LayerNorm([C, H, W]) |
| self.norms = torch.nn.ModuleList( |
| [ |
| layer, # reuse layer norm |
| layer, |
| layer, |
| ] |
| ) |
| |
| def forward(self, input_): |
| for i in range(len(self.norms)): |
| output = self.norms[i](input_) |
| input_ = output |
| return output |
| |
| mod = MyModule() |
| input_ = torch.randn(N, C, H, W) |
| |
| ep_strict = export(copy.deepcopy(mod), (input_,), strict=True) |
| umod = unflatten(ep_strict) |
| self.assertTrue(torch.allclose(umod(input_), mod(input_))) |
| |
| ep_non_strict = export(copy.deepcopy(mod), (input_,), strict=False) |
| umod = unflatten(ep_non_strict) |
| self.assertTrue(torch.allclose(umod(input_), mod(input_))) |
| |
| def test_simple_alias(self): |
| # handle weight sharing, check tensor ids after unflattening |
| class Foo(torch.nn.Module): |
| def __init__(self) -> None: |
| super().__init__() |
| # alias param |
| self.bias = torch.nn.Parameter(torch.randn(4)) |
| self.m = torch.nn.Linear(4, 4) |
| self.m.bias = self.bias |
| |
| def forward(self, x): |
| return self.m(x) + self.bias |
| |
| m = Foo() |
| inps = (torch.randn(4, 4),) |
| ep = export(m, inps) |
| unep = unflatten(ep) |
| self.assertTrue(id(unep.m.bias) == id(unep.bias)) |
| |
| # handle aliasing where one alias is unused |
| class Foo(torch.nn.Module): |
| def __init__(self) -> None: |
| super().__init__() |
| self.bias = torch.nn.Parameter(torch.randn(4)) |
| self.m = torch.nn.Linear(4, 4) |
| self.m.bias = ( |
| self.bias |
| ) # self.bias is unused, aliasing should be handled |
| |
| def forward(self, x): |
| return self.m(x) |
| |
| m = Foo() |
| inps = (torch.randn(4, 4),) |
| ep = export(m, inps) |
| unep = unflatten(ep) |
| self.assertTrue(torch.allclose(unep(*inps), m(*inps))) |
| |
| def test_attr_as_submod_input(self): |
| class layer(torch.nn.Module): |
| def forward(self, x, const) -> torch.Tensor: |
| return x + const |
| |
| class M(torch.nn.Module): |
| def __init__(self) -> None: |
| super().__init__() |
| self.const = torch.nn.Buffer(torch.ones(4, 8)) |
| self.layers = torch.nn.ModuleList([layer() for _ in range(2)]) |
| |
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| for layer in self.layers: |
| x = layer(x, self.const) |
| return x |
| |
| mod = M() |
| x = torch.randn(4, 8) |
| ep = export(mod, (x,)) |
| unflattened = unflatten(ep) |
| torch.testing.assert_close(unflattened(x), mod(x)) |
| |
| def test_dedup_sym_size(self): |
| # Here, sym_size & floor div are used in 3 subgraphs (top-level, m1, m2), |
| # but only one copy of sym_size is created in the initial export graph. |
| # For m1, sym_size & floordiv should be copied as recompute since we preserve the call signature, |
| # but for m2 floordiv should be passed in as a placeholder. |
| # Test that this is preserved, and the unflattened module runs correctly. |
| class M1(torch.nn.Module): |
| def forward(self, x, y): |
| d = x.size(0) // 2 |
| return y[:d] |
| |
| class M2(torch.nn.Module): |
| def forward(self, x, y): |
| d = x.size(0) // 2 |
| return y[:d] |
| |
| class M(torch.nn.Module): |
| def __init__(self) -> None: |
| super().__init__() |
| self.m1 = M1() |
| self.m2 = M2() |
| |
| def forward(self, x, y): |
| d = x.size(0) // 2 |
| m1_res = self.m1(x, y) |
| m2_res = self.m2(x, y) |
| return y[d:] + m1_res + m2_res |
| |
| inputs = (torch.ones(10), torch.ones(10)) |
| d_ = torch.export.Dim("foo", max=2048) |
| d = 2 * d_ |
| ep = torch.export.export( |
| M(), |
| inputs, |
| dynamic_shapes=((d,), (d,)), |
| strict=False, |
| preserve_module_call_signature=("m1",), |
| ) |
| unflat = unflatten(ep) |
| unflat(*inputs) |
| |
| fn_count_sym_size = lambda graph: [node.target for node in graph.nodes].count( |
| torch.ops.aten.sym_size.int |
| ) |
| self.assertEqual(fn_count_sym_size(unflat.graph), 1) |
| self.assertEqual(fn_count_sym_size(unflat.m1.graph), 1) |
| self.assertEqual(fn_count_sym_size(unflat.m2.graph), 0) |
| |
| def test_unflatten_eager(self): |
| class NestedChild(torch.nn.Module): |
| def forward(self, x): |
| return x / x |
| |
| class Child1(torch.nn.Module): |
| def __init__(self) -> None: |
| super().__init__() |
| self.nested = NestedChild() |
| self.register_parameter( |
| "child1param", torch.nn.Parameter(torch.ones(2, 3)) |
| ) |
| |
| def forward(self, x): |
| x = self.nested(x) |
| return x + self.child1param |
| |
| class Child2(torch.nn.Module): |
| def __init__(self) -> None: |
| super().__init__() |
| self.child2buffer = torch.nn.Buffer(torch.ones(2, 3)) |
| |
| def forward(self, x): |
| return x - self.child2buffer |
| |
| class MyModule(torch.nn.Module): |
| def __init__(self) -> None: |
| super().__init__() |
| self.foo = Child1() |
| self.bar = Child2() |
| self.register_parameter( |
| "rootparam", torch.nn.Parameter(torch.ones(2, 3)) |
| ) |
| |
| def forward(self, x): |
| x = x * self.rootparam |
| x = self.foo(x) |
| x = self.bar(x) |
| return x |
| |
| orig_eager = MyModule() |
| export_module = export(orig_eager, (torch.rand(2, 3),), {}) |
| with _disable_interpreter(): |
| unflattened = unflatten(export_module) |
| |
| self.assertEqual(unflattened._run_with_interpeter, False) |
| self.assertEqual(unflattened.foo._run_with_interpeter, False) |
| |
| inputs = (torch.rand(2, 3),) |
| |
| # Compare the root modules and all submodules |
| self.compare_outputs(orig_eager, unflattened, inputs) |
| self.compare_outputs(orig_eager.foo, unflattened.foo, inputs) |
| self.compare_outputs(orig_eager.bar, unflattened.bar, inputs) |
| self.compare_outputs(orig_eager.foo.nested, unflattened.foo.nested, inputs) |
| |
| # Check state dicts are equal |
| orig_state_dict = orig_eager.state_dict() |
| exported_state_dict = unflattened.state_dict() |
| for name, value in orig_state_dict.items(): |
| self.assertTrue(torch.allclose(value, exported_state_dict[name])) |
| |
| # Check composability with symbolic trace, as torchrec ddp uses symbolic |
| # tracer |
| symbolic_traced = torch.fx.symbolic_trace(unflattened, concrete_args=inputs) |
| self.assertTrue(torch.allclose(orig_eager(*inputs), symbolic_traced(*inputs))) |
| |
| # torch.compile submodule |
| unflattened.foo = torch.compile(unflattened.foo, fullgraph=True) |
| self.compare_outputs(orig_eager, unflattened, inputs) |
| |
| |
| if __name__ == "__main__": |
| run_tests() |