| # Owner(s): ["oncall: quantization"] |
| # Copied from pytorch/test/fx/test_subgraph_rewriter.py |
| |
| import os |
| import sys |
| |
| import torch |
| from torch.fx import symbolic_trace, subgraph_rewriter |
| from torch.fx.annotate import annotate |
| # Make the helper files in test/ importable |
| from torch.fx.experimental.rewriter import RewritingTracer |
| |
| pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) |
| sys.path.append(pytorch_test_dir) |
| from torch.testing._internal.jit_utils import JitTestCase |
| |
| if __name__ == '__main__': |
| raise RuntimeError("This test file is not meant to be run directly, use:\n\n" |
| "\tpython test/test_fx.py TESTNAME\n\n" |
| "instead.") |
| |
| class TestSubgraphRewriter(JitTestCase): |
| |
| def test_subgraph_rewriter_preserves_logic(self): |
| class M(torch.nn.Module): |
| def forward(self, x): |
| val = torch.neg(x) + torch.relu(x) |
| return torch.add(val, val) |
| |
| def pattern(x): |
| return torch.neg(x) + torch.relu(x) |
| |
| def comparison(x): |
| val = torch.neg(x) + torch.relu(x) |
| return torch.add(val, val) |
| |
| traced = symbolic_trace(M()) |
| comparison_fn = symbolic_trace(comparison) |
| |
| x = torch.rand(1, 3) |
| |
| # Replace `pattern` with the same pattern (shouldn't change |
| # the underlying logic) |
| subgraph_rewriter.replace_pattern(traced, pattern, pattern) |
| |
| traced.graph.lint() |
| |
| ref_output = comparison_fn(x) |
| test_output = traced.forward(x) |
| self.assertEqual(ref_output, test_output) |
| |
| def test_subgraph_rewriter_with_oneliner_pattern(self): |
| class M(torch.nn.Module): |
| def forward(self, x): |
| val = torch.neg(x) |
| return torch.add(val, val) |
| |
| def pattern(x): |
| return torch.neg(x) |
| |
| def replacement(x): |
| return torch.relu(x) |
| |
| def comparison(x): |
| val = torch.relu(x) |
| return torch.add(val, val) |
| |
| traced = symbolic_trace(M()) |
| comparison_fn = symbolic_trace(comparison) |
| |
| x = torch.rand(1, 3) |
| |
| subgraph_rewriter.replace_pattern(traced, pattern, replacement) |
| |
| traced.graph.lint() |
| |
| ref_output = comparison_fn(x) |
| test_output = traced.forward(x) |
| self.assertEqual(ref_output, test_output) |
| |
| def test_subgraph_rewriter_single_pattern_match(self): |
| class M(torch.nn.Module): |
| def forward(self, x): |
| val = torch.neg(x) + torch.relu(x) |
| return torch.add(val, val) |
| |
| def pattern(x): |
| return torch.neg(x) + torch.relu(x) |
| |
| def replacement(x): |
| return torch.relu(x) |
| |
| def comparison(x): |
| val = torch.relu(x) |
| return torch.add(val, val) |
| |
| traced = symbolic_trace(M()) |
| comparison_fn = symbolic_trace(comparison) |
| |
| x = torch.rand(1, 3) |
| |
| subgraph_rewriter.replace_pattern(traced, pattern, replacement) |
| |
| traced.graph.lint() |
| |
| ref_output = comparison_fn(x) |
| test_output = traced.forward(x) |
| self.assertEqual(ref_output, test_output) |
| |
| def test_subgraph_rewriter_multiple_pattern_match(self): |
| class M(torch.nn.Module): |
| def forward(self, x, w1, w2): |
| m1 = torch.cat([w1, w2]).sum() |
| m2 = torch.cat([w1, w2]).sum() |
| return x + torch.max(m1) + torch.max(m2) |
| |
| def pattern(w1, w2): |
| return torch.cat([w1, w2]).sum() |
| |
| def replacement(w1, w2): |
| return torch.stack([w1, w2]) |
| |
| def comparison(x, w1, w2): |
| m1 = torch.stack([w1, w2]) |
| m2 = torch.stack([w1, w2]) |
| return x + torch.max(m1) + torch.max(m2) |
| |
| traced = symbolic_trace(M()) |
| comparison_fn = symbolic_trace(comparison) |
| |
| x = torch.rand(1, 3) |
| w1 = torch.rand(1, 3) |
| w2 = torch.rand(1, 3) |
| |
| subgraph_rewriter.replace_pattern(traced, pattern, replacement) |
| |
| traced.graph.lint() |
| |
| ref_outs = comparison_fn(x, w1, w2) |
| test_outs = traced.forward(x, w1, w2) |
| self.assertEqual(ref_outs, test_outs) |
| |
| def test_subgraph_rewriter_graph_argument_order(self): |
| class M(torch.nn.Module): |
| def forward(self, x, y): |
| return torch.mm(x, y) |
| |
| def pattern(x, y): |
| return torch.mm(x, y) |
| |
| def comparison(x, y): |
| return torch.mm(x, y) |
| |
| traced = symbolic_trace(M()) |
| comparison_fn = symbolic_trace(comparison) |
| |
| x = torch.randn(3, 4) |
| y = torch.randn(4, 5) |
| |
| subgraph_rewriter.replace_pattern(traced, pattern, pattern) |
| |
| traced.graph.lint() |
| |
| ref_outs = comparison_fn(x, y) |
| test_outs = traced.forward(x, y) |
| self.assertEqual(ref_outs, test_outs) |
| |
| def test_subgraph_rewriter_correct_output_replacement(self): |
| class M(torch.nn.Module): |
| def forward(self, x, y): |
| val = torch.neg(y) + torch.relu(x) |
| return torch.add(val, val) |
| |
| def pattern(x): |
| return torch.relu(x) |
| |
| def replacement(x): |
| return torch.neg(x) |
| |
| def comparison(x, y): |
| val = torch.neg(y) + torch.neg(x) |
| return torch.add(val, val) |
| |
| traced = symbolic_trace(M()) |
| comparison_fn = symbolic_trace(comparison) |
| |
| x = torch.randn(4, 4) |
| y = torch.randn(4, 4) |
| |
| subgraph_rewriter.replace_pattern(traced, pattern, replacement) |
| |
| traced.graph.lint() |
| |
| ref_outs = comparison_fn(x, y) |
| test_outs = traced.forward(x, y) |
| self.assertEqual(ref_outs, test_outs) |
| |
| def test_subgraph_rewriter_traced_as_callable(self): |
| class M(torch.nn.Module): |
| def forward(self, x): |
| val = torch.neg(x) + torch.relu(x) |
| return torch.add(val, val) |
| |
| class Pattern(torch.nn.Module): |
| def forward(self, x): |
| return torch.neg(x) + torch.relu(x) |
| |
| class Replacement(torch.nn.Module): |
| def forward(self, x): |
| return torch.sigmoid(x) |
| |
| def comparison(x): |
| val = torch.sigmoid(x) |
| return torch.add(val, val) |
| |
| traced = symbolic_trace(M()) |
| traced_pattern = symbolic_trace(Pattern()) |
| traced_replacement = symbolic_trace(Replacement()) |
| comparison_fn = symbolic_trace(comparison) |
| |
| x = torch.randn(3, 4) |
| |
| subgraph_rewriter.replace_pattern(traced, traced_pattern, traced_replacement) |
| |
| traced.graph.lint() |
| |
| ref_outs = comparison_fn(x) |
| test_outs = traced.forward(x) |
| self.assertEqual(ref_outs, test_outs) |
| |
| def test_subgraph_rewriter_pattern_is_entire_graph(self): |
| class M(torch.nn.Module): |
| def forward(self, x): |
| a = torch.neg(x) |
| return torch.add(a, a) |
| |
| def pattern(x): |
| a = torch.neg(x) |
| return torch.add(a, a) |
| |
| def replacement(x): |
| a = torch.sigmoid(x) |
| return torch.cat([a, a]) |
| |
| traced = symbolic_trace(M()) |
| comparison_fn = symbolic_trace(replacement) |
| |
| x = torch.randn(3, 4) |
| |
| subgraph_rewriter.replace_pattern(traced, pattern, replacement) |
| |
| traced.graph.lint() |
| |
| ref_outs = comparison_fn(x) |
| test_outs = traced.forward(x) |
| self.assertEqual(ref_outs, test_outs) |
| |
| def test_subgraph_rewriter_pattern_output_pattern_node_can_have_users_that_are_not_matched(self): |
| class M(torch.nn.Module): |
| def forward(self, x): |
| y = torch.relu(x) |
| return torch.neg(y) - y |
| |
| def pattern(x): |
| return torch.relu(x) |
| |
| def replacement(x): |
| return torch.sigmoid(x) |
| |
| def comparison(x): |
| y = torch.sigmoid(x) |
| return torch.neg(y) - y |
| |
| traced = symbolic_trace(M()) |
| comparison_fn = symbolic_trace(comparison) |
| |
| x = torch.randn(3, 4) |
| |
| subgraph_rewriter.replace_pattern(traced, pattern, replacement) |
| |
| traced.graph.lint() |
| |
| ref_outs = comparison_fn(x) |
| test_outs = traced.forward(x) |
| self.assertEqual(ref_outs, test_outs) |
| |
| def test_subgraph_rewriter_internal_pattern_nodes_cannot_have_users_that_are_not_matched(self): |
| class M(torch.nn.Module): |
| def forward(self, x, w1, w2, b1, b2): |
| m0 = torch.cat([w1, w2]) |
| m1 = torch.cat([w1, w2]) |
| m2 = torch.cat([x, b2]) |
| t0 = torch.addmm(b1, m1, m2.t()) |
| t1 = torch.sum(w1, 1) |
| t2 = torch.addmm(b1, m1, m2.t()) |
| return torch.sum(t1), torch.sum(t2) |
| |
| def pattern(x, w1, w2, b1, b2): |
| m1 = torch.cat([w1, w2]) |
| m2 = torch.cat([x, b2]) |
| return torch.addmm(b1, m1, m2.t()) |
| |
| def replacement(x, w1, w2, b1, b2): |
| return torch.cat([x, w1, w2]) |
| |
| traced = symbolic_trace(M()) |
| |
| # Result should be [] since no matches can be found |
| res = subgraph_rewriter.replace_pattern(traced, pattern, replacement) |
| |
| traced.graph.lint() |
| |
| self.assertEqual(res, []) |
| |
| def test_subgraph_rewriter_placeholder_matching(self): |
| """ |
| This tests that a placeholder Node can be matched to a Node with |
| a different number of input Nodes. In the example below, the |
| original traced Module looks like this: |
| opcode target args kwargs |
| ------------- ---------------------------------------------------------- ------------------------ -------- |
| placeholder x () {} |
| call_function <built-in function add> (x, 3) {} |
| call_method dequantize (add,) {} |
| call_function <built-in method sigmoid of type object at 0x7f7c1f440fe0> (dequantize,) {} |
| call_method to (sigmoid, torch.float16) {} |
| output output (to,) {} |
| while the pattern we want to match looks like this: |
| opcode target args kwargs |
| ------------- ---------------------------------------------------------- ------------------------ -------- |
| placeholder x () {} |
| call_method dequantize (x,) {} |
| call_function <built-in method sigmoid of type object at 0x7f7c1f440fe0> (dequantize,) {} |
| call_method to (sigmoid, torch.float16) {} |
| output output (to,) {} |
| Here, we want to be able to match the original graph's |
| `call_function.add` Node with the pattern graph's |
| `plaeholder.x` Node. |
| Credit to Jerry Zhang (GitHub: jerryzh168) for this test case |
| """ |
| class M(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.dtype = torch.float16 |
| |
| def forward(self, x): |
| x += 3 |
| x = x.dequantize() |
| x = torch.sigmoid(x) |
| dtype = self.dtype |
| x = x.to(dtype) |
| return x |
| |
| def pattern(x): |
| x = x.dequantize() |
| x = torch.sigmoid(x) |
| x = x.to(torch.float16) |
| return x |
| |
| def replacement(x): |
| return x |
| |
| def comparison(x): |
| return x + 3 |
| |
| traced = symbolic_trace(M()) |
| comparison_fn = symbolic_trace(comparison) |
| |
| x = torch.randn(3, 4) |
| |
| subgraph_rewriter.replace_pattern(traced, pattern, replacement) |
| |
| traced.graph.lint() |
| |
| ref_outs = comparison_fn(x) |
| test_outs = traced.forward(x) |
| self.assertEqual(ref_outs, test_outs) |
| |
| def test_subgraph_rewriter_replaces_referenced_submodules(self): |
| class M(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.sigmoid = torch.nn.Sigmoid() |
| self.submod = torch.nn.ReLU() |
| |
| def forward(self, x): |
| x = x + 1 |
| return self.submod(self.sigmoid(x)) |
| |
| class Pattern(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.sigmoid = torch.nn.Sigmoid() |
| self.submod = torch.nn.ReLU() |
| |
| def forward(self, x): |
| return self.submod(self.sigmoid(x)) |
| |
| class Replacement(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.id = torch.nn.Identity() |
| self.submod = torch.nn.ReLU() |
| |
| def forward(self, x): |
| return self.submod(self.id(x)) |
| |
| class Comparison(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.id = torch.nn.Identity() |
| self.submod = torch.nn.ReLU() |
| |
| def forward(self, x): |
| x = x + 1 |
| return self.submod(self.id(x)) |
| |
| traced = symbolic_trace(M()) |
| comparison = Comparison() |
| |
| x = torch.randn(3, 4) |
| |
| subgraph_rewriter.replace_pattern(traced, Pattern(), Replacement()) |
| |
| traced.graph.lint() |
| |
| ref_outs = comparison(x) |
| test_outs = traced.forward(x) |
| self.assertEqual(ref_outs, test_outs) |
| |
| traced.get_submodule("id") |
| with self.assertRaisesRegex(AttributeError, "has no attribute"): |
| traced.get_submodule("sigmoid") |
| |
| submod = traced.get_submodule("submod") |
| self.assertEqual(type(submod), torch.nn.ReLU) |
| |
| def test_subgraph_rewriter_annotations_int(self): |
| |
| class M1(torch.nn.Module): |
| def forward(self, x): |
| y: int = x |
| return torch.add(x, y) |
| |
| class M2(torch.nn.Module): |
| def forward(self, x): |
| y = annotate(x, int) |
| return torch.add(x, y) |
| |
| ast_rewriter = RewritingTracer() |
| graph = ast_rewriter.trace(M1()) |
| |
| module = M2() |
| symbolic_traced: torch.fx.GraphModule = symbolic_trace(module) |
| for n, m in zip(symbolic_traced.graph.nodes, graph.nodes): |
| if n.op == 'placeholder': |
| assert n.type == int |
| assert m.type == int |
| |
| def test_subgraph_writer_replace_consecutive_submodules(self): |
| |
| def f(x): |
| x = torch.sigmoid(x) |
| x = torch.sigmoid(x) |
| return torch.sigmoid(x) |
| |
| def pattern(x): |
| return torch.sigmoid(x) |
| |
| def replacement(x): |
| return torch.exp(x) |
| |
| def comparison(x): |
| x = torch.exp(x) |
| x = torch.exp(x) |
| return torch.exp(x) |
| |
| traced = symbolic_trace(f) |
| comparison_fn = symbolic_trace(comparison) |
| |
| x = torch.randn(3, 4) |
| |
| subgraph_rewriter.replace_pattern(traced, pattern, replacement) |
| |
| traced.graph.lint() |
| |
| ref_outs = comparison_fn(x) |
| test_outs = traced.forward(x) |
| self.assertEqual(ref_outs, test_outs) |