blob: b8ff48334f01120b6aaf0c6a5a07748599512b0a [file] [log] [blame]
# Owner(s): ["oncall: export"]
# flake8: noqa
import copy
import dataclasses
import unittest
from contextlib import contextmanager
from dataclasses import dataclass
from re import escape
from typing import Any, List
import torch
import torch._dynamo as torchdynamo
from functorch.experimental.control_flow import cond, map
from torch import Tensor
from torch._export.utils import (
get_buffer,
get_param,
is_buffer,
is_param,
register_dataclass_as_pytree_node,
)
from torch._higher_order_ops.torchbind import enable_torchbind_tracing
from torch.export import (
Constraint,
Dim,
dynamic_dim,
export,
FlatArgsAdapter,
unflatten,
)
from torch.export._trace import DEFAULT_EXPORT_DYNAMO_CONFIG
from torch.fx.experimental.proxy_tensor import make_fx
from torch.testing import FileCheck
from torch.testing._internal.common_utils import (
find_library_location,
IS_FBCODE,
IS_MACOS,
IS_SANDCASTLE,
IS_WINDOWS,
run_tests,
skipIfTorchDynamo,
TestCase,
)
from torch.testing._internal.torchbind_impls import init_torchbind_implementations
from torch.utils._pytree import (
LeafSpec,
tree_flatten,
tree_unflatten,
TreeSpec,
treespec_dumps,
treespec_loads,
)
@unittest.skipIf(not torchdynamo.is_dynamo_supported(), "dynamo isn't support")
class TestUnflatten(TestCase):
def compare_outputs(self, eager, unflattened, args):
orig_output = eager(*args)
unflattened_output = unflattened(*args)
self.assertTrue(torch.allclose(orig_output, unflattened_output))
def test_unflatten_nested(self):
class NestedChild(torch.nn.Module):
def forward(self, x):
return x / x
class Child1(torch.nn.Module):
def __init__(self):
super().__init__()
self.nested = NestedChild()
self.register_parameter(
"child1param", torch.nn.Parameter(torch.ones(2, 3))
)
def forward(self, x):
x = self.nested(x)
return x + self.child1param
class Child2(torch.nn.Module):
def __init__(self):
super().__init__()
self.register_buffer("child2buffer", torch.ones(2, 3))
def forward(self, x):
return x - self.child2buffer
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.foo = Child1()
self.bar = Child2()
self.register_parameter(
"rootparam", torch.nn.Parameter(torch.ones(2, 3))
)
def forward(self, x):
x = x * self.rootparam
x = self.foo(x)
x = self.bar(x)
return x
orig_eager = MyModule()
export_module = export(orig_eager, (torch.rand(2, 3),), {})
unflattened = unflatten(export_module)
inputs = (torch.rand(2, 3),)
# Compare the root modules and all submodules
self.compare_outputs(orig_eager, unflattened, inputs)
self.compare_outputs(orig_eager.foo, unflattened.foo, inputs)
self.compare_outputs(orig_eager.bar, unflattened.bar, inputs)
self.compare_outputs(orig_eager.foo.nested, unflattened.foo.nested, inputs)
# Check state dicts are equal
orig_state_dict = orig_eager.state_dict()
exported_state_dict = unflattened.state_dict()
for name, value in orig_state_dict.items():
self.assertTrue(torch.allclose(value, exported_state_dict[name]))
def test_unflatten_buffer_mutation(self):
class Child(torch.nn.Module):
def __init__(self):
super().__init__()
self.register_buffer("child2buffer", torch.ones(2, 3))
def forward(self, x):
self.child2buffer.add_(x)
return x - self.child2buffer
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.foo = Child()
self.register_parameter(
"rootparam", torch.nn.Parameter(torch.ones(2, 3))
)
def forward(self, x):
x = self.foo(x)
return x * self.rootparam
eager_module = MyModule()
export_module = export(eager_module, (torch.rand(2, 3),), {})
unflattened_module = unflatten(export_module)
# Buffer should look the same before and after one run
eager_buffer = eager_module.foo.child2buffer
unflattened_buffer = unflattened_module.foo.child2buffer
self.assertTrue(torch.allclose(eager_buffer, unflattened_buffer))
inputs = (torch.rand(2, 3),)
eager_module(*inputs)
unflattened_module(*inputs)
self.assertTrue(torch.allclose(eager_buffer, unflattened_buffer))
def test_unflatten_nested_access(self):
class Child(torch.nn.Module):
def __init__(self):
super().__init__()
self.register_buffer("child2buffer", torch.ones(2, 3))
def forward(self, x):
return x - self.child2buffer
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.foo = Child()
self.register_parameter(
"rootparam", torch.nn.Parameter(torch.ones(2, 3))
)
def forward(self, x):
x = x + self.foo.child2buffer
x = self.foo(x)
return x
eager_module = MyModule()
export_module = export(eager_module, (torch.rand(2, 3),), {})
unflattened_module = unflatten(export_module)
inputs = (torch.rand(2, 3),)
self.compare_outputs(eager_module, unflattened_module, inputs)
def test_unflatten_shared_submodule(self):
class Shared(torch.nn.Module):
def __init__(self):
super().__init__()
layernorm = torch.nn.LayerNorm(10)
self.sub_net = torch.nn.Sequential(
layernorm,
torch.nn.ReLU(),
layernorm,
torch.nn.ReLU(),
)
def forward(self, x):
return self.sub_net(x)
eager_module = Shared()
inps = (torch.rand(10),)
export_module = export(eager_module, inps, {})
unflattened_module = unflatten(export_module)
self.compare_outputs(eager_module, unflattened_module, inps)
self.assertTrue(hasattr(unflattened_module, "sub_net"))
for i in range(len(eager_module.sub_net)):
self.assertTrue(hasattr(unflattened_module.sub_net, str(i)))
self.assertEqual(
id(getattr(unflattened_module.sub_net, "0")),
id(getattr(unflattened_module.sub_net, "2")),
)
@unittest.skipIf(IS_WINDOWS, "Windows not supported for this test")
@skipIfTorchDynamo("Non strict mode is not meant to run with dynamo")
def test_unflatten_preserve_signature(self):
class NestedChild(torch.nn.Module):
def forward(self, zx, y):
return {"x": y["key"] + zx[1], "w": y["key"] * zx[1]}
class Child1(torch.nn.Module):
def __init__(self):
super().__init__()
self.nested = NestedChild()
def forward(self, x, y):
z = torch.ones_like(x)
xw = self.nested((z, x), y={"key": y})
return xw["w"] + z - xw["x"]
class Child2(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
return x - 1
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.foo = Child1()
self.bar = Child2()
def forward(self, x, y):
x = self.foo(x, y)
x = self.bar(x)
return x
orig_eager = MyModule()
inps = torch.rand(2, 3), torch.rand(2, 3)
for strict in [True, False]:
export_module = export(
orig_eager,
inps,
{},
preserve_module_call_signature=("foo.nested",),
strict=strict,
)
unflattened = unflatten(export_module)
self.compare_outputs(export_module.module(), unflattened, inps)
unflattened.foo.nested = NestedChild()
self.compare_outputs(export_module.module(), unflattened, inps)
# Test tree spec mismatched input
orig_outs = export_module.module()(*inps)
new_inps = *inps, torch.rand(2, 3)
with self.assertRaisesRegex(
TypeError,
"There is no flat args adapter sepcified. Are you sure you are calling this with the right arguments?",
):
unflattened(new_inps)
# With flat args adapter
class KeepTwoFlatArgsAdapter(FlatArgsAdapter):
def adapt(
self,
target_spec: TreeSpec,
input_spec: TreeSpec,
input_args: List[Any],
) -> List[Any]:
while len(input_args) > 2:
input_args.pop(-1)
return input_args
unflattened = unflatten(export_module, KeepTwoFlatArgsAdapter())
new_outs = unflattened(*new_inps)
self.assertTrue(torch.allclose(orig_outs, new_outs))
def test_unflatten_param_list_dict(self):
class Mod(torch.nn.Module):
def __init__(self):
super().__init__()
self.param_list = torch.nn.ParameterList()
self.param_dict = torch.nn.ParameterDict()
for i in range(2):
self.param_list.append(torch.nn.Parameter(torch.randn((2, 3))))
self.param_dict[f"key_{i}"] = torch.nn.Parameter(
torch.randn((2, 3))
)
def forward(self, x):
for i in range(2):
x = x + self.param_list[i]
x = x + self.param_dict[f"key_{i}"]
return x
export_module = torch.export.export(Mod(), (torch.randn((2, 3)),))
unflattened = unflatten(export_module)
self.compare_outputs(
export_module.module(), unflattened, (torch.randn((2, 3)),)
)
def test_unflatten_wrong_input(self):
class Mod(torch.nn.Module):
def __init__(self):
super().__init__()
self.param_list = torch.nn.ParameterList()
self.param_dict = torch.nn.ParameterDict()
for i in range(2):
self.param_list.append(torch.nn.Parameter(torch.randn((2, 3))))
self.param_dict[f"key_{i}"] = torch.nn.Parameter(
torch.randn((2, 3))
)
def forward(self, x):
a = x.sum()
for i in range(2):
a = a + self.param_list[i].sum()
a = a + self.param_dict[f"key_{i}"].sum()
return a
export_module = torch.export.export(Mod(), (torch.randn((2, 3)),))
with self.assertRaisesRegex(
RuntimeError,
escape("Expected input at *args[0].shape[0] to be equal to 2, but got 6"),
):
export_module.module()(torch.randn(6, 6))
unflattened = unflatten(export_module)
with self.assertRaisesRegex(
RuntimeError,
escape("Expected input at *args[0].shape[0] to be equal to 2, but got 6"),
):
unflattened(torch.randn(6, 6))
def test_unflatten_with_inplace_compile(self):
class NestedChild(torch.nn.Module):
def forward(self, x):
return x / x
class Child1(torch.nn.Module):
def __init__(self):
super().__init__()
self.nested = NestedChild()
self.register_parameter(
"child1param", torch.nn.Parameter(torch.ones(2, 3))
)
def forward(self, x):
x = self.nested(x)
return x + self.child1param
class Child2(torch.nn.Module):
def __init__(self):
super().__init__()
self.register_buffer("child2buffer", torch.ones(2, 3))
def forward(self, x):
return x - self.child2buffer
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.foo = Child1()
self.bar = Child2()
self.register_parameter(
"rootparam", torch.nn.Parameter(torch.ones(2, 3))
)
def forward(self, x):
x = x * self.rootparam
x = self.foo(x)
x = self.bar(x)
return x
orig_eager = MyModule()
export_module = torch.export.export(orig_eager, (torch.rand(2, 3),), {})
unflattened = unflatten(export_module)
# in-place compilation should work. Pass fullgraph to ensure no graph breaks.
unflattened.foo.compile(fullgraph=True)
inputs = (torch.rand(2, 3),)
self.compare_outputs(orig_eager, unflattened, inputs)
def test_fx_trace(self):
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x, y):
x = x[0] + x[1]
x = x + y["foo"]
return x
orig_eager = MyModule()
inputs = ((torch.rand(2, 3), torch.rand(2, 3)), {"foo": torch.rand(2, 3)})
export_module = export(orig_eager, inputs, {})
unflattened = unflatten(export_module)
torch.fx.symbolic_trace(
unflattened, concrete_args=(torch.fx.PH, torch.fx.PH, torch.fx.PH)
)
def test_double_nested_submodule(self):
class SubSubMod(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
return x * x
class SubMod(torch.nn.Module):
def __init__(self):
super().__init__()
self.subsubmod = SubSubMod()
def forward(self, x):
return x - x
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.submod = SubMod()
def forward(self, x):
return x + self.submod.subsubmod(x)
orig_eager = MyModule()
export_module = torch.export.export(orig_eager, (torch.rand(2, 3),), {})
unflattened = unflatten(export_module)
inputs = (torch.rand(2, 3),)
self.compare_outputs(orig_eager, unflattened, inputs)
def test_unflatten_container_type(self):
class Leaf(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(4, 4)
def forward(self, x):
return self.linear(x)
class Bar(torch.nn.Module):
def __init__(self):
super().__init__()
self.leaf = Leaf()
self.register_buffer("buffer", torch.randn(4, 4))
def forward(self, x, z):
return self.buffer.sum() + self.leaf(x).sum() + z[0].sum() + z[1].sum()
class Foo(torch.nn.Module):
def __init__(self):
super().__init__()
self.bar = Bar()
def forward(self, x, z):
y = self.bar.buffer + x + z[0] + z[1]
return self.bar(x, z) + y.sum()
inp = (torch.randn(4, 4), [torch.randn(4, 4), torch.randn(4, 4)])
mod = Foo()
ep_strict = torch.export.export(mod, inp)
ep_non_strict = torch.export.export(mod, inp, strict=False)
gm_unflat_non_strict = unflatten(ep_non_strict)
ep = torch.export.export(gm_unflat_non_strict, inp, strict=False)
self.assertTrue(torch.allclose(ep.module()(*inp), mod(*inp)))
def test_unflattened_module_nodes_has_meta_val(self):
class SubMod(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
return x + x, x * x
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.submod = SubMod()
def forward(self, x):
return x + sum(self.submod(x))
orig_eager = MyModule()
export_module = torch.export.export(orig_eager, (torch.rand(2, 3),), {})
unflattened = unflatten(export_module)
inputs = (torch.rand(2, 3),)
self.compare_outputs(orig_eager, unflattened, inputs)
def check_meta(gm):
for n in gm.graph.nodes:
if n.op == "output":
continue
self.assertTrue(n.meta.get("val") is not None)
for m in unflattened.modules():
check_meta(m)
def test_placeholder_and_get_attr_ordering_after_unflattened(self):
class TransposeModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = torch.nn.Conv2d(3, 1, 3, stride=2)
def forward(self, x):
x = self.conv(x)
return x.transpose(0, 1)
x = torch.randn(32, 3, 64, 64)
exported_program = export(TransposeModule(), args=(x,))
unflattened_module = unflatten(exported_program)
# Check the inputs of the created call_module node are in order
call_module_input_order = []
for node in unflattened_module.graph.nodes:
if node.op == "call_module":
transpose_module = unflattened_module.get_submodule(node.target)
for sub_node in transpose_module.graph.nodes:
if sub_node.op == "placeholder" or sub_node.op == "get_attr":
call_module_input_order.append(sub_node.op)
self.assertEqual(
call_module_input_order, ["placeholder", "get_attr", "get_attr"]
)
def test_unflatten_constant_tensor(self):
class SubMod(torch.nn.Module):
def __init__(self):
super().__init__()
self.initializer = 0.1
def forward(self, x):
return x + torch.tensor(self.initializer)
class Mod(torch.nn.Module):
def __init__(self):
super().__init__()
self.submod = SubMod()
def forward(self, x):
return x + self.submod(x)
export_module = torch.export.export(Mod(), (torch.randn((2, 3)),))
unflattened = unflatten(export_module)
self.compare_outputs(
export_module.module(), unflattened, (torch.randn((2, 3)),)
)
@skipIfTorchDynamo("custom objects not supported in dynamo yet")
def test_unflatten_constant_obj(self):
init_torchbind_implementations()
@torch._library.register_fake_class("_TorchScriptTesting::_Foo")
class FakeFoo:
def __init__(self, x: int, y: int):
self.x = x
self.y = y
@classmethod
def __obj_unflatten__(cls, flat_ctx):
return cls(**dict(flat_ctx))
def add_tensor(self, z):
return (self.x + self.y) * z
class SubMod(torch.nn.Module):
def __init__(self):
super().__init__()
self.attr = torch.classes._TorchScriptTesting._Foo(10, 20)
def forward(self, x):
return x + self.attr.add_tensor(x)
class Mod(torch.nn.Module):
def __init__(self):
super().__init__()
self.submod = SubMod()
def forward(self, x):
return x + self.submod(x)
with enable_torchbind_tracing():
export_module = torch.export.export(
Mod(), (torch.randn((2, 3)),), strict=False
)
unflattened = unflatten(export_module)
self.compare_outputs(
export_module.module(), unflattened, (torch.randn((2, 3)),)
)
def test_nested_leaf_non_strict(self):
class Leaf(torch.nn.Module):
def forward(self, x):
return x + 1
class Nested(torch.nn.Module):
def __init__(self):
super().__init__()
self.leaf = Leaf()
def forward(self, x):
return self.leaf(x) + 2
class TopLevel(torch.nn.Module):
def __init__(self):
super().__init__()
self.nested = Nested()
def forward(self, x):
return self.nested(x) + 3
ep = torch.export.export(
TopLevel(),
(torch.randn(3),),
strict=False,
preserve_module_call_signature=("nested",),
)
torch.export.unflatten(ep)
def test_unflatten_submodule_ordering(self):
class Module2(torch.nn.Module):
def __init__(self):
super().__init__()
self.register_buffer("buffer", torch.rand(3, 4))
self.register_parameter("param", torch.nn.Parameter(torch.rand(3, 4)))
def forward(self, x):
return x + self.buffer + self.param
class Module1(torch.nn.Module):
def __init__(self):
super().__init__()
self.register_buffer("buffer", torch.rand(3, 4))
self.register_parameter("param", torch.nn.Parameter(torch.rand(3, 4)))
def forward(self, x):
return x + self.buffer + self.param
class Module(torch.nn.Module):
def __init__(self):
super().__init__()
self.mod2 = Module2()
self.mod3 = self.mod2
self.mod1 = Module1()
def forward(self, x):
return self.mod3(self.mod2(self.mod1(x)))
mod = Module()
ep = torch.export.export(mod, (torch.randn(3, 4),))
unflattened = torch.export.unflatten(ep)
fqn_list = [x for x, _ in unflattened.named_modules(remove_duplicate=False)]
self.assertEqual(len(fqn_list), 4)
self.assertEqual(
[x for x, _ in mod.named_modules(remove_duplicate=False)],
fqn_list,
)
def test_duplicate_placeholder(self):
N, C, H, W = 1, 2, 2, 3
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
layer = torch.nn.LayerNorm([C, H, W])
self.norms = torch.nn.ModuleList(
[
layer, # reuse layer norm
layer,
layer,
]
)
def forward(self, input_):
for i in range(len(self.norms)):
output = self.norms[i](input_)
input_ = output
return output
mod = MyModule()
input_ = torch.randn(N, C, H, W)
ep_strict = export(copy.deepcopy(mod), (input_,), strict=True)
umod = unflatten(ep_strict)
self.assertTrue(torch.allclose(umod(input_), mod(input_)))
ep_non_strict = export(copy.deepcopy(mod), (input_,), strict=False)
umod = unflatten(ep_non_strict)
self.assertTrue(torch.allclose(umod(input_), mod(input_)))
if __name__ == "__main__":
run_tests()