blob: b43b8e13250aa4b6c711d9e14ced329341fd0f4f [file] [log] [blame]
# Owner(s): ["module: fx"]
import os
import sys
import torch
from torch.fx import symbolic_trace
from torch.fx.experimental.proxy_tensor import make_fx
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir)
from torch.fx.passes.utils.matcher_utils import SubgraphMatcher
from torch.testing._internal.jit_utils import JitTestCase
class TestMatcher(JitTestCase):
def test_subgraph_matcher_with_attributes(self):
class LargeModel(torch.nn.Module):
def __init__(self):
super().__init__()
self._weight = torch.nn.Parameter(torch.ones(3, 3))
self._bias = torch.nn.Parameter(torch.ones(3, 3))
def forward(self, x):
return torch.ops.aten.addmm.default(self._bias, x, self._weight)
# Large Model graph:
# opcode name target args kwargs
# ------------- ------------- ------------------ ------------------- --------
# placeholder x x () {}
# get_attr _bias _bias () {}
# get_attr _weight _weight () {}
# call_function addmm_default aten.addmm.default (_bias, x, _weight) {}
# output output output (addmm_default,) {}
large_model_graph = symbolic_trace(LargeModel()).graph
class PatternModel(torch.nn.Module):
def __init__(self):
super().__init__()
self._weight_1 = torch.nn.Parameter(torch.ones(5, 5))
self._bias_1 = torch.nn.Parameter(torch.ones(5, 5))
def forward(self, x):
return torch.ops.aten.addmm.default(self._bias_1, x, self._weight_1)
pattern_graph = torch.fx.symbolic_trace(PatternModel()).graph
subgraph_matcher = SubgraphMatcher(pattern_graph)
match_result = subgraph_matcher.match(large_model_graph)
self.assertEqual(len(match_result), 1)
def test_subgraph_matcher_with_list(self):
def original(x, y):
return torch.ops.aten.view(x, [5, y.shape[0]])
original_graph = torch.fx.symbolic_trace(original).graph
def pattern(x, y, z):
return torch.ops.aten.view(x, [z, y.shape[0]])
pattern_graph = torch.fx.symbolic_trace(pattern).graph
subgraph_matcher = SubgraphMatcher(pattern_graph)
match_result = subgraph_matcher.match(original_graph)
self.assertEqual(len(match_result), 1)
def test_subgraph_matcher_with_list_bad(self):
def original(x, y):
return torch.ops.aten._reshape_alias_copy.default(x, [1, y.shape[0]], [y.shape[1], y.shape[1]])
original_graph = torch.fx.symbolic_trace(original).graph
def pattern(x, y, b):
return torch.ops.aten._reshape_alias_copy.default(x, [b, y.shape[0], y.shape[1]], [y.shape[1]])
pattern_graph = torch.fx.symbolic_trace(pattern).graph
subgraph_matcher = SubgraphMatcher(pattern_graph)
match_result = subgraph_matcher.match(original_graph)
self.assertEqual(len(match_result), 0)
def test_subgraph_matcher_ignore_literals(self):
def original(x):
return x + 1
original_graph = make_fx(original)(torch.ones(3, 3)).graph
original_graph.eliminate_dead_code()
def pattern(x):
return x + 2
pattern_graph = make_fx(pattern)(torch.ones(4, 4)).graph
pattern_graph.eliminate_dead_code()
subgraph_matcher = SubgraphMatcher(pattern_graph)
match_result = subgraph_matcher.match(original_graph)
self.assertEqual(len(match_result), 0)
subgraph_matcher = SubgraphMatcher(pattern_graph, ignore_literals=True)
match_result = subgraph_matcher.match(original_graph)
self.assertEqual(len(match_result), 1)
def test_variatic_arg_matching(self):
inputs = (torch.randn(20, 16, 50, 32),)
def maxpool(x, kernel_size, stride, padding, dilation):
return torch.ops.aten.max_pool2d_with_indices.default(x, kernel_size, stride, padding, dilation)
maxpool_graph = torch.fx.symbolic_trace(maxpool).graph
maxpool_matcher = SubgraphMatcher(maxpool_graph)
match_result = maxpool_matcher.match(maxpool_graph)
self.assertEqual(len(match_result), 1)
# Graph only contains "stride" argument
maxpool_s = torch.nn.MaxPool2d(kernel_size=2, stride=1).eval()
maxpool_s_graph = make_fx(maxpool_s)(*inputs).graph
match_s_result = maxpool_matcher.match(maxpool_s_graph)
self.assertEqual(len(match_s_result), 1)
# Graph only contains "padding" argument
maxpool_p = torch.nn.MaxPool2d(kernel_size=2, padding=1)
maxpool_p_graph = make_fx(maxpool_p)(*inputs).graph
match_p_result = maxpool_matcher.match(maxpool_p_graph)
self.assertEqual(len(match_p_result), 1)
# Graph only contains "stride, padding" argument
maxpool_sp = torch.nn.MaxPool2d(kernel_size=2, stride=1, padding=1)
maxpool_sp_graph = make_fx(maxpool_sp)(*inputs).graph
match_sp_result = maxpool_matcher.match(maxpool_sp_graph)
self.assertEqual(len(match_sp_result), 1)