blob: eb4346a0f2413a13774784cf6a61beea3ca64490 [file] [log] [blame]
# Owner(s): ["oncall: quantization"]
# Copied from pytorch/test/fx/test_subgraph_rewriter.py
import os
import sys
import torch
from torch.fx import symbolic_trace, subgraph_rewriter
from torch.fx.annotate import annotate
# Make the helper files in test/ importable
from torch.fx.experimental.rewriter import RewritingTracer
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir)
from torch.testing._internal.jit_utils import JitTestCase
if __name__ == '__main__':
raise RuntimeError("This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_fx.py TESTNAME\n\n"
"instead.")
class TestSubgraphRewriter(JitTestCase):
def test_subgraph_rewriter_preserves_logic(self):
class M(torch.nn.Module):
def forward(self, x):
val = torch.neg(x) + torch.relu(x)
return torch.add(val, val)
def pattern(x):
return torch.neg(x) + torch.relu(x)
def comparison(x):
val = torch.neg(x) + torch.relu(x)
return torch.add(val, val)
traced = symbolic_trace(M())
comparison_fn = symbolic_trace(comparison)
x = torch.rand(1, 3)
# Replace `pattern` with the same pattern (shouldn't change
# the underlying logic)
subgraph_rewriter.replace_pattern(traced, pattern, pattern)
traced.graph.lint()
ref_output = comparison_fn(x)
test_output = traced.forward(x)
self.assertEqual(ref_output, test_output)
def test_subgraph_rewriter_with_oneliner_pattern(self):
class M(torch.nn.Module):
def forward(self, x):
val = torch.neg(x)
return torch.add(val, val)
def pattern(x):
return torch.neg(x)
def replacement(x):
return torch.relu(x)
def comparison(x):
val = torch.relu(x)
return torch.add(val, val)
traced = symbolic_trace(M())
comparison_fn = symbolic_trace(comparison)
x = torch.rand(1, 3)
subgraph_rewriter.replace_pattern(traced, pattern, replacement)
traced.graph.lint()
ref_output = comparison_fn(x)
test_output = traced.forward(x)
self.assertEqual(ref_output, test_output)
def test_subgraph_rewriter_single_pattern_match(self):
class M(torch.nn.Module):
def forward(self, x):
val = torch.neg(x) + torch.relu(x)
return torch.add(val, val)
def pattern(x):
return torch.neg(x) + torch.relu(x)
def replacement(x):
return torch.relu(x)
def comparison(x):
val = torch.relu(x)
return torch.add(val, val)
traced = symbolic_trace(M())
comparison_fn = symbolic_trace(comparison)
x = torch.rand(1, 3)
subgraph_rewriter.replace_pattern(traced, pattern, replacement)
traced.graph.lint()
ref_output = comparison_fn(x)
test_output = traced.forward(x)
self.assertEqual(ref_output, test_output)
def test_subgraph_rewriter_multiple_pattern_match(self):
class M(torch.nn.Module):
def forward(self, x, w1, w2):
m1 = torch.cat([w1, w2]).sum()
m2 = torch.cat([w1, w2]).sum()
return x + torch.max(m1) + torch.max(m2)
def pattern(w1, w2):
return torch.cat([w1, w2]).sum()
def replacement(w1, w2):
return torch.stack([w1, w2])
def comparison(x, w1, w2):
m1 = torch.stack([w1, w2])
m2 = torch.stack([w1, w2])
return x + torch.max(m1) + torch.max(m2)
traced = symbolic_trace(M())
comparison_fn = symbolic_trace(comparison)
x = torch.rand(1, 3)
w1 = torch.rand(1, 3)
w2 = torch.rand(1, 3)
subgraph_rewriter.replace_pattern(traced, pattern, replacement)
traced.graph.lint()
ref_outs = comparison_fn(x, w1, w2)
test_outs = traced.forward(x, w1, w2)
self.assertEqual(ref_outs, test_outs)
def test_subgraph_rewriter_graph_argument_order(self):
class M(torch.nn.Module):
def forward(self, x, y):
return torch.mm(x, y)
def pattern(x, y):
return torch.mm(x, y)
def comparison(x, y):
return torch.mm(x, y)
traced = symbolic_trace(M())
comparison_fn = symbolic_trace(comparison)
x = torch.randn(3, 4)
y = torch.randn(4, 5)
subgraph_rewriter.replace_pattern(traced, pattern, pattern)
traced.graph.lint()
ref_outs = comparison_fn(x, y)
test_outs = traced.forward(x, y)
self.assertEqual(ref_outs, test_outs)
def test_subgraph_rewriter_correct_output_replacement(self):
class M(torch.nn.Module):
def forward(self, x, y):
val = torch.neg(y) + torch.relu(x)
return torch.add(val, val)
def pattern(x):
return torch.relu(x)
def replacement(x):
return torch.neg(x)
def comparison(x, y):
val = torch.neg(y) + torch.neg(x)
return torch.add(val, val)
traced = symbolic_trace(M())
comparison_fn = symbolic_trace(comparison)
x = torch.randn(4, 4)
y = torch.randn(4, 4)
subgraph_rewriter.replace_pattern(traced, pattern, replacement)
traced.graph.lint()
ref_outs = comparison_fn(x, y)
test_outs = traced.forward(x, y)
self.assertEqual(ref_outs, test_outs)
def test_subgraph_rewriter_traced_as_callable(self):
class M(torch.nn.Module):
def forward(self, x):
val = torch.neg(x) + torch.relu(x)
return torch.add(val, val)
class Pattern(torch.nn.Module):
def forward(self, x):
return torch.neg(x) + torch.relu(x)
class Replacement(torch.nn.Module):
def forward(self, x):
return torch.sigmoid(x)
def comparison(x):
val = torch.sigmoid(x)
return torch.add(val, val)
traced = symbolic_trace(M())
traced_pattern = symbolic_trace(Pattern())
traced_replacement = symbolic_trace(Replacement())
comparison_fn = symbolic_trace(comparison)
x = torch.randn(3, 4)
subgraph_rewriter.replace_pattern(traced, traced_pattern, traced_replacement)
traced.graph.lint()
ref_outs = comparison_fn(x)
test_outs = traced.forward(x)
self.assertEqual(ref_outs, test_outs)
def test_subgraph_rewriter_pattern_is_entire_graph(self):
class M(torch.nn.Module):
def forward(self, x):
a = torch.neg(x)
return torch.add(a, a)
def pattern(x):
a = torch.neg(x)
return torch.add(a, a)
def replacement(x):
a = torch.sigmoid(x)
return torch.cat([a, a])
traced = symbolic_trace(M())
comparison_fn = symbolic_trace(replacement)
x = torch.randn(3, 4)
subgraph_rewriter.replace_pattern(traced, pattern, replacement)
traced.graph.lint()
ref_outs = comparison_fn(x)
test_outs = traced.forward(x)
self.assertEqual(ref_outs, test_outs)
def test_subgraph_rewriter_pattern_output_pattern_node_can_have_users_that_are_not_matched(self):
class M(torch.nn.Module):
def forward(self, x):
y = torch.relu(x)
return torch.neg(y) - y
def pattern(x):
return torch.relu(x)
def replacement(x):
return torch.sigmoid(x)
def comparison(x):
y = torch.sigmoid(x)
return torch.neg(y) - y
traced = symbolic_trace(M())
comparison_fn = symbolic_trace(comparison)
x = torch.randn(3, 4)
subgraph_rewriter.replace_pattern(traced, pattern, replacement)
traced.graph.lint()
ref_outs = comparison_fn(x)
test_outs = traced.forward(x)
self.assertEqual(ref_outs, test_outs)
def test_subgraph_rewriter_internal_pattern_nodes_cannot_have_users_that_are_not_matched(self):
class M(torch.nn.Module):
def forward(self, x, w1, w2, b1, b2):
m0 = torch.cat([w1, w2])
m1 = torch.cat([w1, w2])
m2 = torch.cat([x, b2])
t0 = torch.addmm(b1, m1, m2.t())
t1 = torch.sum(w1, 1)
t2 = torch.addmm(b1, m1, m2.t())
return torch.sum(t1), torch.sum(t2)
def pattern(x, w1, w2, b1, b2):
m1 = torch.cat([w1, w2])
m2 = torch.cat([x, b2])
return torch.addmm(b1, m1, m2.t())
def replacement(x, w1, w2, b1, b2):
return torch.cat([x, w1, w2])
traced = symbolic_trace(M())
# Result should be [] since no matches can be found
res = subgraph_rewriter.replace_pattern(traced, pattern, replacement)
traced.graph.lint()
self.assertEqual(res, [])
def test_subgraph_rewriter_placeholder_matching(self):
"""
This tests that a placeholder Node can be matched to a Node with
a different number of input Nodes. In the example below, the
original traced Module looks like this:
opcode target args kwargs
------------- ---------------------------------------------------------- ------------------------ --------
placeholder x () {}
call_function <built-in function add> (x, 3) {}
call_method dequantize (add,) {}
call_function <built-in method sigmoid of type object at 0x7f7c1f440fe0> (dequantize,) {}
call_method to (sigmoid, torch.float16) {}
output output (to,) {}
while the pattern we want to match looks like this:
opcode target args kwargs
------------- ---------------------------------------------------------- ------------------------ --------
placeholder x () {}
call_method dequantize (x,) {}
call_function <built-in method sigmoid of type object at 0x7f7c1f440fe0> (dequantize,) {}
call_method to (sigmoid, torch.float16) {}
output output (to,) {}
Here, we want to be able to match the original graph's
`call_function.add` Node with the pattern graph's
`plaeholder.x` Node.
Credit to Jerry Zhang (GitHub: jerryzh168) for this test case
"""
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.dtype = torch.float16
def forward(self, x):
x += 3
x = x.dequantize()
x = torch.sigmoid(x)
dtype = self.dtype
x = x.to(dtype)
return x
def pattern(x):
x = x.dequantize()
x = torch.sigmoid(x)
x = x.to(torch.float16)
return x
def replacement(x):
return x
def comparison(x):
return x + 3
traced = symbolic_trace(M())
comparison_fn = symbolic_trace(comparison)
x = torch.randn(3, 4)
subgraph_rewriter.replace_pattern(traced, pattern, replacement)
traced.graph.lint()
ref_outs = comparison_fn(x)
test_outs = traced.forward(x)
self.assertEqual(ref_outs, test_outs)
def test_subgraph_rewriter_replaces_referenced_submodules(self):
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.sigmoid = torch.nn.Sigmoid()
self.submod = torch.nn.ReLU()
def forward(self, x):
x = x + 1
return self.submod(self.sigmoid(x))
class Pattern(torch.nn.Module):
def __init__(self):
super().__init__()
self.sigmoid = torch.nn.Sigmoid()
self.submod = torch.nn.ReLU()
def forward(self, x):
return self.submod(self.sigmoid(x))
class Replacement(torch.nn.Module):
def __init__(self):
super().__init__()
self.id = torch.nn.Identity()
self.submod = torch.nn.ReLU()
def forward(self, x):
return self.submod(self.id(x))
class Comparison(torch.nn.Module):
def __init__(self):
super().__init__()
self.id = torch.nn.Identity()
self.submod = torch.nn.ReLU()
def forward(self, x):
x = x + 1
return self.submod(self.id(x))
traced = symbolic_trace(M())
comparison = Comparison()
x = torch.randn(3, 4)
subgraph_rewriter.replace_pattern(traced, Pattern(), Replacement())
traced.graph.lint()
ref_outs = comparison(x)
test_outs = traced.forward(x)
self.assertEqual(ref_outs, test_outs)
traced.get_submodule("id")
with self.assertRaisesRegex(AttributeError, "has no attribute"):
traced.get_submodule("sigmoid")
submod = traced.get_submodule("submod")
self.assertEqual(type(submod), torch.nn.ReLU)
def test_subgraph_rewriter_annotations_int(self):
class M1(torch.nn.Module):
def forward(self, x):
y: int = x
return torch.add(x, y)
class M2(torch.nn.Module):
def forward(self, x):
y = annotate(x, int)
return torch.add(x, y)
ast_rewriter = RewritingTracer()
graph = ast_rewriter.trace(M1())
module = M2()
symbolic_traced: torch.fx.GraphModule = symbolic_trace(module)
for n, m in zip(symbolic_traced.graph.nodes, graph.nodes):
if n.op == 'placeholder':
assert n.type == int
assert m.type == int
def test_subgraph_writer_replace_consecutive_submodules(self):
def f(x):
x = torch.sigmoid(x)
x = torch.sigmoid(x)
return torch.sigmoid(x)
def pattern(x):
return torch.sigmoid(x)
def replacement(x):
return torch.exp(x)
def comparison(x):
x = torch.exp(x)
x = torch.exp(x)
return torch.exp(x)
traced = symbolic_trace(f)
comparison_fn = symbolic_trace(comparison)
x = torch.randn(3, 4)
subgraph_rewriter.replace_pattern(traced, pattern, replacement)
traced.graph.lint()
ref_outs = comparison_fn(x)
test_outs = traced.forward(x)
self.assertEqual(ref_outs, test_outs)