blob: 18dcab1952f21609666d37fba87019f93e2ae5fc [file] [log] [blame]
# Owner(s): ["module: inductor"]
from typing import Any, Callable
import torch
from torch._inductor.fx_passes.pre_grad import (
linear_permute_fusion,
linear_transpose,
permute_linear_fusion,
permute_matmul_fusion,
sink_cat_after_pointwise,
transpose_linear,
transpose_matmul,
)
from torch.fx.passes.shape_prop import ShapeProp
from torch.testing._internal.common_utils import run_tests, TestCase
PassFunc = Callable[[torch.fx.GraphModule, Any], torch.fx.GraphModule]
def chain_passes(*passes: PassFunc) -> PassFunc:
def parent_pass(module: torch.fx.GraphModule, input: Any) -> torch.fx.GraphModule:
for pass_ in passes:
if isinstance(module, torch.fx.GraphModule):
ShapeProp(module).propagate(*input)
module = pass_(module)
return module
return parent_pass
def count_call(module: torch.fx.GraphModule, op: str, target_op: Any) -> int:
return sum(
[1 if (n.op == op and n.target == target_op) else 0 for n in module.graph.nodes]
)
def count_call_function(module: torch.fx.GraphModule, target_op: Any) -> int:
return count_call(module, "call_function", target_op)
def count_call_method(module: torch.fx.GraphModule, target_op: Any) -> int:
return count_call(module, "call_method", target_op)
class TestFxFusion(TestCase):
def test_sink_cat_after_pointwise(self):
def test_kwarg(x, y):
return torch.cat([x, y], dim=-1).view(-1).view(128).tanh()
def test_arg(x, y):
return torch.cat([x, y], -1).view(-1).view(128).tanh()
def test_arg2(x, y):
return torch.cat([x, y]).view(-1).view(128).tanh()
def test_kwarg2(x, y):
return torch.cat(tensors=[x, y], dim=0).tanh()
def test_kwarg3(x, y):
return torch.cat(tensors=[x, y], dim=0).view(128).tanh()
trace_func = chain_passes(torch.fx.symbolic_trace, sink_cat_after_pointwise)
inputs = [
torch.randn(8, 8),
torch.randn(8, 8),
]
for f in [test_kwarg, test_arg, test_arg2, test_kwarg2, test_kwarg3]:
traced = trace_func(f, inputs)
self.assertTrue(torch.allclose(f(*inputs), traced(*inputs)))
self.assertEqual(count_call_method(traced, "tanh"), 2)
def test_linear_permute_fusion(self):
class TestModule(torch.nn.Module):
def __init__(self, k: int, n: int, has_bias: bool):
super().__init__()
self.weight = torch.nn.Parameter(torch.randn(n, k))
self.has_bias = has_bias
if has_bias:
self.bias = torch.nn.Parameter(torch.randn(n))
def forward(self, input: torch.Tensor):
if self.has_bias:
a0 = torch.nn.functional.linear(input, self.weight, self.bias)
else:
a0 = torch.nn.functional.linear(input, self.weight)
b0 = a0.permute(0, 2, 1)
return b0
m, k, n = 16, 8, 4
trace_func = chain_passes(torch.fx.symbolic_trace, linear_permute_fusion)
for has_bias in [True, False]:
module = TestModule(k, n, has_bias).eval()
input = torch.randn(6, m, k)
traced = trace_func(module, [input])
num_linear = count_call_function(traced, torch.nn.functional.linear)
num_linear_transpose = count_call_function(traced, linear_transpose)
self.assertEqual(num_linear, 0)
self.assertEqual(num_linear_transpose, 1)
self.assertTrue(torch.allclose(module(input), traced(input)))
def test_permute_linear_fusion(self):
class TestModule(torch.nn.Module):
def __init__(self, k: int, n: int, has_bias: bool):
super().__init__()
self.weight = torch.nn.Parameter(torch.randn(n, k))
self.has_bias = has_bias
if has_bias:
self.bias = torch.nn.Parameter(torch.randn(n))
def forward(self, input: torch.Tensor):
input1 = input.permute(0, 2, 1)
if self.has_bias:
return torch.nn.functional.linear(input1, self.weight, self.bias)
return torch.nn.functional.linear(input1, self.weight)
m, k, n = 16, 8, 4
trace_func = chain_passes(torch.fx.symbolic_trace, permute_linear_fusion)
for has_bias in [True, False]:
module = TestModule(k, n, has_bias).eval()
input = torch.randn(6, k, m)
traced = trace_func(module, [input])
num_linear = count_call_function(traced, torch.nn.functional.linear)
num_transpose_linear = count_call_function(traced, transpose_linear)
self.assertEqual(num_linear, 0)
self.assertEqual(num_transpose_linear, 1)
self.assertTrue(torch.allclose(module(input), traced(input)))
def test_permute_bmm_fusion(self):
class TestModule(torch.nn.Module):
def __init__(self, batch: int, k: int, n: int):
super().__init__()
self.other = torch.randn(batch, k, n)
def forward(self, input: torch.Tensor):
input1 = input.permute(0, 2, 1)
output = torch.bmm(input1, self.other)
return output
batch, m, k, n = 6, 16, 8, 4
trace_func = chain_passes(torch.fx.symbolic_trace, permute_matmul_fusion)
module = TestModule(batch, k, n).eval()
input = torch.randn(batch, k, m)
traced = trace_func(module, [input])
num_bmm = count_call_function(traced, torch.bmm)
num_transpose_matmul = count_call_function(traced, transpose_matmul)
self.assertEqual(num_bmm, 0)
self.assertEqual(num_transpose_matmul, 1)
self.assertTrue(torch.allclose(module(input), traced(input)))
if __name__ == "__main__":
run_tests()