| # Owner(s): ["module: fx"] |
| |
| import functools |
| import math |
| import numbers |
| import operator |
| import pickle |
| import sys |
| import sympy |
| import tempfile |
| import unittest |
| from types import BuiltinFunctionType |
| from typing import Callable, Dict, List, NamedTuple, Optional, Tuple, Union |
| |
| import torch |
| import torch.fx.experimental.meta_tracer |
| import torch.fx.experimental.optimization as optimization |
| from torch.fx._symbolic_trace import symbolic_trace |
| from torch.fx.experimental import merge_matmul |
| from torch.fx.experimental.accelerator_partitioner import Partitioner |
| from torch.fx.experimental.normalize import NormalizeArgs, NormalizeOperators |
| from torch.fx.experimental.partitioner_utils import ( |
| Device, |
| get_latency_of_partitioned_graph, |
| get_partition_to_latency_mapping, |
| NodeLatency, |
| PartitionerConfig, |
| PartitionMode, |
| ) |
| from torch.fx.experimental.rewriter import RewritingTracer |
| from torch.fx.experimental.schema_type_annotation import AnnotateTypesWithSchema |
| from torch.fx.graph_module import GraphModule |
| from torch.fx.node import Node |
| from torch.fx.operator_schemas import ( |
| _torchscript_type_to_python_type, |
| create_type_hint, |
| normalize_function, |
| normalize_module, |
| type_matches, |
| ) |
| from torch.fx.passes import graph_manipulation |
| from torch.fx.passes.param_fetch import lift_lowering_attrs_to_nodes |
| from torch.fx.passes.shape_prop import ShapeProp |
| from torch.fx.passes.split_module import split_module |
| from torch.fx.passes.annotate_getitem_nodes import annotate_getitem_nodes |
| from torch.testing._internal.common_device_type import ( |
| instantiate_device_type_tests, |
| onlyCPU, |
| ops, |
| ) |
| from torch.testing._internal.common_methods_invocations import op_db |
| from torch.testing._internal.common_nn import module_tests, new_module_tests |
| from torch.testing._internal.common_utils import TEST_Z3, run_tests, TestCase |
| from torch.testing._internal.jit_utils import JitTestCase |
| import torch.utils._pytree as pytree |
| |
| try: |
| import torchvision.models |
| from torchvision.models import resnet18 |
| |
| HAS_TORCHVISION = True |
| except ImportError: |
| HAS_TORCHVISION = False |
| skipIfNoTorchVision = unittest.skipIf(not HAS_TORCHVISION, "no torchvision") |
| skipIfNoMkldnn = unittest.skipIf( |
| not (torch.backends.mkldnn.enabled and torch.backends.mkldnn.is_available()), |
| "no MKLDNN", |
| ) |
| |
| |
| def symbolic_trace_with_rewrite(root: Union[torch.nn.Module, Callable]) -> GraphModule: |
| return GraphModule( |
| root if isinstance(root, torch.nn.Module) else torch.nn.Module(), |
| RewritingTracer().trace(root), |
| ) |
| |
| |
| class TestFXExperimental(JitTestCase): |
| def test_find_single_partition(self): |
| class TestModule(torch.nn.Module): |
| def forward(self, a, b): |
| return a + b |
| |
| m = TestModule() |
| traced = symbolic_trace(m) |
| a = torch.rand(1) |
| b = torch.rand(1) |
| graph_manipulation.get_size_of_all_nodes(traced, [a, b]) |
| partitioner = Partitioner() |
| devices = [ |
| Device("dev_0", 125, 0), |
| Device("dev_1", 150, 1), |
| Device("dev_2", 125, 2), |
| ] |
| partitioner_config = PartitionerConfig(devices) |
| ret = partitioner.partition_graph(traced, m, partitioner_config) |
| module_with_submodules = ret.module_with_submodules |
| dag = ret.dag |
| self.assertEqual(traced(a, b), module_with_submodules(a, b)) |
| assert dag.nodes[0].logical_device_ids == [1] |
| |
| def test_lack_of_devices(self): |
| class TestModule(torch.nn.Module): |
| def forward(self, a, b): |
| return a + b |
| |
| m = TestModule() |
| traced = symbolic_trace(m) |
| a = torch.rand(4) |
| b = torch.rand(4) |
| graph_manipulation.get_size_of_all_nodes(traced, [a, b]) |
| partitioner = Partitioner() |
| devices = [Device("dev_0", 4, 0), Device("dev_1", 4, 1)] |
| partitioner_config = PartitionerConfig(devices, PartitionMode.size_based) |
| catch_runtime_error = False |
| try: |
| ret = partitioner.partition_graph(traced, m, partitioner_config) |
| except RuntimeError: |
| catch_runtime_error = True |
| assert catch_runtime_error |
| |
| def test_large_node_error(self): |
| class TestModule(torch.nn.Module): |
| def __init__(self) -> None: |
| super().__init__() |
| self.linear = torch.nn.Linear(4, 4) |
| |
| def forward(self, a): |
| linear = self.linear(a) |
| add = linear + a |
| return add |
| |
| m = TestModule() |
| traced = symbolic_trace(m) |
| a = torch.rand(4) |
| graph_manipulation.get_size_of_all_nodes(traced, [a]) |
| partitioner = Partitioner() |
| devices = [ |
| Device("dev_0", 40, 0), |
| Device("dev_1", 40, 0), |
| Device("dev_2", 40, 0), |
| Device("dev_3", 40, 0), |
| Device("dev_4", 40, 0), |
| ] |
| partitioner_config = PartitionerConfig(devices, PartitionMode.size_based) |
| catch_runtime_error = False |
| try: |
| ret = partitioner.partition_graph(traced, m, partitioner_config) |
| except RuntimeError: |
| catch_runtime_error = True |
| assert catch_runtime_error |
| |
| def test_partition_node_manipulation(self): |
| class TestModule(torch.nn.Module): |
| def forward(self, a, b): |
| add_1 = a + b |
| add_2 = add_1 + torch.rand(4) |
| add_3 = add_2 + torch.rand(4) |
| return add_3 |
| |
| m = TestModule() |
| traced = symbolic_trace(m) |
| a, b = torch.rand(4), torch.rand(4) |
| graph_manipulation.get_size_of_all_nodes(traced, [a, b]) |
| partitioner = Partitioner() |
| devices = [Device("dev_0", 1000, 0)] |
| partitioner_config = PartitionerConfig(devices) |
| ret = partitioner.partition_graph(traced, m, partitioner_config) |
| partition = partitioner.partitions[0] |
| assert partition.used_mem_bytes == 112 |
| # Select add_2 node to remove |
| selected_node = None |
| for node in partition.nodes: |
| if node.name == "add_2": |
| selected_node = node |
| partition.remove_node(selected_node) |
| assert partition.used_mem_bytes == 80 |
| |
| def test_size_based_partition(self): |
| class TestModule(torch.nn.Module): |
| def __init__(self) -> None: |
| super().__init__() |
| self.linear = torch.nn.Linear(4, 4) |
| self.c = torch.rand(4) |
| |
| def forward(self, a, b): |
| add_1 = a + b |
| linear = self.linear(add_1) |
| add_2 = linear + self.c |
| return add_2 |
| |
| m = TestModule() |
| traced = symbolic_trace(m) |
| a = torch.rand(4) |
| b = torch.rand(4) |
| graph_manipulation.get_size_of_all_nodes(traced, [a, b]) |
| partitioner = Partitioner() |
| devices = [ |
| Device("dev_0", 125, 0), |
| Device("dev_1", 125, 1), |
| Device("dev_2", 125, 2), |
| ] |
| partitioner_config = PartitionerConfig(devices, PartitionMode.size_based) |
| ret = partitioner.partition_graph(traced, m, partitioner_config) |
| module_with_submodules = ret.module_with_submodules |
| dag = ret.dag |
| self.assertEqual(traced(a, b), module_with_submodules(a, b)) |
| for i, node in enumerate(dag.nodes): |
| assert node.logical_device_ids == [i] |
| |
| def test_partition_device_mapping(self): |
| class TestModule(torch.nn.Module): |
| def __init__(self) -> None: |
| super().__init__() |
| self.linear = torch.nn.Linear(4, 4) |
| |
| def forward(self, a): |
| b = torch.rand(4) |
| add_1 = a + b |
| linear_1 = self.linear(add_1) |
| add_2 = torch.rand(4) + a |
| add_3 = add_2 + linear_1 |
| return add_3 |
| |
| m = TestModule() |
| traced = symbolic_trace(m) |
| a = torch.rand(4) |
| graph_manipulation.get_size_of_all_nodes(traced, [a]) |
| partitioner = Partitioner() |
| devices = [Device("dev_0", 120, 0), Device("dev_1", 160, 1)] |
| partitioner_config = PartitionerConfig(devices, PartitionMode.size_based) |
| ret = partitioner.partition_graph(traced, m, partitioner_config) |
| module_with_submodules = ret.module_with_submodules |
| dag = ret.dag |
| self.assertEqual(traced(a), module_with_submodules(a)) |
| for i, node in enumerate(dag.nodes): |
| if i == 1: |
| assert node.logical_device_ids == [1] |
| else: |
| assert node.logical_device_ids == [0] |
| |
| def test_sparse_nn_partition(self): |
| class MyRecommendationModule(torch.nn.Module): |
| def create_mlp(self, num_of_layers: int, input_size: int, output_size: int): |
| layers = torch.nn.ModuleList() |
| for _ in range(num_of_layers): |
| ll = torch.nn.Linear(input_size, output_size) |
| layers.append(ll) |
| layers.append(torch.nn.ReLU()) |
| return layers |
| |
| def __init__(self) -> None: |
| super().__init__() |
| layers = self.create_mlp(4, 4, 4) |
| self.bottom_layers = torch.nn.Sequential(*layers) |
| layers = self.create_mlp(3, 24, 24) |
| self.top_layers = torch.nn.Sequential(*layers) |
| self.embedding_layers = torch.nn.ModuleList() |
| el = torch.nn.EmbeddingBag(500000, 4, mode="sum", sparse=True) |
| self.embedding_layers.append(el) |
| for i in range(3): |
| el = torch.nn.EmbeddingBag(1000000, 4, mode="sum", sparse=True) |
| self.embedding_layers.append(el) |
| el = torch.nn.EmbeddingBag(500000, 4, mode="sum", sparse=True) |
| self.embedding_layers.append(el) |
| |
| def forward(self, a, b, offset): |
| x = self.bottom_layers(a) |
| y = [] |
| c = [] |
| for i in range(len(self.embedding_layers)): |
| temp = torch.randint(10, (8,)) |
| c.append(temp + b) |
| for i in range(len(self.embedding_layers)): |
| if i % 2 == 0: |
| y.append(self.embedding_layers[i](c[i], offset)) |
| else: |
| y.append( |
| self.embedding_layers[i](torch.randint(10, (8,)), offset) |
| ) |
| z = torch.cat([x] + y, dim=1) |
| p = self.top_layers(z) |
| return p |
| |
| m = MyRecommendationModule() |
| a = torch.rand(2, 4) |
| b = torch.randint(10, (8,)) |
| offset = torch.randint(1, (2,)) |
| traced = symbolic_trace(m) |
| graph_manipulation.get_size_of_all_nodes(traced, [a, b, offset]) |
| devices = [ |
| Device("dev_0", 33000000, 0), |
| Device("dev_1", 33000000, 1), |
| Device("dev_2", 33000000, 2), |
| ] |
| partitioner_config = PartitionerConfig(devices, PartitionMode.sparse_nn) |
| partitioner = Partitioner() |
| ret = partitioner.partition_graph(traced, m, partitioner_config) |
| module_with_submodules = ret.module_with_submodules |
| dag = ret.dag |
| self.assertEqual(traced(a, b, offset), module_with_submodules(a, b, offset)) |
| assert len(module_with_submodules.graph.nodes) == 24 |
| |
| def test_partition_latency(self): |
| class TestModule(torch.nn.Module): |
| def __init__(self) -> None: |
| super().__init__() |
| self.linear = torch.nn.Linear(4, 4) |
| |
| def forward(self, a): |
| add_1 = a + torch.rand(4) |
| add_2 = add_1 + torch.rand(4) |
| linear_1 = self.linear(add_1) |
| add_3 = add_2 + linear_1 |
| add_4 = add_2 + add_3 |
| return add_4 |
| |
| def get_node_to_latency_mapping(fx_module: GraphModule): |
| """Given a fx module, generate node latency for each node |
| based on the size of each node |
| """ |
| node_to_latency_mapping: Dict[Node, NodeLatency] = {} |
| for node in fx_module.graph.nodes: |
| if node.op not in {"output", "placeholder", "get_attr"}: |
| if node.size_bytes.total_size == node.size_bytes.output_size: |
| node_to_latency_mapping[node] = NodeLatency( |
| node.size_bytes.total_size, 2.0 * node.size_bytes.total_size |
| ) |
| else: |
| node_to_latency_mapping[node] = NodeLatency( |
| node.size_bytes.total_size, node.size_bytes.output_size |
| ) |
| return node_to_latency_mapping |
| |
| m = TestModule() |
| traced = symbolic_trace(m) |
| a = torch.rand(4) |
| graph_manipulation.get_size_of_all_nodes(traced, [a]) |
| node_to_latency_mapping = get_node_to_latency_mapping(traced) |
| devices = [Device("dev_0", 200, 0), Device("dev_1", 200, 1)] |
| partitioner = Partitioner() |
| partitioner_config = PartitionerConfig(devices) |
| ret = partitioner.partition_graph(traced, m, partitioner_config) |
| module_with_submodules = ret.module_with_submodules |
| self.assertEqual(traced(a), module_with_submodules(a)) |
| partitions = partitioner.partitions |
| partition_to_latency_mapping = get_partition_to_latency_mapping( |
| partitions, node_to_latency_mapping |
| ) |
| for p in partition_to_latency_mapping: |
| if p.partition_id == 0: |
| assert partition_to_latency_mapping[p] == (128.0, 80.0, 160.0) |
| else: |
| assert partition_to_latency_mapping[p] == (16.0, 32.0, 32.0) |
| transfer_rate_bytes_per_sec = 2 |
| critical_path_latency_sec = get_latency_of_partitioned_graph( |
| partitions, partition_to_latency_mapping, transfer_rate_bytes_per_sec |
| ) |
| assert critical_path_latency_sec == 208.0 |
| |
| def test_cost_aware_partition(self): |
| class MyModule(torch.nn.Module): |
| def __init__(self) -> None: |
| super().__init__() |
| self.linear = torch.nn.Linear(4, 4) |
| |
| def forward(self, a): |
| add_1 = a + torch.rand(4) |
| add_2 = add_1 + torch.rand(4) |
| linear_1 = self.linear(add_1) |
| add_3 = add_2 + torch.rand(4) |
| add_4 = add_2 + linear_1 |
| add_5 = add_3 + add_4 |
| return add_5 |
| |
| def get_node_to_latency_mapping(fx_module: GraphModule): |
| node_to_latency_mapping: Dict[Node, NodeLatency] = {} |
| for node in fx_module.graph.nodes: |
| if node.op not in {"output", "placeholder", "get_attr"}: |
| if node.size_bytes.total_size == node.size_bytes.output_size: |
| node_to_latency_mapping[node] = NodeLatency( |
| node.size_bytes.total_size, 1 |
| ) |
| else: |
| node_to_latency_mapping[node] = NodeLatency( |
| node.size_bytes.total_size, node.size_bytes.output_size |
| ) |
| return node_to_latency_mapping |
| |
| m = MyModule() |
| traced = symbolic_trace(m) |
| a = torch.rand(4) |
| graph_manipulation.get_size_of_all_nodes(traced, [a]) |
| devices = [ |
| Device("dev_0", 125, 0), |
| Device("dev_1", 125, 1), |
| Device("dev_2", 125, 2), |
| Device("dev_3", 125, 3), |
| ] |
| node_to_latency_mapping = get_node_to_latency_mapping(traced) |
| partitioner_config = PartitionerConfig( |
| devices, |
| mode=PartitionMode.cost_aware, |
| transfer_rate_bytes_per_sec=2, |
| node_to_latency_mapping=node_to_latency_mapping, |
| ) |
| partitioner = Partitioner() |
| ret = partitioner.partition_graph(traced, m, partitioner_config) |
| module_with_submodules = ret.module_with_submodules |
| dag = ret.dag |
| self.assertEqual(traced(a), module_with_submodules(a)) |
| partitions = partitioner.partitions |
| partition_to_latency_mapping = get_partition_to_latency_mapping( |
| partitions, node_to_latency_mapping |
| ) |
| critical_path_latency_sec = get_latency_of_partitioned_graph( |
| partitions, |
| partition_to_latency_mapping, |
| partitioner_config.transfer_rate_bytes_per_sec, |
| ) |
| assert critical_path_latency_sec == 160.0 |
| |
| def test_aot_based_partition(self): |
| class TestModule(torch.nn.Module): |
| def __init__(self) -> None: |
| super().__init__() |
| self.b = torch.rand(4) |
| self.c = torch.rand(4) |
| |
| def forward(self, a): |
| add_1 = a + self.b |
| add_2 = self.c + add_1 |
| return add_2 |
| |
| m = TestModule() |
| traced = symbolic_trace(m) |
| a = torch.rand(4) |
| node_to_partition_id = {} |
| partition_to_logical_devices = {} |
| count = 0 |
| graph_manipulation.get_size_of_all_nodes(traced, [a]) |
| for node in traced.graph.nodes: |
| if node.op not in {"placeholder", "get_attr", "output"}: |
| node_to_partition_id[node] = count |
| partition_to_logical_devices[count] = [0] |
| count += 1 |
| devices = [Device("dev_0", 200, 0)] |
| partitioner_config = PartitionerConfig( |
| devices=devices, |
| mode=PartitionMode.aot_based, |
| node_to_partition_mapping=node_to_partition_id, |
| partition_to_logical_device_mapping=partition_to_logical_devices, |
| ) |
| partitioner = Partitioner() |
| ret = partitioner.partition_graph(traced, m, partitioner_config) |
| module_with_submodules = ret.module_with_submodules |
| dag = ret.dag |
| self.assertEqual(module_with_submodules(a), traced(a)) |
| for node in dag.nodes: |
| assert node.size_bytes == 48 |
| assert node.logical_device_ids == [0] |
| |
| def test_replace_target_nodes_with(self): |
| class testModule(torch.nn.Module): |
| def forward(self, a, b): |
| return a + b |
| |
| m = testModule() |
| traced = symbolic_trace(m) |
| input1 = torch.randn(1) |
| input2 = torch.randn(1) |
| assert (input1 + input2) == traced(input1, input2) |
| graph_manipulation.replace_target_nodes_with( |
| fx_module=traced, |
| old_op="call_function", |
| old_target=operator.add, |
| new_op="call_function", |
| new_target=operator.mul, |
| ) |
| assert (input1 * input2) == traced(input1, input2) |
| |
| def test_saturate_host(self): |
| class TestModule(torch.nn.Module): |
| def __init__(self) -> None: |
| super().__init__() |
| self.linear = torch.nn.Linear(4, 4) |
| |
| def forward(self, a): |
| add_1 = a + torch.rand(4) |
| add_2 = add_1 + torch.rand(4) |
| linear_1 = self.linear(add_1) |
| add_3 = add_2 + linear_1 |
| add_4 = add_2 + add_3 |
| return add_4 |
| |
| m = TestModule() |
| traced = symbolic_trace(m) |
| a = torch.rand(4) |
| graph_manipulation.get_size_of_all_nodes(traced, [a]) |
| devices = [ |
| Device("dev_0", 200, 0), |
| Device("dev_1", 200, 1), |
| Device("dev_2", 100, 2), |
| Device("dev_3", 100, 3), |
| Device("dev_4", 200, 4), |
| Device("dev_5", 100, 5), |
| ] |
| partitioner = Partitioner() |
| # Without host saturation, the model will be split into two partitions. |
| # dev_0 holds partition 0 of 192 bytes and dev_1 holds partition 1 of 48 bytes. |
| partitioner_config = PartitionerConfig(devices, saturate_host=True) |
| ret = partitioner.partition_graph(traced, m, partitioner_config) |
| module_with_submodules = ret.module_with_submodules |
| self.assertEqual(traced(a), module_with_submodules(a)) |
| |
| partitions = partitioner.partitions |
| self.assertEqual(len(partitions), 2) |
| # With host saturation, partition 1 will be replicated to dev_4, and partition 2 |
| # will be replicated to dev_2. |
| self.assertEqual(partitions[0].logical_device_ids, [0, 4]) |
| self.assertEqual(partitions[1].logical_device_ids, [1, 2]) |
| |
| @skipIfNoTorchVision |
| def test_conv_bn_fusion(self): |
| rn18 = resnet18().eval() |
| traced = symbolic_trace(rn18) |
| fused = optimization.fuse(traced) |
| |
| self.assertTrue( |
| all(not isinstance(m, torch.nn.BatchNorm2d) for m in fused.modules()) |
| ) |
| |
| N, C, H, W = 20, 3, 224, 224 |
| inp = torch.randn(N, C, H, W) |
| |
| self.assertEqual(fused(inp), rn18(inp)) |
| |
| def test_conv_bn_fusion_not_running_state(self): |
| class M(torch.nn.Module): |
| def __init__(self) -> None: |
| super().__init__() |
| self.conv = torch.nn.Conv2d(32, 64, 3, stride=2) |
| self.bn = torch.nn.BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False) |
| |
| def forward(self, x): |
| x = self.conv(x) |
| x = self.bn(x) |
| return x |
| |
| model = M().eval() |
| |
| traced = symbolic_trace(model) |
| fused = optimization.fuse(traced) |
| inp = torch.randn([1, 32, 50, 50]) |
| |
| # bn need not be folded in conv |
| self.assertTrue( |
| any(isinstance(m, torch.nn.BatchNorm2d) for m in fused.modules()) |
| ) |
| self.assertEqual(fused(inp), model(inp)) |
| |
| def test_conv_bn_fusion_mixed_dtype(self): |
| class M(torch.nn.Module): |
| def __init__(self) -> None: |
| super().__init__() |
| self.conv = torch.nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False, dtype=torch.bfloat16) |
| self.bn = torch.nn.BatchNorm2d(16, eps=0.001, momentum=0.1, affine=True, track_running_stats=True) |
| |
| def forward(self, x): |
| x = self.conv(x) |
| x = self.bn(x) |
| return x |
| |
| model = M().eval() |
| |
| traced = symbolic_trace(model) |
| fused = optimization.fuse(traced) |
| inp = torch.randn(1, 3, 64, 64, dtype=torch.bfloat16) |
| |
| self.assertTrue( |
| all(not isinstance(m, torch.nn.BatchNorm2d) for m in fused.modules()) |
| ) |
| self.assertEqual(fused(inp), model(inp)) |
| |
| def test_call_to_assert_no_msg(self): |
| class M(torch.nn.Module): |
| def forward(self, a, b): |
| assert a == b |
| return a + b |
| |
| m = M() |
| traced = symbolic_trace_with_rewrite(m) |
| |
| # Make sure the graph is well-formed |
| traced.graph.lint() |
| |
| # Check the IR to make sure there's a call_function node with target == "Assert" |
| self.assertTrue( |
| any( |
| node.op == "call_function" and node.target == torch._assert |
| for node in traced.graph.nodes |
| ) |
| ) |
| |
| # Ensure that the assert throws when it's supposed to and doesn't throw when it's not supposed to |
| traced(3, 3) |
| with self.assertRaisesRegex(AssertionError, ""): |
| traced(3, 5) |
| |
| # Confirm that the output is correct |
| self.assertEqual(traced(3, 3), m(3, 3)) |
| |
| def test_meta_tracer(self): |
| class MetaTracerTestModule(torch.nn.Module): |
| def __init__(self) -> None: |
| super().__init__() |
| self.emb = torch.nn.Embedding(num_embeddings=42, embedding_dim=16) |
| self.layernorm = torch.nn.LayerNorm(16) |
| |
| def forward(self, x): |
| emb = self.emb(x) |
| emb = emb + torch.arange(emb.shape[-1], dtype=torch.float, device=emb.device) |
| lol = self.layernorm(emb) |
| return torch.relu(lol) if lol.shape[0] < 30 else torch.sigmoid(lol) |
| |
| mttm = MetaTracerTestModule() |
| for BS in [15, 35]: |
| x = torch.zeros(BS, dtype=torch.long).random_(42) |
| meta_args = {'x' : x.to(device='meta')} |
| gm = torch.fx.experimental.meta_tracer.symbolic_trace(mttm, meta_args=meta_args) |
| torch.testing.assert_close(gm(x), mttm(x)) |
| |
| # Test serialization/deserialization |
| with tempfile.TemporaryDirectory() as tmp_dir: |
| with open(f'{tmp_dir}/meta_module.pkl', 'wb') as f: |
| pickle.dump(gm, f) |
| |
| with open(f'{tmp_dir}/meta_module.pkl', 'rb') as f: |
| loaded = pickle.load(f) |
| |
| torch.testing.assert_close(loaded(x), mttm(x)) |
| |
| |
| def test_call_to_assert_with_msg(self): |
| class M(torch.nn.Module): |
| def forward(self, a, b): |
| assert a == b, "test message" |
| return a + b |
| |
| m = M() |
| traced = symbolic_trace_with_rewrite(m) |
| |
| # Make sure the graph is well-formed |
| traced.graph.lint() |
| |
| # Check the IR to make sure there's a call_function node with target == "Assert" |
| self.assertTrue( |
| any( |
| node.op == "call_function" and node.target == torch._assert |
| for node in traced.graph.nodes |
| ) |
| ) |
| |
| # Ensure that the assert throws when it's supposed to and doesn't throw when it's not supposed to |
| traced(3, 3) |
| with self.assertRaisesRegex(AssertionError, "test message"): |
| traced(3, 5) |
| |
| # Confirm that the output is correct |
| self.assertEqual(traced(3, 3), m(3, 3)) |
| |
| def test_call_to_assert_with_empty_msg(self): |
| class M(torch.nn.Module): |
| def forward(self, a, b): |
| assert a == b, "" |
| return a + b |
| |
| m = M() |
| traced = symbolic_trace_with_rewrite(m) |
| |
| # Make sure the graph is well-formed |
| traced.graph.lint() |
| |
| # Check the IR to make sure there's a call_function node with target == "Assert" |
| self.assertTrue( |
| any( |
| node.op == "call_function" and node.target == torch._assert |
| for node in traced.graph.nodes |
| ) |
| ) |
| |
| # Ensure that the assert throws when it's supposed to and doesn't throw when it's not supposed to |
| traced(3, 3) |
| with self.assertRaisesRegex(AssertionError, ""): |
| traced(3, 5) |
| |
| # Confirm that the output is correct |
| self.assertEqual(traced(3, 3), m(3, 3)) |
| |
| def test_call_to_assert_with_multiline_message(self): |
| class M(torch.nn.Module): |
| def forward(self, a, b): |
| error_msg = """ |
| An error message with |
| terrible spacing |
| """ |
| assert a == b, error_msg |
| return a + b |
| |
| m = M() |
| traced = symbolic_trace_with_rewrite(m) |
| |
| # Make sure the graph is well-formed |
| traced.graph.lint() |
| |
| # Check the IR to make sure there's a call_function node with target == "Assert" |
| self.assertTrue( |
| any( |
| node.op == "call_function" and node.target == torch._assert |
| for node in traced.graph.nodes |
| ) |
| ) |
| |
| # Ensure that the assert throws when it's supposed to and doesn't throw when it's not supposed to |
| error_msg = """ |
| An error message with |
| terrible spacing |
| """ |
| traced(3, 3) |
| with self.assertRaisesRegex(AssertionError, error_msg): |
| traced(3, 5) |
| |
| # Confirm that the output is correct |
| self.assertEqual(traced(3, 3), m(3, 3)) |
| |
| def test_subgraph_creation(self): |
| class MyModule(torch.nn.Module): |
| def __init__(self) -> None: |
| super().__init__() |
| self.param = torch.nn.Parameter(torch.rand(3, 4)) |
| self.linear = torch.nn.Linear(4, 5) |
| |
| def forward(self, x, y): |
| z = self.linear(x + self.param).clamp(min=0.0, max=1.0) |
| w = self.linear(y).clamp(min=0.0, max=1.0) |
| return z + w |
| |
| # symbolically trace model |
| my_module = MyModule() |
| my_module_traced = symbolic_trace(my_module) |
| |
| # random mod partitioning |
| partition_counter = 0 |
| NPARTITIONS = 3 |
| |
| # Add some random meta info to make sure it is kept around. |
| for node in my_module_traced.graph.nodes: |
| if node.op != "output": |
| node.meta["test_meta_info"] = True |
| |
| def mod_partition(node: Node): |
| nonlocal partition_counter |
| partition = partition_counter % NPARTITIONS |
| partition_counter = (partition_counter + 1) % NPARTITIONS |
| return partition |
| |
| # split module in module with submodules |
| module_with_submodules = split_module( |
| my_module_traced, my_module, mod_partition |
| ) |
| |
| # Check that test_meta_info was still on all nodes. |
| submodules = dict(module_with_submodules.named_modules()) |
| for node in module_with_submodules.graph.nodes: |
| if node.op == "call_module": |
| submod = submodules[node.target] |
| self.assertTrue(isinstance(submod, torch.fx.GraphModule)) |
| for submod_node in submod.graph.nodes: |
| if submod_node.op != "output": |
| stored_op = submod_node.meta.get("test_meta_info") |
| self.assertTrue(stored_op is not None and stored_op) |
| |
| x = torch.rand(3, 4) |
| y = torch.rand(3, 4) |
| |
| orig_out = my_module_traced(x, y) |
| submodules_out = module_with_submodules(x, y) |
| |
| self.assertEqual(orig_out, submodules_out) |
| |
| def test_split_module_dead_code(self): |
| class ModWithDeadCode(torch.nn.Module): |
| def forward(self, x): |
| output = x * 2 # we want this |
| dead_line = x + 2 # this is dead |
| return output |
| |
| mod = ModWithDeadCode() |
| traced = torch.fx.symbolic_trace(mod) |
| |
| # split into before (0), target (1), and after(2) |
| saw_mul = False |
| |
| def split_callback(n): |
| nonlocal saw_mul |
| if n.target == operator.mul: |
| saw_mul = True |
| return 1 |
| |
| if not saw_mul: |
| return 0 |
| if saw_mul: |
| return 2 |
| |
| split = split_module(traced, mod, split_callback) |
| |
| x = torch.randn((5,)) |
| torch.testing.assert_close( |
| split(x), traced(x) |
| ) |
| |
| |
| def test_split_module_kwargs_expansion(self): |
| class ModuleWithKwargsExpansion(torch.nn.Module): |
| def forward(self, x, **kwargs): |
| return x + kwargs['foo'] |
| |
| mod = ModuleWithKwargsExpansion() |
| traced = torch.fx.symbolic_trace(mod) |
| |
| seen_getitem = False |
| |
| def split_callback(n): |
| nonlocal seen_getitem |
| split_idx = int(seen_getitem) |
| if n.target == operator.getitem: |
| seen_getitem = True |
| return split_idx |
| |
| split = split_module(traced, mod, split_callback) |
| |
| x = torch.randn(5, 3) |
| foo = torch.randn(5, 3) |
| torch.testing.assert_close(split(x, foo=foo), traced(x, foo=foo)) |
| |
| @skipIfNoTorchVision |
| def test_subgraph_trivial_resnet(self): |
| # Smoke test trivially splitting resnet into 1 partition works |
| # There was an issue before causing submodule names to be aliased |
| m = resnet18() |
| traced = symbolic_trace(m) |
| a = torch.rand(64, 3, 7, 7) |
| module_with_submodules = split_module(traced, m, lambda node: 0) |
| module_with_submodules(a) |
| |
| def test_split_module_default_arg(self): |
| class ModelToTrace(torch.nn.Module): |
| def __init__(self) -> None: |
| super().__init__() |
| self.lin = torch.nn.Linear(512, 512) |
| |
| def forward(self, x, targets=None): |
| x = self.lin(x) |
| |
| if targets is not None: |
| x = x + targets |
| |
| return x |
| |
| mtt = ModelToTrace() |
| traced = torch.fx.symbolic_trace(mtt, concrete_args={'targets': None}) |
| |
| split = split_module(traced, mtt, lambda node: 0) |
| |
| x = torch.randn(50, 512) |
| torch.testing.assert_close(split(x), traced(x)) |
| |
| def test_normalize_binary_operators(self): |
| ops_to_test = { |
| torch.add, |
| torch.mul, |
| torch.sub, |
| torch.div, |
| torch.floor_divide, |
| torch.remainder, |
| torch.eq, |
| torch.ne, |
| torch.lt, |
| torch.le, |
| torch.gt, |
| torch.ge, |
| } |
| |
| # Test Tensor/Tensor callsite |
| for op in ops_to_test: |
| |
| class WrapperMod(torch.nn.Module): |
| def forward(self, x, y): |
| return op(x, y) |
| |
| traced = symbolic_trace(WrapperMod()) |
| normalized = NormalizeOperators(traced).transform() |
| x, y = torch.randn(3, 4), torch.randn(3, 4) |
| torch.testing.assert_close(traced(x, y), normalized(x, y)) |
| self.assertFalse( |
| any(n.target in ops_to_test for n in normalized.graph.nodes) |
| ) |
| |
| # Test Tensor/scalar callsite |
| for op in ops_to_test: |
| |
| class WrapperMod(torch.nn.Module): |
| def forward(self, x): |
| return op(x, 42) |
| |
| traced = symbolic_trace(WrapperMod()) |
| normalized = NormalizeOperators(traced).transform() |
| x = torch.randn(3, 4) |
| torch.testing.assert_close(traced(x), normalized(x)) |
| self.assertFalse( |
| any(n.target in ops_to_test for n in normalized.graph.nodes) |
| ) |
| |
| @skipIfNoTorchVision |
| def test_normalize_args(self): |
| m = resnet18() |
| |
| class FunctionalTracer(torch.fx.Tracer): |
| def is_leaf_module( |
| self, m: torch.nn.Module, module_qualified_name: str |
| ) -> bool: |
| # `leaves` contains the set of standard `nn.Modules` that are not |
| # currently symbolically traceable. Ideally this set would be empty |
| leaves = {torch.nn.BatchNorm2d} |
| return type(m) in leaves |
| |
| traced = torch.fx.GraphModule(m, FunctionalTracer().trace(m)) |
| |
| input = torch.randn(5, 3, 224, 224) |
| ref_outs = traced(input) |
| |
| ShapeProp(traced).propagate(input) |
| traced = NormalizeArgs(traced).transform() |
| |
| modules = dict(traced.named_modules()) |
| |
| for node in traced.graph.nodes: |
| if node.op == "call_function" and node.target != operator.add: |
| self.assertEqual(len(node.args), 0) |
| elif node.op == "call_module": |
| submod_class = modules[node.target].__class__ |
| nn_class = getattr(torch.nn, submod_class.__name__) |
| if submod_class == nn_class: |
| self.assertEqual(len(node.args), 0) |
| traced(input) |
| self.assertEqual(traced(input), ref_outs) |
| |
| def test_normalize_modules_exhaustive(self): |
| """ |
| Exhaustively test `Node.normalized_arguments` on all standard |
| torch.nn Module classes |
| """ |
| for test_params in module_tests + new_module_tests: |
| if "constructor" not in test_params: |
| constructor = getattr(torch.nn, test_params["module_name"]) |
| else: |
| constructor = test_params["constructor"] |
| |
| if "constructor_args" not in test_params: |
| args = () |
| else: |
| args = test_params["constructor_args"] |
| |
| mod = constructor(*args) |
| # Skip modules that are not standard `torch.nn` |
| # instances, including functionals. (functionals |
| # are tested in test_normalize_args) |
| if mod.__class__.__name__ not in dir(torch.nn): |
| continue |
| |
| if "input_fn" not in test_params: |
| inputs = torch.randn(test_params["input_size"]) |
| else: |
| inputs = test_params["input_fn"]() |
| |
| if not isinstance(inputs, (tuple, list)): |
| inputs = (inputs,) |
| |
| params = ", ".join(f"v{i}" for i in range(len(inputs))) |
| |
| # Generate a class to wrap this standard `nn.Module` instance |
| test_classname = f"Test{mod.__class__.__name__}" |
| test_mod_code = f""" |
| class {test_classname}(torch.nn.Module): |
| def __init__(self, mod): |
| super().__init__() |
| self.mod = mod |
| |
| def forward(self, {params}): |
| return self.mod({params}) |
| """ |
| |
| gbls = {"torch": torch} |
| exec(test_mod_code, gbls) |
| |
| test_instance = gbls[test_classname](mod) |
| traced = symbolic_trace(test_instance) |
| |
| # Use `Node.normalized_arguments` to get a new set of arguments |
| # to feed to the Module. Then, rewrite the node to only take |
| # in those arguments as kwargs |
| modules = dict(traced.named_modules()) |
| for node in traced.graph.nodes: |
| if node.op == "call_module": |
| submod_class = modules[node.target].__class__ |
| nn_class = getattr(torch.nn, submod_class.__name__) |
| if submod_class == nn_class: |
| normalized_args = node.normalized_arguments(traced) |
| normalized_args2 = normalize_module( |
| traced, node.target, node.args, node.kwargs |
| ) |
| assert normalized_args == normalized_args2 |
| assert normalized_args |
| node.args = normalized_args.args |
| node.kwargs = normalized_args.kwargs |
| |
| traced.recompile() |
| |
| # These Modules have an RNG in their forward, so testing |
| # correctness by comparing outputs is not correct. Skip that |
| # check for these |
| stochastic_modules = {"FractionalMaxPool2d", "FractionalMaxPool3d", "RReLU"} |
| |
| if mod.__class__.__name__ not in stochastic_modules: |
| self.assertEqual(traced(*inputs), mod(*inputs)) |
| |
| traced = NormalizeArgs(symbolic_trace(test_instance)).transform() |
| modules = dict(traced.named_modules()) |
| for node in traced.graph.nodes: |
| if node.op == "call_module": |
| submod_class = modules[node.target].__class__ |
| nn_class = getattr(torch.nn, submod_class.__name__) |
| if submod_class == nn_class: |
| self.assertEqual(len(node.args), 0) |
| |
| def test_normalize_args_preserve_meta(self): |
| class MyModule(torch.nn.Module): |
| def forward(self, a): |
| return torch.add(a, 3) |
| |
| m = MyModule() |
| traced = symbolic_trace(m) |
| |
| for node in traced.graph.nodes: |
| if node.op == "call_function" and node.target == torch.add: |
| node.meta["my_key"] = 7 |
| break |
| else: |
| self.fail("Didn't find call_function torch.add") |
| |
| input = torch.randn(2, 3) |
| ShapeProp(traced).propagate(input) |
| traced = NormalizeArgs(traced).transform() |
| |
| for node in traced.graph.nodes: |
| if node.op == "call_function" and node.target == torch.add: |
| self.assertTrue("my_key" in node.meta) |
| self.assertEqual(node.meta["my_key"], 7) |
| break |
| else: |
| self.fail("Didn't find call_function torch.add") |
| |
| def test_normalize_args_perserve_type(self): |
| class MyModule(torch.nn.Module): |
| def forward(self, a: List[torch.Tensor]): |
| return torch.add(a[0], a[1]) |
| |
| m = MyModule() |
| traced = symbolic_trace(m) |
| traced = NormalizeArgs(traced).transform() |
| |
| for node in traced.graph.nodes: |
| if node.op == "placeholder": |
| self.assertEqual(node.type, List[torch.Tensor]) |
| |
| @skipIfNoTorchVision |
| def test_annotate_returns_with_schema(self): |
| m = resnet18() |
| |
| traced_modules = symbolic_trace(m) |
| traced_modules_annotated = AnnotateTypesWithSchema(traced_modules).transform() |
| for node in traced_modules_annotated.graph.nodes: |
| if node.type is None: |
| check = (node.op, node.target) |
| self.assertIn( |
| check, |
| { |
| ("placeholder", "x"), |
| ("call_module", "maxpool"), |
| ("call_function", operator.add), |
| ("call_function", torch.flatten), |
| ("output", "output"), |
| } |
| ) |
| |
| # Smoke test torchscript compilation since now we're emitting type annotations |
| torch.jit.script(traced_modules_annotated) |
| |
| class FunctionalTracer(torch.fx.Tracer): |
| def is_leaf_module( |
| self, m: torch.nn.Module, module_qualified_name: str |
| ) -> bool: |
| # `leaves` contains the set of standard `nn.Modules` that are not |
| # currently symbolically traceable. Ideally this set would be empty |
| leaves = {torch.nn.BatchNorm2d} |
| return type(m) in leaves |
| |
| traced_functionals = torch.fx.GraphModule(m, FunctionalTracer().trace(m)) |
| |
| traced_functionals_annotated = AnnotateTypesWithSchema( |
| traced_functionals |
| ).transform() |
| for node in traced_functionals_annotated.graph.nodes: |
| if node.type is None: |
| check = (node.op, node.target) |
| excluded_nodes = { |
| ("placeholder", "x"), |
| # Return type differs based on boolean dispatch :( |
| ("call_function", torch.nn.functional.max_pool2d), |
| ("output", "output"), |
| } |
| # AnnotateTypesWithSchema doesn't work with bound C++ functions |
| if not isinstance(node.target, BuiltinFunctionType): |
| self.assertIn(check, excluded_nodes) |
| |
| # Smoke test torchscript compilation since now we're emitting type annotations |
| torch.jit.script(traced_functionals_annotated) |
| |
| def test_annotate_getitem_node(self): |
| class CustomType: |
| pass |
| |
| class CustomNamedTuple(NamedTuple): |
| x: int |
| y: float |
| |
| class MyModule(torch.nn.Module): |
| def forward(self, inp: Tuple[CustomType, torch.Tensor], inp2: List[CustomType], inp3: CustomNamedTuple): |
| inp_0 = inp[0] |
| inp_1 = inp[1] |
| inp2_0 = inp2[0] |
| inp3_x = inp3.x |
| inp3_y = inp3.y |
| return inp_0 + inp_1 + inp2_0 + inp3_x + inp3_y |
| |
| my_module = MyModule() |
| my_module_traced = torch.fx.symbolic_trace(my_module) |
| |
| # by default, fx transform loses type annotation of getitem nodes. |
| for node in my_module_traced.graph.nodes: |
| if node.target == operator.getitem: |
| assert node.type is None |
| |
| annotate_getitem_nodes(my_module_traced.graph) |
| |
| for node in my_module_traced.graph.nodes: |
| if node.target == operator.getitem: |
| self.assertIsNotNone(node.type, f"Node {node} should be annotated but is not.") |
| |
| def test_subgraph_uniquename(self): |
| class MyModule(torch.nn.Module): |
| def __init__(self) -> None: |
| super().__init__() |
| self.linear = torch.nn.Linear(4, 4) |
| |
| def forward(self, a, b, c, d): |
| add_1 = a + b |
| add_2 = add_1 + c |
| linear_1 = self.linear(add_1) |
| add_3 = add_2 + d |
| add_4 = add_2 + linear_1 |
| add_5 = add_3 + add_4 |
| return add_5 |
| |
| a, b, c, d = torch.ones(4), torch.ones(4), torch.ones(4), torch.ones(4) |
| mm = MyModule() |
| traced = symbolic_trace(mm) |
| |
| def split_cb(node: torch.fx.Node): |
| if node.name == "a" or node.name == "b" or node.name == "add": |
| return 0 |
| else: |
| return 1 |
| |
| module_with_submodule = split_module(traced, mm, split_cb) |
| self.assertEqual(module_with_submodule(a, b, c, d), traced(a, b, c, d)) |
| |
| def test_split_qualname_mapping(self): |
| d_hid = 4 |
| |
| class ExampleCode(torch.nn.Module): |
| def __init__(self) -> None: |
| super().__init__() |
| self.mm_param = torch.nn.Parameter(torch.randn(d_hid, d_hid)) |
| self.mm_param2 = torch.nn.Parameter(torch.randn(d_hid, d_hid)) |
| self.lin = torch.nn.Linear(d_hid, d_hid) |
| |
| def forward(self, x): |
| x = torch.mm(x, self.mm_param) |
| x = torch.relu(x) |
| x = torch.mm(x, self.mm_param) |
| x = self.lin(x) |
| x = torch.relu(x) |
| x = torch.mm(x, self.mm_param2) |
| x = self.lin(x) |
| return x |
| |
| my_module = ExampleCode() |
| my_module_traced = symbolic_trace(my_module) |
| |
| part_idx = 0 |
| |
| def split_callback(n : torch.fx.Node): |
| nonlocal part_idx |
| if (n.op, n.target) == ('call_module', 'lin'): |
| part_idx += 1 |
| return part_idx |
| |
| # split module in module with submodules |
| qualname_map : Dict[str, str] = {} |
| module_with_submodules = split_module( |
| my_module_traced, my_module, split_callback, qualname_map |
| ) |
| expected_qualname_map = { |
| 'submod_1.lin': 'lin', 'submod_2.lin': 'lin' |
| } |
| self.assertEqual(qualname_map, expected_qualname_map) |
| |
| def test_traceable_function_with_nonstandard_name(self): |
| def foo(x): |
| return torch.relu(x) |
| |
| traced = symbolic_trace_with_rewrite(foo) |
| |
| def test_to_folder(self): |
| class Test(torch.nn.Module): |
| def __init__(self) -> None: |
| super().__init__() |
| self.W = torch.nn.Parameter(torch.randn(2)) |
| self.seq = torch.nn.Sequential(torch.nn.BatchNorm1d(2, 2)) |
| self.linear = torch.nn.Linear(2, 2) |
| self.attr = torch.randn(2) |
| self.attr2 = torch.nn.Buffer(torch.randn(2)) |
| self.attr3 = torch.nn.Buffer(torch.ones(2, dtype=torch.int32)) |
| |
| def forward(self, x): |
| return self.linear(self.seq(self.W + self.attr + self.attr2 + self.attr3 + x)) |
| |
| mod = symbolic_trace(Test()) |
| module_name = "Foo" |
| import tempfile |
| from pathlib import Path |
| |
| with tempfile.TemporaryDirectory() as tmp_dir: |
| tmp_dir = Path(tmp_dir) |
| mod.to_folder(tmp_dir, module_name) |
| # Recipe taken from here: |
| # https://docs.python.org/3/library/importlib.html#importing-a-source-file-directly |
| import importlib.util |
| |
| spec = importlib.util.spec_from_file_location( |
| module_name, tmp_dir / "__init__.py" |
| ) |
| module = importlib.util.module_from_spec(spec) |
| sys.modules[module_name] = module |
| spec.loader.exec_module(module) |
| t = torch.randn(2, 2) |
| self.assertEqual(module.Foo()(t), mod(t)) |
| |
| def test_fetch(self): |
| attrs_for_lowering: Dict[str, List[str]] = { |
| "torch.nn.modules.conv.Conv2d": [ |
| "weight", |
| "bias", |
| "kernel_size", |
| "stride", |
| "padding", |
| "dilation", |
| "groups", |
| "padding_mode", |
| ], |
| "torch.nn.modules.batchnorm.BatchNorm2d": [ |
| "weight", |
| "bias", |
| "running_mean", |
| "running_var", |
| "eps", |
| ], |
| } |
| |
| class TestModule(torch.nn.Module): |
| def __init__(self) -> None: |
| super().__init__() |
| self.conv = torch.nn.Conv2d(3, 3, 2) |
| self.bn = torch.nn.BatchNorm2d(3) |
| |
| def forward(self, a): |
| a = self.conv(a) |
| a += a |
| return self.bn(a) |
| |
| mod = TestModule() |
| traced = symbolic_trace(mod) |
| lift_lowering_attrs_to_nodes(traced) |
| |
| for node in traced.graph.nodes: |
| if node.op == "call_module": |
| assert hasattr(node, "attrs_for_lowering") |
| para_list = attrs_for_lowering[node.attrs_for_lowering["name"]] |
| |
| # node.attrs_for_lowering has an addition field of class name |
| assert len(para_list) + 1 == len(node.attrs_for_lowering) |
| for p_name in para_list: |
| assert p_name in node.attrs_for_lowering |
| |
| def test_merge_matmuls(self): |
| """ |
| A collection of test cases for torch.fx.experimental.merge_matmul, |
| a graph transformation that merges matrix multiplication operations. |
| """ |
| # Utility function for counting matmuls for test assertions. |
| def _count_matmuls(mod): |
| gm = torch.fx.symbolic_trace(mod) |
| |
| num_matmuls = 0 |
| for node in gm.graph.nodes: |
| if node.target == torch.matmul: |
| num_matmuls += 1 |
| |
| return num_matmuls |
| |
| # Simple test case in which there are two matmuls of the same size to merge. |
| class SimpleMergeMatmulModule(torch.nn.Module): |
| def __init__(self, rhs): |
| super().__init__() |
| self.rhs = rhs |
| |
| def forward(self, x, y): |
| a = torch.matmul(x, self.rhs) |
| b = torch.matmul(y, self.rhs) |
| return a + b |
| |
| # Initialize inputs. |
| a = torch.randn(3, 3) |
| b = torch.randn(3, 3) |
| |
| # Initialize RHS for matmuls. |
| rhs = torch.randn(3, 4) |
| |
| # Construct SimpleMergeMatmulModule and call merge_matmul on it. |
| module = SimpleMergeMatmulModule(rhs) |
| opt_module = merge_matmul.merge_matmul(module) |
| |
| # Numerical correctness check. |
| before = module(a, b) |
| after = opt_module(a, b) |
| before.allclose(after) |
| |
| # Basic graph structure check; original module should have 2 matmuls |
| # and optimized module should have 1. |
| self.assertEqual(_count_matmuls(module), 2) |
| self.assertEqual(_count_matmuls(opt_module), 1) |
| |
| # Test case in which there are multiple matmuls of different sizes to merge. |
| class FiveMergeMatmulModule(torch.nn.Module): |
| def __init__(self, rhs): |
| super().__init__() |
| self.rhs = rhs |
| |
| def forward(self, a, b, c, d, e): |
| s = torch.tensor([]) |
| matmuls = [] |
| |
| # For some reason using a list comprehension or for-loop for this |
| # doesn't work. |
| matmuls.append(torch.matmul(a, self.rhs)) |
| matmuls.append(torch.matmul(b, self.rhs)) |
| matmuls.append(torch.matmul(c, self.rhs)) |
| matmuls.append(torch.matmul(d, self.rhs)) |
| matmuls.append(torch.matmul(e, self.rhs)) |
| |
| for m in matmuls: |
| s += torch.sum(m) |
| |
| return s |
| |
| # Initialize inputs. |
| inputs = [torch.randn(2 * i + 1, 5) for i in range(5)] |
| |
| # Initialize RHS. |
| rhs = torch.randn(5, 4) |
| |
| # Construct FiveMergeMatmulModule and call merge_matmul on it. |
| module = FiveMergeMatmulModule(rhs) |
| opt_module = merge_matmul.merge_matmul(module) |
| |
| # Numerical correctness check. |
| before = module(*inputs) |
| after = opt_module(*inputs) |
| before.allclose(after) |
| |
| # Basic graph structure check; original module should have len(inputs) matmuls |
| # and optimized module should have 1. |
| self.assertEqual(_count_matmuls(module), len(inputs)) |
| self.assertEqual(_count_matmuls(opt_module), 1) |
| |
| # Simple test case in which two matmuls cannot be merged due to a data dependency between |
| # the LHS operands. |
| class UnmergeableMatmulModule(torch.nn.Module): |
| def __init__(self, rhs): |
| super().__init__() |
| self.rhs = rhs |
| |
| def forward(self, x): |
| a = torch.matmul(x, self.rhs) |
| a_abs = torch.abs(a) |
| b = torch.matmul(a_abs.transpose(1, 0), self.rhs) |
| return b |
| |
| # Initialize inputs. |
| a = torch.randn(3, 3) |
| |
| # Initialize RHS for matmuls. |
| rhs = torch.randn(3, 4) |
| |
| # Construct UnmergeableMatmulModule and call merge_matmul on it. |
| module = UnmergeableMatmulModule(rhs) |
| opt_module = merge_matmul.merge_matmul(module) |
| |
| # Numerical correctness check. |
| before = module(a) |
| after = opt_module(a) |
| before.allclose(after) |
| |
| # Basic graph structure check; the number of matrix multiplcations should not have changed. |
| self.assertEqual(_count_matmuls(module), 2) |
| self.assertEqual(_count_matmuls(opt_module), 2) |
| |
| def test_type_matches(self): |
| should_be_equal = [ |
| (int, int), |
| (numbers.Number, int), |
| (numbers.Number, float), |
| (int, type(torch.float)), |
| (Union[int, float], int), |
| (Union[int, float], float), |
| (List[int], int), |
| (List[int], create_type_hint([int, int])), |
| (List[int], create_type_hint((int, int))), |
| (List[torch.Tensor], create_type_hint([torch.Tensor, torch.Tensor])), |
| ( |
| List[torch.Tensor], |
| create_type_hint([torch.nn.Parameter, torch.nn.Parameter]), |
| ), |
| (torch.Tensor, torch.nn.Parameter), |
| (List[torch.Tensor], create_type_hint([torch.nn.Parameter, torch.Tensor])), |
| (List[torch.Tensor], create_type_hint([torch.Tensor, torch.nn.Parameter])), |
| (List[torch.Tensor], create_type_hint((torch.Tensor, torch.Tensor))), |
| ( |
| List[torch.Tensor], |
| create_type_hint((torch.nn.Parameter, torch.nn.Parameter)), |
| ), |
| (torch.Tensor, torch.nn.Parameter), |
| (List[torch.Tensor], create_type_hint((torch.nn.Parameter, torch.Tensor))), |
| (List[torch.Tensor], create_type_hint((torch.Tensor, torch.nn.Parameter))), |
| (Optional[List[torch.Tensor]], List[torch.Tensor]), |
| (Optional[List[int]], List[int]), |
| ] |
| for sig_type, arg_type in should_be_equal: |
| self.assertTrue(type_matches(sig_type, arg_type)) |
| |
| should_fail = [ |
| (int, float), |
| (Union[int, float], str), |
| (List[torch.Tensor], List[int]), |
| ] |
| |
| for sig_type, arg_type in should_fail: |
| self.assertFalse(type_matches(sig_type, arg_type)) |
| |
| @skipIfNoMkldnn |
| def test_optimize_for_inference_cpu(self): |
| import torch.nn as nn |
| |
| class Foo(nn.Module): |
| def __init__(self) -> None: |
| super().__init__() |
| layers = [] |
| layers2 = [] |
| for _ in range(10): |
| layers.append(nn.Conv2d(3, 3, 1)) |
| layers.append(nn.BatchNorm2d(3)) |
| layers.append(nn.ReLU()) |
| |
| layers2.append(nn.Conv2d(3, 3, 1)) |
| layers2.append(nn.BatchNorm2d(3)) |
| layers2.append(nn.ReLU()) |
| self.model = nn.Sequential(*layers) |
| self.model2 = nn.Sequential(*layers2) |
| |
| def forward(self, x): |
| return self.model(x) + self.model2(x) |
| |
| N, C, H, W, = ( |
| 1, |
| 3, |
| 224, |
| 224, |
| ) |
| inp = torch.randn(N, C, H, W) |
| with torch.no_grad(): |
| model = Foo().eval() |
| optimized_model = optimization.optimize_for_inference(model) |
| torch.testing.assert_close(model(inp), optimized_model(inp)) |
| |
| optimized_model2 = optimization.optimize_for_inference( |
| model, pass_config={"remove_dropout": False} |
| ) |
| torch.testing.assert_close(model(inp), optimized_model2(inp)) |
| |
| @skipIfNoTorchVision |
| @skipIfNoMkldnn |
| def test_optimize_for_inference_cpu_torchvision(self): |
| models = [ |
| torchvision.models.resnet18, |
| torchvision.models.resnet50, |
| torchvision.models.densenet121, |
| torchvision.models.shufflenet_v2_x1_0, |
| torchvision.models.vgg16, |
| torchvision.models.mobilenet_v2, |
| torchvision.models.mnasnet1_0, |
| torchvision.models.resnext50_32x4d, |
| ] |
| with torch.no_grad(): |
| for model_type in models: |
| model = model_type() |
| C, H, W, = ( |
| 3, |
| 224, |
| 224, |
| ) |
| inp = torch.randn(3, C, H, W) |
| model(inp) |
| model.eval() |
| inp = torch.randn(1, C, H, W) |
| heuristic = optimization.gen_mkl_autotuner(inp, iters=0, warmup=0) |
| optimized_model = optimization.optimize_for_inference(model) |
| |
| orig_out = model(inp) |
| new_out = optimized_model(inp) |
| torch.testing.assert_close(orig_out, new_out) |
| |
| |
| class TestNormalizeOperators(JitTestCase): |
| @onlyCPU |
| @ops(op_db, allowed_dtypes=(torch.float,)) |
| def test_normalize_operator_exhaustive(self, device, dtype, op): |
| # These ops currently don't trace in FX for various reasons (i.e. they take a list of tensors) |
| fx_fail = {"cat", "stack", "hstack", "vstack", "dstack", "linalg.multi_dot", "_upsample_bilinear2d_aa", "_chunk_cat"} |
| sample_inputs_itr = op.sample_inputs(device, dtype, requires_grad=False) |
| if isinstance(op.op, torch._ops.OpOverload): |
| self.skipTest("normalize operator doesn't work on torch.ops") |
| for sample_input in sample_inputs_itr: |
| unsupported_arg_type = False |
| arg_values = [sample_input.input] + list(sample_input.args) |
| kwarg_values = sample_input.kwargs |
| arg_types = [] |
| kwarg_types = {} |
| |
| def jit_infer_type(v): |
| inferred_arg_type = torch._C._jit_try_infer_type(v) |
| assert inferred_arg_type.success() |
| t = _torchscript_type_to_python_type(inferred_arg_type.type()) |
| return t |
| |
| for v in arg_values: |
| if isinstance(v, torch.Tensor): |
| arg_types.append(type(v)) |
| else: |
| if isinstance(v, complex): |
| # Complex type not supported in FX |
| unsupported_arg_type = True |
| arg_types.append(jit_infer_type(v)) |
| |
| for k, v in kwarg_values.items(): |
| if isinstance(v, torch.Tensor): |
| kwarg_types[k] = type(v) |
| else: |
| if isinstance(v, complex): |
| # Complex type not supported in FX |
| unsupported_arg_type = True |
| kwarg_types[k] = jit_infer_type(v) |
| |
| if unsupported_arg_type: |
| continue |
| # Test normalize_function by itself |
| ref_out = op.op(*arg_values, **kwarg_values) |
| norm_args_and_kwargs = normalize_function( |
| op.op, arg_values, kwarg_values, arg_types, kwarg_types |
| ) |
| if norm_args_and_kwargs is None: |
| raise RuntimeError( |
| """ |
| FX failed to normalize op - add the op to the op_skip list. |
| A common reason is if your OpInfo was implemented with a lambda |
| - otherwise, file an issue |
| """ |
| ) |
| test_out = op.op(*norm_args_and_kwargs.args, **norm_args_and_kwargs.kwargs) |
| self.assertEqual(test_out, ref_out) |
| |
| # Test normalized_arguments as part of FX |
| if op.name in fx_fail: |
| continue |
| param_names = [] |
| param_values = [] |
| fx_args = [] |
| |
| idx = 0 |
| |
| def process_arg(arg, name): |
| if isinstance(arg, torch.Tensor): |
| param_names.append(name) |
| param_values.append(arg) |
| return name |
| else: |
| return f"{repr(arg)}" |
| |
| def process_arg_with_idx(arg): |
| nonlocal idx |
| res = process_arg(arg, f"arg_{idx}") |
| idx = idx + 1 |
| return res |
| |
| def str_arg(arg): |
| if isinstance(arg, tuple): |
| args = [f"{str_arg(v)}, " for v in arg] |
| return f"({' '.join(args)})" |
| elif isinstance(arg, list): |
| args = [f"{str_arg(v)}" for v in arg] |
| return f"[{', '.join(args)}]" |
| else: |
| return arg |
| |
| for v in arg_values: |
| arg = pytree.tree_map(process_arg_with_idx, v) |
| fx_args.append(str_arg(arg)) |
| |
| for k, v in kwarg_values.items(): |
| arg = pytree.tree_map(functools.partial(process_arg, name=k), v) |
| fx_args.append(f"{k} = {str_arg(arg)}") |
| |
| code = f""" |
| class TestModule(torch.nn.Module): |
| def forward(self, {', '.join(param_names)}): |
| return torch.{op.name}({', '.join(fx_args)}) |
| """ |
| |
| g = {"torch": torch, "inf": math.inf} |
| exec(code, g) |
| TestModule = g["TestModule"] |
| |
| m = TestModule() |
| traced = torch.fx.symbolic_trace(m) |
| ref_out = traced(*param_values) |
| |
| for node in traced.graph.nodes: |
| if node.op == "call_function": |
| normalized_args = node.normalized_arguments( |
| traced, arg_types, kwarg_types |
| ) |
| assert normalized_args |
| node.args = normalized_args.args |
| node.kwargs = normalized_args.kwargs |
| traced.recompile() |
| |
| test_out = traced(*param_values) |
| self.assertEqual(test_out, ref_out) |
| |
| def test_normalize_quantized_eb(self): |
| target = torch.ops.quantized.embedding_bag_byte_rowwise_offsets |
| args = ( |
| torch.empty((2, 3), dtype=torch.uint8), |
| torch.empty((2,), dtype=torch.int64), |
| torch.empty((2,), dtype=torch.int64), |
| ) |
| norm_args_and_kwargs = normalize_function( |
| target, args, normalize_to_only_use_kwargs=True |
| ) |
| self.assertTrue(norm_args_and_kwargs is not None) |
| self.assertEqual( |
| set(norm_args_and_kwargs.kwargs.keys()), |
| { |
| "weight", |
| "indices", |
| "offsets", |
| "scale_grad_by_freq", |
| "mode", |
| "pruned_weights", |
| "per_sample_weights", |
| "compressed_indices_mapping", |
| "include_last_offset", |
| }, |
| ) |
| self.assertEqual(norm_args_and_kwargs.args, ()) |
| |
| def test_normalize_args_op_overload(self): |
| for target in [torch.ops.aten.resize_as_.default, torch.ops.aten.resize_as_]: |
| inp1 = torch.rand([1]) |
| inp2 = torch.rand([4]) |
| args, kwargs = normalize_function(target, (inp1,), {"the_template": inp2}, normalize_to_only_use_kwargs=True) |
| self.assertIs(kwargs["input"], inp1) |
| self.assertIs(kwargs["the_template"], inp2) |
| |
| |
| if TEST_Z3: |
| import z3 |
| |
| import torch._dynamo.config |
| |
| from torch.fx.experimental.validator import SympyToZ3, TranslationValidator, ValidationException, z3str |
| from torch.utils._sympy.functions import FloorDiv, Mod |
| |
| class TestTranslationValidation(TestCase): |
| def _prepare_for_translation_validation(self): |
| validator = TranslationValidator() |
| |
| # SymPy symbols. |
| s0, s1, s2 = sympy.symbols("s0 s1 s2", integer=True) |
| |
| # Z3 symbols. |
| [validator.add_var(s, int) for s in (s0, s1, s2)] |
| z0, z1, z2 = (validator.z3var(s) for s in (s0, s1, s2)) |
| |
| return (s0, s1, s2), (z0, z1, z2), validator |
| |
| def test_sympy_to_z3(self): |
| |
| ( |
| (s0, s1, s2), |
| (z0, z1, z2), |
| validator, |
| ) = self._prepare_for_translation_validation() |
| |
| test_cases = [ |
| # Integer constants. |
| (sympy.S.Zero, z3.IntVal(0)), |
| (sympy.S.One, z3.IntVal(1)), |
| (sympy.S.NegativeOne, z3.IntVal(-1)), |
| (sympy.Integer(2), z3.IntVal(2)), |
| ( |
| s0, |
| z0, |
| ), |
| # Arithmetic operations. |
| *[ |
| (op(s0, s1), op(z0, z1)) |
| for op in ( |
| operator.add, |
| operator.mul, |
| operator.pow, |
| ) |
| ], |
| # Logical operations. |
| *[ |
| (sympy_op(s0, s1), z3_op(z0, z1)) |
| for sympy_op, z3_op in ( |
| (sympy.Eq, operator.eq), |
| (sympy.Ne, operator.ne), |
| (sympy.Lt, operator.lt), |
| (sympy.Le, operator.le), |
| (sympy.Gt, operator.gt), |
| (sympy.Ge, operator.ge), |
| ) |
| ], |
| # Other operations. |
| ( |
| s0 - s1, |
| z0 + z3.IntVal(-1) * z1, |
| ), |
| ( |
| s0 / s1, |
| z3.ToReal(z0) * (z1**-1), |
| ), |
| (FloorDiv(s0, s1), z3.ToInt(z3.ToReal(z0) / z3.ToReal(z1))), |
| (Mod(s0, s1), z0 - z3.ToInt(z3.ToReal(z0) / z3.ToReal(z1)) * z1), |
| ( |
| Mod(s2, (s0 / s1)), |
| z2 |
| - z3.ToReal(z3.ToInt(z3.ToReal(z2) / (z3.ToReal(z0) * z1**-1))) |
| * (z3.ToReal(z0) * z1**-1), |
| ), |
| ( |
| Mod(s2, s0**3), |
| z2 - z3.ToReal(z3.ToInt(z3.ToReal(z2) / z0**3)) * z0**3, |
| ), |
| ] |
| |
| toZ3 = SympyToZ3(validator) |
| for sympy_expr, z3_expr in test_cases: |
| result = toZ3.run(sympy_expr) |
| self.assertTrue( |
| z3_expr.eq(result), msg=f"expected: {z3_expr}. Got: {result}" |
| ) |
| |
| def test_sat(self): |
| ( |
| (s0, s1, s2), |
| (z0, z1, z2), |
| validator, |
| ) = self._prepare_for_translation_validation() |
| |
| validator.add_source_expr(z0 > 5) |
| validator.add_source_expr(z1 / 2 > z0) |
| |
| # Solutions for target is a subset of the solutions for the source. |
| validator.add_target_expr(s0 > 20) |
| validator.add_target_expr(s1 > s0**2) |
| |
| validator.validate() |
| |
| def test_unsat(self): |
| ( |
| (s0, s1, s2), |
| (z0, z1, z2), |
| validator, |
| ) = self._prepare_for_translation_validation() |
| |
| validator.add_source_expr(z0 > 5) |
| validator.add_source_expr(z1 / 2 > z0) |
| |
| # Solutions for target is NOT a subset of the solutions for the source. |
| validator.add_target_expr(s0 > 20) |
| # This expression is less restrictive than its counterpart. |
| validator.add_target_expr(s1 > s0 + 2) |
| |
| with self.assertRaisesRegex(ValidationException, "translation validation failed."): |
| validator.validate() |
| |
| def test_z3str(self): |
| a = z3.Int("a") |
| b = z3.Int("b") |
| special = z3.Real("this.size()[2]") |
| |
| test_cases = [ |
| (z3.IntVal(42), "42"), |
| # Variable. |
| (a, "a"), |
| # Name with special characters. |
| (special, "this.size()[2]"), |
| # Renamed function fpplications. |
| (a != b, "(!= a b)"), |
| (a ** b, "(pow a b)"), |
| # Chain of associative operations. |
| *[ |
| (op(op(a, 5), b), f"({opstr} 5 a b)") |
| for op, opstr in [ |
| (operator.add, "+"), |
| (operator.mul, "*") |
| ] |
| ], |
| # Revert 'Not' conversions. |
| (a != b, "(!= a b)"), |
| (a < b, "(> b a)"), |
| (a > b, "(> a b)"), |
| # Ignore 'ToInt' and 'ToReal' functions. |
| (z3.ToInt(special) + a, "(+ this.size()[2] a)"), |
| (z3.ToReal(a + b), "(+ a b)"), |
| # Convert to floor division: 'idiv'. |
| (z3.ToInt(z3.ToReal(a) / z3.ToReal(b)), "(idiv a b)"), |
| ] |
| |
| for expr, expected in test_cases: |
| self.assertEqual(z3str(expr), expected) |
| |
| |
| instantiate_device_type_tests(TestNormalizeOperators, globals()) |
| |
| if __name__ == "__main__": |
| run_tests() |