| """ |
| PYTEST_DONT_REWRITE (prevents pytest from rewriting assertions, which interferes |
| with test_sym_bool) |
| """ |
| |
| |
| # Owner(s): ["oncall: export"] |
| import copy |
| import io |
| import tempfile |
| import unittest |
| import zipfile |
| from pathlib import Path |
| |
| import torch |
| import torch._dynamo as torchdynamo |
| import torch.export._trace |
| import torch.utils._pytree as pytree |
| from torch._export.db.case import ExportCase, SupportLevel |
| from torch._export.db.examples import all_examples |
| from torch._export.serde.serialize import ( |
| canonicalize, |
| deserialize, |
| ExportedProgramDeserializer, |
| ExportedProgramSerializer, |
| serialize, |
| SerializeError, |
| ) |
| from torch._higher_order_ops.torchbind import enable_torchbind_tracing |
| from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode |
| from torch.export import Dim, export, load, save |
| from torch.fx.experimental.symbolic_shapes import is_concrete_int, ValueRanges |
| from torch.testing._internal.common_utils import ( |
| instantiate_parametrized_tests, |
| IS_WINDOWS, |
| parametrize, |
| run_tests, |
| TemporaryFileName, |
| TestCase, |
| ) |
| from torch.testing._internal.torchbind_impls import init_torchbind_implementations |
| |
| |
| def get_filtered_export_db_tests(): |
| return [ |
| (name, case) |
| for name, case in all_examples().items() |
| if case.support_level == SupportLevel.SUPPORTED |
| ] |
| |
| |
| @unittest.skipIf(not torchdynamo.is_dynamo_supported(), "dynamo doesn't support") |
| class TestSerialize(TestCase): |
| def test_export_with_extension_op_serialization(self): |
| class TestModule(torch.nn.Module): |
| def forward(self, x): |
| return x + x |
| |
| class FooExtensionOp: |
| def __hash__(self): |
| return 0 |
| |
| def __eq__(self, other): |
| return type(other) == type(self) |
| |
| def __call__(self, *args, **kwargs): |
| return torch.ops.aten.add.Tensor(*args, **kwargs) |
| |
| @property |
| def __name__(self): |
| return "foo.my_op" |
| |
| class ExtensionVerifier(torch._export.verifier.Verifier): |
| dialect = "FOO" |
| |
| def allowed_op_types(self): |
| return super().allowed_op_types() + (FooExtensionOp,) |
| |
| class FooExtensionHandler(torch._export.serde.serialize.ExtensionHandler): |
| @classmethod |
| def namespace(cls): |
| return "foo" |
| |
| @classmethod |
| def to_op_name(cls, op): |
| return "my_op" |
| |
| @classmethod |
| def from_op_name(cls, name: str): |
| self.assertEqual(name, "my_op") |
| return FooExtensionOp() |
| |
| @classmethod |
| def op_schema(cls, op): |
| return torch.ops.aten.add.Tensor._schema |
| |
| inp = (torch.ones(10),) |
| ep = export(TestModule(), inp) |
| |
| # Register the custom op handler. |
| foo_custom_op = FooExtensionOp() |
| torch._export.serde.serialize.register_extension( |
| FooExtensionOp, FooExtensionHandler |
| ) |
| |
| new_gm = copy.deepcopy(ep.graph_module) |
| # Inject the custom operator. |
| for node in new_gm.graph.nodes: |
| if node.name == "add": |
| node.target = foo_custom_op |
| |
| new_ep = ep._update(new_gm, ep.graph_signature, verifiers=[ExtensionVerifier]) |
| serialized = serialize(new_ep) |
| deserialized = deserialize(serialized) |
| self.assertEqual( |
| len( |
| deserialized.graph.find_nodes(op="call_function", target=foo_custom_op) |
| ), |
| 1, |
| ) |
| |
| def test_predispatch_export_with_autograd_op(self): |
| class Foo(torch.nn.Module): |
| def __init__(self) -> None: |
| super().__init__() |
| |
| def forward(self, x): |
| with torch.enable_grad(): |
| return x + x |
| |
| inp = (torch.ones(10),) |
| with torch.no_grad(): |
| from torch.export._trace import _export |
| |
| ep = _export(Foo(), inp, pre_dispatch=True) |
| |
| buffer = io.BytesIO() |
| torch.export.save(ep, buffer) |
| buffer.seek(0) |
| loaded_ep = torch.export.load(buffer) |
| |
| exp_out = ep.module()(*inp) |
| actual_out = loaded_ep.module()(*inp) |
| self.assertEqual(exp_out, actual_out) |
| self.assertEqual(exp_out.requires_grad, actual_out.requires_grad) |
| |
| def test_export_example_inputs_preserved(self): |
| class MyModule(torch.nn.Module): |
| """A test module with that has multiple args and uses kwargs""" |
| |
| def __init__(self) -> None: |
| super().__init__() |
| self.p = torch.nn.Parameter(torch.ones(2, 3)) |
| |
| def forward(self, x, y, use_p=False): |
| out = x + y |
| if use_p: |
| out += self.p |
| return out |
| |
| model = MyModule().eval() |
| random_inputs = (torch.rand([2, 3]), torch.rand([2, 3])) |
| exp_program = torch.export.export(model, random_inputs, {"use_p": True}) |
| |
| output_buffer = io.BytesIO() |
| # Tests that example inputs are preserved when saving and loading module. |
| torch.export.save(exp_program, output_buffer) |
| loaded_model = torch.export.load(output_buffer) |
| # Extract the example inputs from before and after saving. |
| orig_args, orig_kwargs = exp_program.example_inputs |
| loaded_args, loaded_kwargs = loaded_model.example_inputs |
| # Run both modules and confirm that outputs match. |
| orig_out = exp_program.module()(*orig_args, **orig_kwargs) |
| loaded_out = loaded_model.module()(*loaded_args, **loaded_kwargs) |
| self.assertEqual(orig_out, loaded_out) |
| |
| def test_metadata_parsing_with_layer_split(self): |
| # Tests that modules with more complicated layer patterns can be serialized |
| # and deserialized correctly. |
| class MyModule(torch.nn.Module): |
| def __init__(self) -> None: |
| super().__init__() |
| self.layers = torch.nn.Sequential( |
| torch.nn.SiLU(), |
| torch.nn.SiLU(), |
| torch.nn.SiLU(), |
| ) |
| |
| def forward(self, x): |
| # Splitting layers of a sequential stack introduces commas and parens |
| # into metadata trace. |
| out_start, out_rest = self.layers[0], self.layers[1:] |
| h = out_start(x) |
| h = out_rest(h) |
| return h |
| |
| inp = (torch.ones(10),) |
| # Module will only be able to roundtrip if metadata |
| # can be correctly parsed. |
| ep = export(MyModule(), inp) |
| buffer = io.BytesIO() |
| save(ep, buffer) |
| loaded_ep = load(buffer) |
| |
| # Check that both modules run to confirm load was successful. |
| exp_out = ep.module()(*inp) |
| actual_out = loaded_ep.module()(*inp) |
| self.assertEqual(exp_out, actual_out) |
| |
| def test_serialize_constant_outputs(self): |
| class MyModule(torch.nn.Module): |
| def __init__(self) -> None: |
| super().__init__() |
| |
| def forward(self, x): |
| # Along with tensor output, return Nonetype |
| # and constant. Although these outputs aren't |
| # very useful, they do show up in graphs. |
| return x + 1, None, 1024 |
| |
| # Check that module can be roundtripped, thereby confirming proper deserialization. |
| inp = (torch.ones(10),) |
| ep = export(MyModule(), inp) |
| buffer = io.BytesIO() |
| save(ep, buffer) |
| loaded_ep = load(buffer) |
| |
| exp_out = ep.module()(*inp) |
| actual_out = loaded_ep.module()(*inp) |
| self.assertEqual(exp_out, actual_out) |
| |
| def test_serialize_multiple_returns_from_node(self) -> None: |
| class MyModule(torch.nn.Module): |
| def __init__(self) -> None: |
| super().__init__() |
| |
| def forward(self, x, w, b): |
| return torch.nn.functional.layer_norm( |
| x, |
| x.size()[1:], |
| weight=w, |
| bias=b, |
| eps=1e-5, |
| ) |
| |
| exported_module = export( |
| MyModule(), |
| ( |
| torch.ones([512, 512], requires_grad=True), |
| torch.ones([512]), |
| torch.ones([512]), |
| ), |
| ).run_decompositions() |
| |
| serialized = ExportedProgramSerializer().serialize(exported_module) |
| node = serialized.exported_program.graph_module.graph.nodes[-1] |
| self.assertEqual(node.target, "torch.ops.aten.native_layer_norm.default") |
| # aten::native_layer_norm returns 3 tensors |
| self.assertEqual(len(node.outputs), 3) |
| |
| # check the names are unique |
| seen = set() |
| for output in node.outputs: |
| name = output.as_tensor.name |
| self.assertNotIn(name, seen) |
| seen.add(name) |
| |
| def test_serialize_sym_int(self) -> None: |
| class DynamicShapeSimpleModel(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| |
| def forward(self, a, b, c) -> torch.Tensor: |
| d = (torch.matmul(a, b) + c) / 2 |
| d_s0 = d.shape[0] |
| d_s1 = d.shape[1] |
| d_s3 = d_s0 * d_s1 |
| e = d.view(d_s3) |
| return torch.cat([e, e]) |
| |
| inputs = (torch.randn(2, 4), torch.randn(4, 7), torch.randn(2, 7)) |
| dim0_ac = torch.export.Dim("dim0_ac") |
| dim1_bc = torch.export.Dim("dim1_b") |
| dynamic_shapes = { |
| "a": {0: dim0_ac}, |
| "b": {1: dim1_bc}, |
| "c": {0: dim0_ac, 1: dim1_bc}, |
| } |
| exported_module = export( |
| DynamicShapeSimpleModel(), inputs, dynamic_shapes=dynamic_shapes |
| ).run_decompositions() |
| serialized = ExportedProgramSerializer().serialize(exported_module) |
| sym_size_nodes = [ |
| node |
| for node in serialized.exported_program.graph_module.graph.nodes |
| if node.target == "torch.ops.aten.sym_size.int" |
| ] |
| for node in sym_size_nodes: |
| self.assertEqual(node.inputs[0].name, "self") |
| self.assertEqual(node.inputs[1].name, "dim") |
| |
| def test_serialize_list_returns(self) -> None: |
| class MyModule(torch.nn.Module): |
| def __init__(self) -> None: |
| super().__init__() |
| |
| def forward(self, x): |
| return torch.split(x, 2) |
| |
| input = torch.arange(10.0).reshape(5, 2) |
| exported_module = export(MyModule(), (input,)).run_decompositions() |
| |
| serialized = ExportedProgramSerializer().serialize(exported_module) |
| node = serialized.exported_program.graph_module.graph.nodes[-1] |
| # split.Tensor gets decomposed to split_with_sizes by the core ATen decomposition table |
| self.assertEqual(node.target, "torch.ops.aten.split_with_sizes.default") |
| self.assertEqual(len(node.outputs), 1) |
| # Input looks like: |
| # tensor([[0, 1], |
| # [2, 3], |
| # [4, 5], |
| # [6, 7], |
| # [8, 9]]) |
| # Output looks like: |
| # (tensor([[0, 1], |
| # [2, 3]]), |
| # tensor([[4, 5], |
| # [6, 7]]), |
| # tensor([[8, 9]])) |
| self.assertEqual(len(node.outputs[0].as_tensors), 3) |
| |
| # check the names are unique |
| seen = set() |
| for output in node.outputs[0].as_tensors: |
| name = output.name |
| self.assertNotIn(name, seen) |
| seen.add(name) |
| |
| def test_multi_return_some_unused(self) -> None: |
| """ |
| Make sure the serialized output matches the op schema, even if some of |
| the arguments are never used in the graph. |
| """ |
| |
| class MyModule(torch.nn.Module): |
| def __init__(self) -> None: |
| super().__init__() |
| |
| def forward(self, x): |
| return torch.ops.aten.var_mean.correction(x, [1])[0] |
| |
| exported_module = export( |
| MyModule(), |
| (torch.ones([512, 512], requires_grad=True),), |
| ).run_decompositions() |
| |
| serialized = ExportedProgramSerializer().serialize(exported_module) |
| node = serialized.exported_program.graph_module.graph.nodes[-1] |
| self.assertEqual(node.target, "torch.ops.aten.var_mean.correction") |
| self.assertEqual(len(node.outputs), 2) |
| |
| # check the names are unique |
| seen = set() |
| for output in node.outputs: |
| name = output.as_tensor.name |
| self.assertNotIn(name, seen) |
| seen.add(name) |
| |
| def test_rational_ranges(self) -> None: |
| class M(torch.nn.Module): |
| def forward(self, x): |
| return x + x |
| |
| ep = torch.export.export( |
| M(), (torch.randn(4),), dynamic_shapes=({0: Dim("temp")},) |
| ) |
| |
| range_constraints = list(ep.range_constraints.keys()) |
| assert len(range_constraints) == 1 |
| symint = range_constraints[0] |
| |
| import sympy |
| |
| upper_range = sympy.Rational(10, 3) |
| lower_range = sympy.Rational(10, 6) |
| ep.range_constraints[symint] = ValueRanges(lower=lower_range, upper=upper_range) |
| |
| serialized = ExportedProgramSerializer().serialize(ep) |
| self.assertEqual(serialized.exported_program.range_constraints["s0"].min_val, 2) |
| self.assertEqual(serialized.exported_program.range_constraints["s0"].max_val, 3) |
| |
| def test_kwargs_default(self) -> None: |
| """ |
| Tests that the kwargs default values are serialized even if they are not |
| specified |
| """ |
| |
| class Foo(torch.nn.Module): |
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| values = torch.randn(3, 2) |
| return torch.searchsorted(x, values, side="right", right=True) |
| |
| f = Foo() |
| |
| x, _ = torch.sort(torch.randn(3, 4)) |
| exported_module = export(f, (x,)).run_decompositions() |
| serialized = ExportedProgramSerializer().serialize(exported_module) |
| |
| node = serialized.exported_program.graph_module.graph.nodes[-1] |
| self.assertEqual(node.target, "torch.ops.aten.searchsorted.Tensor") |
| self.assertEqual(len(node.inputs), 4) |
| self.assertEqual(node.inputs[2].name, "right") |
| self.assertEqual(node.inputs[2].arg.as_bool, True) |
| self.assertEqual(node.inputs[3].name, "side") |
| self.assertEqual(node.inputs[3].arg.as_string, "right") |
| |
| def test_canonicalize(self) -> None: |
| class Module(torch.nn.Module): |
| def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: |
| a = y + x |
| b = x + y |
| return b + a |
| |
| ep = torch.export.export(Module(), (torch.randn(3, 2), torch.randn(3, 2))) |
| s = ExportedProgramSerializer().serialize(ep) |
| c = canonicalize(s.exported_program) |
| g = c.graph_module.graph |
| self.assertLess( |
| g.nodes[0].inputs[0].arg.as_tensor.name, |
| g.nodes[1].inputs[0].arg.as_tensor.name, |
| ) |
| |
| def test_int_list(self) -> None: |
| class M(torch.nn.Module): |
| def forward(self, x): |
| return torch.ops.aten.sum.dim_IntList(x, []) |
| |
| ep = torch.export.export(M(), (torch.randn(3, 2),)) |
| serialized = ExportedProgramSerializer().serialize(ep) |
| for node in serialized.exported_program.graph_module.graph.nodes: |
| if "aten.sum.dim_IntList" in node.target: |
| self.assertEqual(node.inputs[1].arg.type, "as_ints") |
| |
| |
| @unittest.skipIf(IS_WINDOWS, "Windows not supported for this test") |
| @unittest.skipIf(not torchdynamo.is_dynamo_supported(), "dynamo doesn't support") |
| class TestDeserialize(TestCase): |
| def setUp(self): |
| super().setUp() |
| init_torchbind_implementations() |
| |
| def _check_graph_nodes(self, gm1, gm2, _check_meta=True): |
| # TODO: The _check_meta flag bypasses checking for |
| # source_fn/nn_module_stack as there is an issue with |
| # roundtripping the source_fn value on torch.ops.map nodes |
| # original source_fn: <functorch.experimental._map.MapWrapper object at 0x7f80a0549930> |
| # deserialized source_fn: 'functorch.experimental._map.map' |
| |
| self.assertEqual(len(gm1.graph.nodes), len(gm2.graph.nodes)) |
| |
| for node1, node2 in zip(gm1.graph.nodes, gm2.graph.nodes): |
| self.assertEqual(node1.op, node2.op) |
| if node1.op == "call_function": |
| # Check "val" metadata |
| val1 = node1.meta.get("val", None) |
| val2 = node2.meta.get("val", None) |
| if val1 is None or val2 is None: |
| # Either both are None |
| self.assertEqual(val1, val2) |
| elif isinstance(val1, FakeTensor) and isinstance(val2, FakeTensor): |
| # Or both are fake tensors with the same shape/dtype |
| self.assertEqual(len(val1.shape), len(val2.shape)) |
| for s1, s2 in zip(val1.shape, val2.shape): |
| if is_concrete_int(s1) and is_concrete_int(s2): |
| self.assertEqual(s1, s2) |
| else: |
| self.assertEqual(str(s1), str(s2)) |
| self.assertEqual(val1.dtype, val2.dtype) |
| elif isinstance(val1, (list, tuple)) and isinstance( |
| val2, (list, tuple) |
| ): |
| # Or both are fake tensors lists with one element and with the |
| # same shape/dtype |
| for v1, v2 in zip( |
| pytree.tree_leaves(val1), pytree.tree_leaves(val2) |
| ): |
| if isinstance(v1, FakeTensor): |
| self.assertEqual(v1.shape, v2.shape) |
| self.assertEqual(v1.dtype, v2.dtype) |
| else: |
| # For expressions like 's0 < 10' can only compare through string |
| self.assertEqual(str(val1), str(val2)) |
| |
| # Check "stack_trace" metadata |
| self.assertEqual( |
| node1.meta.get("stack_trace", None), |
| node2.meta.get("stack_trace", None), |
| ) |
| |
| if node1.target == torch.ops.higher_order.cond: |
| true_graph1 = getattr(gm1, node1.args[1].target) |
| true_graph2 = getattr(gm2, node2.args[1].target) |
| self._check_graph_nodes(true_graph1, true_graph2) |
| |
| false_graph1 = getattr(gm1, node1.args[2].target) |
| false_graph2 = getattr(gm2, node2.args[2].target) |
| self._check_graph_nodes(false_graph1, false_graph2) |
| elif node1.target == torch.ops.higher_order.map_impl: |
| map_graph1 = getattr(gm1, node1.args[0].target) |
| map_graph2 = getattr(gm2, node2.args[0].target) |
| self._check_graph_nodes(map_graph1, map_graph2, False) |
| |
| if _check_meta and node1.op not in ("get_attr", "placeholder", "output"): |
| # Check "nn_module_stack" metadata |
| self.assertEqual( |
| node1.meta.get("nn_module_stack", None), |
| node2.meta.get("nn_module_stack", None), |
| ) |
| # Check "source_fn_stack" metadata |
| self.assertEqual( |
| node1.meta.get("source_fn_stack", None), |
| node2.meta.get("source_fn_stack", None), |
| ) |
| |
| def check_graph( |
| self, |
| fn, |
| inputs, |
| dynamic_shapes=None, |
| _check_meta=True, |
| use_pre_dispatch=True, |
| strict=True, |
| ) -> None: |
| """Export a graph, serialize it, deserialize it, and compare the results.""" |
| |
| def _deepcopy_inputs(inputs): |
| # copy.deepcopy(deepcopy) can fail if tensor inputs have attribute (i.e. __dict__). |
| # we remove __dict__ when deepcopying. |
| dict_mapping = dict() |
| inputs_clone = () |
| for idx, i in enumerate(inputs): |
| if isinstance(i, torch.Tensor) and hasattr(inputs[0], "__dict__"): |
| dict_mapping[idx] = i.__dict__ |
| i.__dict__ = {} |
| inputs_clone += (copy.deepcopy(i),) |
| |
| # Add __dict__ back. |
| for k, v in dict_mapping.items(): |
| inputs[k].__dict__ = v |
| inputs_clone[k].__dict__ = v |
| return inputs_clone |
| |
| def _check_graph(pre_dispatch): |
| if pre_dispatch: |
| ep = torch.export._trace._export( |
| fn, |
| _deepcopy_inputs(inputs), |
| {}, |
| dynamic_shapes=dynamic_shapes, |
| pre_dispatch=True, |
| strict=strict, |
| ) |
| else: |
| ep = torch.export.export( |
| fn, |
| _deepcopy_inputs(inputs), |
| {}, |
| dynamic_shapes=dynamic_shapes, |
| strict=strict, |
| ) |
| ep.graph.eliminate_dead_code() |
| |
| serialized_artifact = serialize(ep, opset_version={"aten": 0}) |
| deserialized_ep = deserialize( |
| serialized_artifact, expected_opset_version={"aten": 0} |
| ) |
| deserialized_ep.graph.eliminate_dead_code() |
| |
| orig_outputs = ep.module()(*_deepcopy_inputs(inputs)) |
| loaded_outputs = deserialized_ep.module()(*_deepcopy_inputs(inputs)) |
| |
| flat_orig_outputs = pytree.tree_leaves(orig_outputs) |
| flat_loaded_outputs = pytree.tree_leaves(loaded_outputs) |
| |
| for orig, loaded in zip(flat_orig_outputs, flat_loaded_outputs): |
| self.assertEqual(type(orig), type(loaded)) |
| if isinstance(orig, torch.Tensor): |
| if orig.is_meta: |
| self.assertEqual(orig, loaded) |
| else: |
| self.assertTrue(torch.allclose(orig, loaded)) |
| else: |
| self.assertEqual(orig, loaded) |
| self._check_graph_nodes( |
| ep.graph_module, deserialized_ep.graph_module, _check_meta |
| ) |
| |
| if use_pre_dispatch: |
| _check_graph(pre_dispatch=True) |
| _check_graph(pre_dispatch=False) |
| else: |
| _check_graph(pre_dispatch=False) |
| |
| def test_optional_tuple(self): |
| with torch.library._scoped_library("mylib", "FRAGMENT") as lib: |
| torch.library.define( |
| "mylib::foo", |
| "(Tensor a, Tensor b, Tensor? c) -> (Tensor, Tensor?)", |
| tags=torch.Tag.pt2_compliant_tag, |
| lib=lib, |
| ) |
| |
| @torch.library.impl("mylib::foo", "cpu", lib=lib) |
| @torch.library.impl_abstract("mylib::foo") |
| def foo_impl(a, b, c): |
| res2 = None |
| if c is not None: |
| res2 = c + a + b |
| return a + b, res2 |
| |
| class M(torch.nn.Module): |
| def forward(self, a, b, c): |
| return torch.ops.mylib.foo(a, b, c) |
| |
| self.check_graph(M(), (torch.randn(3), torch.randn(3), torch.randn(3))) |
| |
| def test_auto_functionalize(self): |
| with torch.library._scoped_library("mylib", "FRAGMENT") as lib: |
| torch.library.define( |
| "mylib::foo1", |
| "(Tensor(a!) x, Tensor[] y, Tensor(b!) z, SymInt w, Tensor n) -> Tensor", |
| tags=torch.Tag.pt2_compliant_tag, |
| lib=lib, |
| ) |
| torch.library.define( |
| "mylib::foo2", |
| "(Tensor(a!) x, Tensor[] y, Tensor(b!) z, SymInt w, Tensor n) -> (Tensor, Tensor)", |
| tags=torch.Tag.pt2_compliant_tag, |
| lib=lib, |
| ) |
| torch.library.define( |
| "mylib::foo3", |
| "(Tensor(a!) x, Tensor[] y, Tensor(b!) z, SymInt w, Tensor n) -> ()", |
| tags=torch.Tag.pt2_compliant_tag, |
| lib=lib, |
| ) |
| |
| @torch.library.impl("mylib::foo1", "cpu", lib=lib) |
| @torch.library.impl_abstract("mylib::foo1") |
| def foo1_impl(x, y, z, w, n): |
| x.add_(y[0] + w) |
| z.add_(y[1] + n) |
| return n + n |
| |
| @torch.library.impl("mylib::foo2", "cpu", lib=lib) |
| @torch.library.impl_abstract("mylib::foo2") |
| def foo2_impl(x, y, z, w, n): |
| x.add_(y[0] + w) |
| z.add_(y[1] + n) |
| return (n + n, n * n) |
| |
| @torch.library.impl("mylib::foo3", "cpu", lib=lib) |
| @torch.library.impl_abstract("mylib::foo3") |
| def foo3_impl(x, y, z, w, n): |
| x.add_(y[0] + w) |
| z.add_(y[1] + n) |
| return |
| |
| class M(torch.nn.Module): |
| def forward(self, x, y, z, n): |
| n = torch.ops.mylib.foo1(x, y, z, 2, n) |
| torch.ops.mylib.foo3(x, y, z, 2, n) |
| return torch.ops.mylib.foo2(x, y, z, 2, n) |
| |
| x = torch.randn(3) |
| y = (torch.randn(3), torch.randn(3)) |
| z = torch.randn(3) |
| n = torch.randn(3) |
| orig_args = (x, y, z, n) |
| |
| # TODO Auto_functionalize is not supported on pre_dispatch IR |
| self.check_graph(M(), orig_args, use_pre_dispatch=False) |
| |
| def test_multi_return(self) -> None: |
| """ |
| Test multiple return from a single node (ex. layer_norm has 2 outputs) |
| """ |
| |
| class MyModule(torch.nn.Module): |
| def __init__(self) -> None: |
| super().__init__() |
| |
| def forward(self, x, w, b): |
| return torch.nn.functional.layer_norm( |
| x, |
| x.size()[1:], |
| weight=w, |
| bias=b, |
| eps=1e-5, |
| ) |
| |
| inputs = ( |
| torch.ones([512, 512], requires_grad=True), |
| torch.ones([512]), |
| torch.ones([512]), |
| ) |
| self.check_graph(MyModule(), inputs) |
| |
| def test_basic(self) -> None: |
| class MyModule(torch.nn.Module): |
| def __init__(self) -> None: |
| super().__init__() |
| |
| def forward(self, x): |
| x = x + x |
| x = x * x |
| x = x / x |
| return x, x.clone() |
| |
| inputs = (torch.ones([512], requires_grad=True),) |
| self.check_graph(MyModule(), inputs) |
| |
| def test_dynamic(self) -> None: |
| class DynamicShapeSimpleModel(torch.nn.Module): |
| def __init__(self) -> None: |
| super().__init__() |
| |
| def forward(self, a, b, c) -> torch.Tensor: |
| d = (torch.matmul(a, b) + c) / 2 |
| d_s0 = d.shape[0] |
| d_s1 = d.shape[1] |
| d_s3 = d_s0 * d_s1 |
| e = d.view(d_s3) |
| return torch.cat([e, e]) |
| |
| inputs = (torch.randn(2, 4), torch.randn(4, 7), torch.randn(2, 7)) |
| dim0_ac = torch.export.Dim("dim0_ac") |
| dynamic_shapes = {"a": {0: dim0_ac}, "b": None, "c": {0: dim0_ac}} |
| self.check_graph(DynamicShapeSimpleModel(), inputs, dynamic_shapes) |
| |
| def test_sym_bool(self): |
| class Module(torch.nn.Module): |
| def forward(self, x, y): |
| assert x.size(0) in y |
| return x + y |
| |
| f = Module() |
| self.check_graph(f, (torch.ones(1), torch.ones(3))) |
| |
| def test_shape(self): |
| class Foo(torch.nn.Module): |
| def forward(self, x): |
| z, y = x.size() |
| return z + y + x[0], z |
| |
| inputs = (torch.ones(2, 3),) |
| dim0_x, dim1_x = torch.export.dims("dim0_x", "dim1_x") |
| dynamic_shapes = {"x": (dim0_x, dim1_x)} |
| self.check_graph(Foo(), inputs, dynamic_shapes) |
| |
| def test_module(self): |
| class M(torch.nn.Module): |
| def __init__(self) -> None: |
| super().__init__() |
| self.linear1 = torch.nn.Linear(3, 3) |
| self.relu = torch.nn.ReLU() |
| self.linear2 = torch.nn.Linear(3, 5) |
| |
| def forward(self, x): |
| x = self.linear1(x) |
| x = self.linear1(x) |
| x = torch.nn.functional.relu(x) |
| x = self.linear2(x) |
| return x |
| |
| inputs = (torch.randn(3, 3),) |
| self.check_graph(M(), inputs) |
| |
| def test_module_meta(self): |
| class M(torch.nn.Module): |
| def __init__(self) -> None: |
| super().__init__() |
| self.p = torch.nn.Parameter(torch.ones(3, 3)) |
| |
| def forward(self, x): |
| return self.p + x |
| |
| with torch.device("meta"): |
| mod = M() |
| |
| inputs = (torch.randn(3, 3, device="meta"),) |
| self.check_graph(mod, inputs) |
| |
| def test_cond(self): |
| from functorch.experimental.control_flow import cond |
| |
| inputs = torch.ones(4, 3), torch.zeros(4, 3) |
| |
| class M(torch.nn.Module): |
| def forward(self, x, y): |
| def t(x, y): |
| return x + y |
| |
| def f(x, y): |
| return x - y |
| |
| return cond(x[0][0] > 4, t, f, [x, y]) |
| |
| self.check_graph(M(), inputs) |
| |
| def test_map(self): |
| from functorch.experimental import control_flow |
| |
| def f(x, y): |
| return x + y |
| |
| class Module(torch.nn.Module): |
| def forward(self, xs, y): |
| return control_flow.map(f, xs, y) |
| |
| g = Module() |
| inputs = (torch.ones(3, 2, 2), torch.ones(2)) |
| self.check_graph(g, inputs, _check_meta=False) |
| |
| def test_tensor_tensor_list(self): |
| with torch.library._scoped_library("_export", "FRAGMENT") as lib: |
| lib.define( |
| "_test_tensor_tensor_list_output(Tensor x, Tensor y) -> (Tensor, Tensor[])", |
| tags=torch.Tag.pt2_compliant_tag, |
| ) |
| |
| def _test_tensor_tensor_list_output(x, y): |
| return y, [x] |
| |
| lib.impl( |
| "_test_tensor_tensor_list_output", |
| _test_tensor_tensor_list_output, |
| "CPU", |
| ) |
| lib.impl( |
| "_test_tensor_tensor_list_output", |
| _test_tensor_tensor_list_output, |
| "Meta", |
| ) |
| |
| class M(torch.nn.Module): |
| def forward(self, x, y): |
| a, b = torch.ops._export._test_tensor_tensor_list_output.default( |
| x, y |
| ) |
| return a + b[0] |
| |
| self.check_graph(M(), (torch.rand(3, 2), torch.rand(3, 2))) |
| |
| def test_list_of_optional_tensors(self) -> None: |
| class MyModule(torch.nn.Module): |
| def __init__(self) -> None: |
| super().__init__() |
| |
| def forward(self, x, y, z): |
| indices = [None, None, torch.tensor([1, 3, 5, 7])] |
| indexed = torch.ops.aten.index.Tensor(x + y, indices) |
| return indexed + z |
| |
| inputs = (torch.rand(8, 8, 8), torch.rand(8, 8, 8), torch.rand(8, 8, 4)) |
| self.check_graph(MyModule(), inputs) |
| |
| def test_sym_ite(self): |
| class Foo(torch.nn.Module): |
| def forward(self, x): |
| b = x.shape[0] == 5 |
| ret = torch.sym_ite(b, x.shape[0], x.shape[1]) |
| return ret |
| |
| dynamic_shapes = {"x": {0: Dim("dim0"), 1: Dim("dim1")}} |
| self.check_graph(Foo(), (torch.ones(4, 5),), dynamic_shapes=dynamic_shapes) |
| |
| def test_multiple_getitem(self): |
| class M(torch.nn.Module): |
| def forward(self, x): |
| a, b = torch.topk(x, 2) |
| a = a * 2 |
| return a, b |
| |
| ep = torch.export.export(M(), (torch.ones(3),)) |
| |
| # insert another getitem node |
| for node in ep.graph.nodes: |
| if node.op == "call_function" and node.target == torch.ops.aten.mul.Tensor: |
| getitem_0 = node.args[0] |
| with ep.graph.inserting_before(getitem_0): |
| getitem_copy = ep.graph.node_copy(getitem_0) |
| mul_node = ep.graph.call_function( |
| torch.ops.aten.mul.Tensor, (getitem_copy, 2) |
| ) |
| mul_node.meta = copy.copy(getitem_copy.meta) |
| node.args = (getitem_0, mul_node) |
| |
| deserialized_ep = deserialize(serialize(ep)) |
| |
| inp = (torch.randn(3),) |
| orig_res = ep.module()(*inp) |
| res = deserialized_ep.module()(*inp) |
| self.assertTrue(torch.allclose(orig_res[0], res[0])) |
| self.assertTrue(torch.allclose(orig_res[1], res[1])) |
| |
| # The deserialized graph should have deduped getitem calls |
| self.assertExpectedInline( |
| deserialized_ep.graph_module.code.strip("\n"), |
| """\ |
| def forward(self, x): |
| topk_default = torch.ops.aten.topk.default(x, 2); x = None |
| getitem = topk_default[0] |
| getitem_1 = topk_default[1]; topk_default = None |
| mul_tensor = torch.ops.aten.mul.Tensor(getitem, 2) |
| mul = torch.ops.aten.mul.Tensor(getitem, mul_tensor); getitem = mul_tensor = None |
| return (mul, getitem_1) |
| """, |
| ) |
| |
| @parametrize( |
| "name,case", |
| get_filtered_export_db_tests(), |
| name_fn=lambda name, case: f"case_{name}", |
| ) |
| def test_exportdb_supported(self, name: str, case: ExportCase) -> None: |
| model = case.model |
| _check_meta = "map" not in name |
| self.check_graph(model, case.example_args, _check_meta=_check_meta) |
| |
| def test_constraints(self): |
| class Module(torch.nn.Module): |
| def forward(self, x, y): |
| n = x.item() |
| torch._check_is_size(n) |
| return y.sum() + torch.ones(n, 5).sum() |
| |
| f = Module() |
| self.check_graph(f, (torch.tensor(3), torch.randn(4, 5))) |
| |
| def test_get_attr(self) -> None: |
| class Module(torch.nn.Module): |
| def forward(self, x): |
| return x + torch.tensor(3) |
| |
| f = Module() |
| self.check_graph(f, (torch.tensor(3),)) |
| |
| def test_get_attr_list(self) -> None: |
| class Module(torch.nn.Module): |
| def forward(self, x): |
| return torch.cat([x, torch.tensor([1, 1])]) |
| |
| f = Module() |
| self.check_graph(f, (torch.tensor([1, 1]),)) |
| |
| @unittest.skipIf(not torch.cuda.is_available(), "Requires cuda") |
| def test_device(self) -> None: |
| class MyModule(torch.nn.Module): |
| def __init__(self) -> None: |
| super().__init__() |
| self.conv = torch.nn.Conv2d(3, 16, 3, stride=1, bias=True) |
| self.relu = torch.nn.ReLU() |
| |
| def forward(self, x): |
| conv = self.conv(x) |
| relu = self.relu(conv) |
| mul = relu * 0.5 |
| return mul |
| |
| inp = torch.randn((1, 3, 224, 224), dtype=torch.float).to("cuda") |
| model = MyModule().eval().cuda() |
| self.check_graph(model, (inp,)) |
| |
| def test_custom_obj_tuple_out(self): |
| class MyModule(torch.nn.Module): |
| def __init__(self) -> None: |
| super().__init__() |
| self.attr = torch.classes._TorchScriptTesting._Foo(10, 20) |
| |
| def forward(self, x): |
| a = torch.ops._TorchScriptTesting.takes_foo_tuple_return(self.attr, x) |
| y = a[0] + a[1] |
| b = torch.ops._TorchScriptTesting.takes_foo(self.attr, y) |
| return x + b |
| |
| m = MyModule() |
| inputs = (torch.ones(2, 3),) |
| self.check_graph(m, inputs, strict=False) |
| |
| def test_custom_obj(self): |
| class MyModule(torch.nn.Module): |
| def __init__(self) -> None: |
| super().__init__() |
| self.attr = torch.classes._TorchScriptTesting._Foo(10, 20) |
| |
| def forward(self, x): |
| a = torch.ops._TorchScriptTesting.takes_foo(self.attr, x) |
| b = torch.ops._TorchScriptTesting.takes_foo(self.attr, a) |
| return x + b |
| |
| m = MyModule() |
| inputs = (torch.ones(2, 3),) |
| self.check_graph(m, inputs, strict=False) |
| |
| def test_custom_obj_list_out(self): |
| class MyModule(torch.nn.Module): |
| def __init__(self) -> None: |
| super().__init__() |
| self.attr = torch.classes._TorchScriptTesting._Foo(10, 20) |
| |
| def forward(self, x): |
| a = torch.ops._TorchScriptTesting.takes_foo_list_return(self.attr, x) |
| y = a[0] + a[1] + a[2] |
| b = torch.ops._TorchScriptTesting.takes_foo(self.attr, y) |
| return x + b |
| |
| m = MyModule() |
| inputs = (torch.ones(2, 3),) |
| self.check_graph(m, inputs, strict=False) |
| |
| def test_export_no_inputs(self): |
| class M(torch.nn.Module): |
| def __init__(self) -> None: |
| super().__init__() |
| self.p = torch.ones(3, 3) |
| |
| def forward(self): |
| return self.p * self.p |
| |
| ep = torch.export.export(M(), ()) |
| ep._example_inputs = None |
| roundtrip_ep = deserialize(serialize(ep)) |
| self.assertTrue(torch.allclose(ep.module()(), roundtrip_ep.module()())) |
| |
| |
| instantiate_parametrized_tests(TestDeserialize) |
| |
| |
| @unittest.skipIf(not torchdynamo.is_dynamo_supported(), "dynamo doesn't support") |
| class TestSchemaVersioning(TestCase): |
| def test_error(self): |
| class Module(torch.nn.Module): |
| def forward(self, x): |
| return x + x |
| |
| f = Module() |
| ep = export(f, (torch.randn(1, 3),)) |
| |
| serialized_program = ExportedProgramSerializer().serialize(ep) |
| serialized_program.exported_program.schema_version.major = -1 |
| with self.assertRaisesRegex( |
| SerializeError, r"Serialized schema version .* does not match our current" |
| ): |
| ExportedProgramDeserializer().deserialize( |
| serialized_program.exported_program, |
| serialized_program.state_dict, |
| serialized_program.constants, |
| serialized_program.example_inputs, |
| ) |
| |
| |
| # We didn't set up kwargs input yet |
| unittest.expectedFailure(TestDeserialize.test_exportdb_supported_case_fn_with_kwargs) |
| |
| |
| @unittest.skipIf(not torchdynamo.is_dynamo_supported(), "dynamo doesn't support") |
| class TestSaveLoad(TestCase): |
| def test_save_buffer(self): |
| inp = (torch.tensor([0.1, 0.1]),) |
| |
| class Module(torch.nn.Module): |
| def __init__(self) -> None: |
| super().__init__() |
| self.linear = torch.nn.Linear(2, 2) |
| |
| def forward(self, x): |
| x = x + 1 |
| y = x.t() |
| y = y.relu() |
| y = self.linear(y) |
| return y |
| |
| ep = export(Module(), inp) |
| |
| buffer = io.BytesIO() |
| save(ep, buffer) |
| buffer.seek(0) |
| loaded_ep = load(buffer) |
| |
| self.assertTrue(torch.allclose(ep.module()(*inp), loaded_ep.module()(*inp))) |
| |
| def test_save_file(self): |
| class Foo(torch.nn.Module): |
| def forward(self, x): |
| return x * x |
| |
| f = Foo() |
| |
| inp = (torch.randn(2, 2),) |
| ep = export(f, inp) |
| |
| with tempfile.NamedTemporaryFile() as f: |
| save(ep, f) |
| f.seek(0) |
| loaded_ep = load(f) |
| |
| self.assertTrue(torch.allclose(ep.module()(*inp), loaded_ep.module()(*inp))) |
| |
| def test_save_path(self): |
| class Foo(torch.nn.Module): |
| def forward(self, x, y): |
| return x + y |
| |
| f = Foo() |
| |
| inp = (torch.tensor([6]), torch.tensor([7])) |
| ep = export(f, inp) |
| |
| with TemporaryFileName() as fname: |
| path = Path(fname) |
| save(ep, path) |
| loaded_ep = load(path) |
| |
| self.assertTrue(torch.allclose(ep.module()(*inp), loaded_ep.module()(*inp))) |
| |
| def test_save_extra(self): |
| inp = (torch.tensor([0.1, 0.1]),) |
| |
| class Foo(torch.nn.Module): |
| def forward(self, x): |
| return x * x + x |
| |
| f = Foo() |
| |
| ep = export(f, inp) |
| |
| buffer = io.BytesIO() |
| save(ep, buffer, extra_files={"extra.txt": "moo"}) |
| buffer.seek(0) |
| extra_files = {"extra.txt": ""} |
| loaded_ep = load(buffer, extra_files=extra_files) |
| |
| self.assertTrue(torch.allclose(ep.module()(*inp), loaded_ep.module()(*inp))) |
| self.assertEqual(extra_files["extra.txt"], "moo") |
| |
| def test_version_error(self): |
| class Foo(torch.nn.Module): |
| def forward(self, x): |
| return x + x |
| |
| f = Foo() |
| |
| ep = export(f, (torch.randn(1, 3),)) |
| |
| with tempfile.NamedTemporaryFile() as f: |
| save(ep, f) |
| f.seek(0) |
| |
| # Modify the version |
| with zipfile.ZipFile(f, "a") as zipf: |
| zipf.writestr("version", "-1.1") |
| |
| with self.assertRaisesRegex( |
| RuntimeError, r"Serialized version .* does not match our current" |
| ): |
| f.seek(0) |
| load(f) |
| |
| def test_save_constants(self): |
| class Foo(torch.nn.Module): |
| def __init__(self) -> None: |
| super().__init__() |
| self.a = torch.tensor(3) |
| |
| def forward(self, x): |
| list_tensor = [torch.tensor(3), torch.tensor(4)] |
| return x + self.a + list_tensor[0] + list_tensor[1] |
| |
| ep = export(Foo(), (torch.tensor(1),)) |
| buffer = io.BytesIO() |
| save(ep, buffer) |
| buffer.seek(0) |
| loaded_ep = load(buffer) |
| |
| inp = (torch.tensor(1),) |
| self.assertTrue(torch.allclose(ep.module()(*inp), loaded_ep.module()(*inp))) |
| |
| |
| @unittest.skipIf(not torchdynamo.is_dynamo_supported(), "dynamo doesn't support") |
| class TestSerializeCustomClass(TestCase): |
| def setUp(self): |
| super().setUp() |
| init_torchbind_implementations() |
| |
| def test_custom_class(self): |
| custom_obj = torch.classes._TorchScriptTesting._PickleTester([3, 4]) |
| |
| class Foo(torch.nn.Module): |
| def forward(self, x): |
| return x + x |
| |
| f = Foo() |
| |
| inputs = (torch.zeros(4, 4),) |
| ep = export(f, inputs) |
| |
| # Replace one of the values with an instance of our custom class |
| for node in ep.graph.nodes: |
| if node.op == "call_function" and node.target == torch.ops.aten.add.Tensor: |
| with ep.graph.inserting_before(node): |
| custom_node = ep.graph.call_function( |
| torch.ops._TorchScriptTesting.take_an_instance.default, |
| (custom_obj,), |
| ) |
| custom_node.meta["val"] = torch.ones(4, 4) |
| custom_node.meta["torch_fn"] = ( |
| "take_an_instance", |
| "take_an_instance", |
| ) |
| arg0, _ = node.args |
| node.args = (arg0, custom_node) |
| |
| serialized_vals = serialize(ep) |
| |
| ep_str = serialized_vals.exported_program.decode("utf-8") |
| assert "class_fqn" in ep_str |
| assert custom_obj._type().qualified_name() in ep_str |
| |
| deserialized_ep = deserialize(serialized_vals) |
| |
| for node in deserialized_ep.graph.nodes: |
| if ( |
| node.op == "call_function" |
| and node.target |
| == torch.ops._TorchScriptTesting.take_an_instance.default |
| ): |
| arg = node.args[0] |
| self.assertTrue(isinstance(arg, torch._C.ScriptObject)) |
| self.assertEqual(arg._type(), custom_obj._type()) |
| self.assertEqual(arg.__getstate__(), custom_obj.__getstate__()) |
| self.assertEqual(arg.top(), 7) |
| |
| def test_custom_class_containing_fake_tensor(self): |
| class Foo(torch.nn.Module): |
| def __init__(self) -> None: |
| super().__init__() |
| self.custom_obj = torch.classes._TorchScriptTesting._ContainsTensor( |
| torch.rand(2, 3) |
| ) |
| |
| def forward(self, x): |
| return x + self.custom_obj.get() |
| |
| with FakeTensorMode(): |
| f = Foo() |
| |
| inputs = (torch.zeros(2, 3),) |
| with enable_torchbind_tracing(): |
| ep = export(f, inputs, strict=False) |
| |
| serialized_vals = serialize(ep) |
| ep = deserialize(serialized_vals) |
| self.assertTrue(isinstance(ep.constants["custom_obj"].get(), FakeTensor)) |
| |
| def test_custom_tag_metadata_serialization(self): |
| class Foo(torch.nn.Module): |
| def forward(self, x): |
| return x + x |
| |
| f = Foo() |
| |
| inputs = (torch.zeros(4, 4),) |
| ep = export(f, inputs) |
| |
| new_gm = copy.deepcopy(ep.graph_module) |
| new_gm.meta["custom"] = {} |
| new_gm.meta["custom"]["f"] = "bar" |
| |
| for node in new_gm.graph.nodes: |
| if node.op == "call_function" and node.target == torch.ops.aten.add.Tensor: |
| node.meta["custom"] = {} |
| node.meta["custom"]["quantization_tag"] = "foo" |
| |
| new_ep = ep._update(new_gm, ep.graph_signature) |
| serialized_vals = serialize(new_ep) |
| new_ep = deserialize(serialized_vals) |
| |
| self.assertEqual(new_ep.graph_module.meta["custom"]["f"], "bar") |
| counter = 0 |
| for node in new_ep.graph.nodes: |
| if node.op == "call_function" and node.target == torch.ops.aten.add.Tensor: |
| counter += 1 |
| self.assertTrue(node.meta["custom"]["quantization_tag"] == "foo") |
| self.assertEqual(counter, 1) |
| |
| def test_custom_tag_metadata_decomp(self): |
| class Foo(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.linear = torch.nn.Linear(2, 2) |
| |
| def forward(self, x): |
| return self.linear(x) |
| |
| f = Foo() |
| |
| inputs = (torch.ones(2, 2),) |
| ep = export(f, inputs) |
| |
| new_gm = copy.deepcopy(ep.graph_module) |
| new_gm.meta["custom"] = {} |
| new_gm.meta["custom"]["f"] = "bar" |
| |
| counter = 0 |
| for node in new_gm.graph.nodes: |
| if ( |
| node.op == "call_function" |
| and node.target == torch.ops.aten.linear.default |
| ): |
| counter += 1 |
| node.meta["custom"] = {} |
| node.meta["custom"]["quantization_tag"] = "foo" |
| self.assertEqual(counter, 1) |
| |
| new_ep = ep._update(new_gm, ep.graph_signature) |
| new_ep = new_ep.run_decompositions() |
| |
| self.assertEqual(new_ep.graph_module.meta["custom"]["f"], "bar") |
| counter = 0 |
| for node in new_ep.graph.nodes: |
| if node.op == "call_function": |
| counter += 1 |
| self.assertTrue(node.meta["custom"]["quantization_tag"] == "foo") |
| self.assertTrue(counter > 1) |
| |
| # TODO For some reason, this doesn't work on Windows ONLY. |
| # def test_custom_tag_metadata_reexport(self): |
| # class Foo(torch.nn.Module): |
| # def forward(self, x): |
| # return x + x |
| # |
| # f = Foo() |
| # |
| # inputs = (torch.zeros(4, 4),) |
| # ep = export(f, inputs) |
| # |
| # new_gm = copy.deepcopy(ep.graph_module) |
| # new_gm.meta["custom"] = {} |
| # new_gm.meta["custom"]["f"] = "bar" |
| # |
| # for node in new_gm.graph.nodes: |
| # if node.op == "call_function" and node.target == torch.ops.aten.add.Tensor: |
| # node.meta["custom"] = {} |
| # node.meta["custom"]["quantization_tag"] = "foo" |
| # |
| # new_ep = ep._update(new_gm, ep.graph_signature) |
| # new_ep = torch.export.export(new_ep.module(), inputs) |
| # |
| # self.assertEqual(new_ep.graph_module.meta["custom"]["f"], "bar") |
| # counter = 0 |
| # for node in new_ep.graph.nodes: |
| # if node.op == "call_function" and node.target == torch.ops.aten.add.Tensor: |
| # counter += 1 |
| # self.assertTrue(node.meta["custom"]["quantization_tag"] == "foo") |
| # self.assertEqual(counter, 1) |
| |
| def test_custom_tag_metadata_copy(self): |
| class Foo(torch.nn.Module): |
| def forward(self, x): |
| return x + x |
| |
| f = Foo() |
| |
| inputs = (torch.zeros(4, 4),) |
| ep = export(f, inputs) |
| |
| new_gm = copy.deepcopy(ep.graph_module) |
| new_gm.meta["custom"] = {} |
| new_gm.meta["custom"]["f"] = "bar" |
| |
| for node in new_gm.graph.nodes: |
| if node.op == "call_function" and node.target == torch.ops.aten.add.Tensor: |
| node.meta["custom"] = {} |
| node.meta["custom"]["quantization_tag"] = "foo" |
| |
| new_gm = copy.deepcopy(new_gm) |
| |
| self.assertEqual(new_gm.meta["custom"]["f"], "bar") |
| counter = 0 |
| for node in new_gm.graph.nodes: |
| if node.op == "call_function" and node.target == torch.ops.aten.add.Tensor: |
| counter += 1 |
| self.assertTrue(node.meta["custom"]["quantization_tag"] == "foo") |
| self.assertEqual(counter, 1) |
| |
| |
| if __name__ == "__main__": |
| run_tests() |