blob: f9103d61aa9608f0d29b6e11952baf1da53da0ff [file] [log] [blame]
# Owner(s): ["module: fx"]
import copy
import sys
import logging
from typing import List, Tuple
import torch
from torch.fx._symbolic_trace import symbolic_trace
from torch.fx.experimental.proxy_tensor import make_fx
from torch.fx.passes.backends.nvfuser import NvFuserBackend
from torch.testing._internal.common_utils import run_tests, TEST_CUDA, TestCase
from torch.testing._internal.common_device_type import (
instantiate_device_type_tests,
dtypes,
)
if not TEST_CUDA:
print('CUDA not available, skipping tests', file=sys.stderr)
TestCase = object # noqa: F811
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger(__name__)
class HF_T5_Partial(torch.nn.Module):
def inputs_meta(self):
return [
(torch.Size([512, 512]), torch.float32),
(torch.Size([512, 512]), torch.float32),
(torch.Size([512, 512]), torch.float32),
(torch.Size([512, 512]), torch.float32),
(torch.Size([512]), torch.float32),
(torch.Size([2048, 512]), torch.float32),
(torch.Size([512, 2048]), torch.float32),
(torch.Size([512]), torch.float32),
(torch.Size([8, 1024, 512]), torch.float32),
(torch.Size([8, 8, 1024, 1024]), torch.float32),
]
def forward(self, primals_1, primals_2, primals_3, primals_4, primals_5,
primals_6, primals_7, primals_8, primals_9, primals_10):
pow_1 = torch.ops.aten.pow(primals_9, 2)
mean = torch.ops.aten.mean(pow_1, [-1], True)
add = torch.ops.aten.add(mean, 1e-06)
rsqrt = torch.ops.aten.rsqrt(add)
mul = torch.ops.aten.mul(primals_9, rsqrt)
mul_1 = torch.ops.aten.mul(primals_5, mul)
t = torch.ops.aten.t(primals_3)
view = torch.ops.aten.view(mul_1, [8192, 512])
mm = torch.ops.aten.mm(view, t)
_unsafe_view = torch.ops.aten._unsafe_view(mm, [8, 1024, 512])
view_1 = torch.ops.aten.view(_unsafe_view, [8, -1, 8, 64])
transpose = torch.ops.aten.transpose(view_1, 1, 2)
t_1 = torch.ops.aten.t(primals_1)
view_2 = torch.ops.aten.view(mul_1, [8192, 512])
mm_1 = torch.ops.aten.mm(view_2, t_1)
_unsafe_view_1 = torch.ops.aten._unsafe_view(mm_1, [8, 1024, 512])
view_3 = torch.ops.aten.view(_unsafe_view_1, [8, -1, 8, 64])
transpose_1 = torch.ops.aten.transpose(view_3, 1, 2)
t_2 = torch.ops.aten.t(primals_4)
view_4 = torch.ops.aten.view(mul_1, [8192, 512])
mm_2 = torch.ops.aten.mm(view_4, t_2)
_unsafe_view_2 = torch.ops.aten._unsafe_view(mm_2, [8, 1024, 512])
view_5 = torch.ops.aten.view(_unsafe_view_2, [8, -1, 8, 64])
transpose_2 = torch.ops.aten.transpose(view_5, 1, 2)
transpose_3 = torch.ops.aten.transpose(transpose_1, 3, 2)
expand = torch.ops.aten.expand(transpose, [8, 8, 1024, 64])
clone = torch.ops.aten.clone(expand, memory_format=torch.contiguous_format)
_unsafe_view_3 = torch.ops.aten._unsafe_view(clone, [64, 1024, 64])
expand_1 = torch.ops.aten.expand(transpose_3, [8, 8, 64, 1024])
clone_1 = torch.ops.aten.clone(expand_1, memory_format=torch.contiguous_format)
_unsafe_view_4 = torch.ops.aten._unsafe_view(clone_1, [64, 64, 1024])
bmm = torch.ops.aten.bmm(_unsafe_view_3, _unsafe_view_4)
_unsafe_view_5 = torch.ops.aten._unsafe_view(bmm, [8, 8, 1024, 1024])
add_ = torch.ops.aten.add_(_unsafe_view_5, primals_10)
_softmax = torch.ops.aten._softmax(add_, -1, False)
expand_2 = torch.ops.aten.expand(_softmax, [8, 8, 1024, 1024])
view_6 = torch.ops.aten.view(expand_2, [64, 1024, 1024])
expand_3 = torch.ops.aten.expand(transpose_2, [8, 8, 1024, 64])
clone_2 = torch.ops.aten.clone(expand_3, memory_format=torch.contiguous_format)
_unsafe_view_6 = torch.ops.aten._unsafe_view(clone_2, [64, 1024, 64])
bmm_1 = torch.ops.aten.bmm(view_6, _unsafe_view_6)
_unsafe_view_7 = torch.ops.aten._unsafe_view(bmm_1, [8, 8, 1024, 64])
transpose_4 = torch.ops.aten.transpose(_unsafe_view_7, 1, 2)
clone_3 = torch.ops.aten.clone(transpose_4, memory_format=torch.contiguous_format)
view_7 = torch.ops.aten.view(clone_3, [8, -1, 512])
t_3 = torch.ops.aten.t(primals_2)
view_8 = torch.ops.aten.view(view_7, [8192, 512])
mm_3 = torch.ops.aten.mm(view_8, t_3)
_unsafe_view_8 = torch.ops.aten._unsafe_view(mm_3, [8, 1024, 512])
add_1 = torch.ops.aten.add(primals_9, _unsafe_view_8)
pow_2 = torch.ops.aten.pow(add_1, 2)
mean_1 = torch.ops.aten.mean(pow_2, [-1], True)
add_2 = torch.ops.aten.add(mean_1, 1e-06)
rsqrt_1 = torch.ops.aten.rsqrt(add_2)
mul_2 = torch.ops.aten.mul(add_1, rsqrt_1)
mul_3 = torch.ops.aten.mul(primals_8, mul_2)
t_4 = torch.ops.aten.t(primals_6)
view_9 = torch.ops.aten.view(mul_3, [8192, 512])
mm_4 = torch.ops.aten.mm(view_9, t_4)
_unsafe_view_9 = torch.ops.aten._unsafe_view(mm_4, [8, 1024, 2048])
relu = torch.ops.aten.relu(_unsafe_view_9)
t_5 = torch.ops.aten.t(primals_7)
view_10 = torch.ops.aten.view(relu, [8192, 2048])
mm_5 = torch.ops.aten.mm(view_10, t_5)
_unsafe_view_10 = torch.ops.aten._unsafe_view(mm_5, [8, 1024, 512])
add_3 = torch.ops.aten.add(add_1, _unsafe_view_10)
return [add_3, rsqrt, _unsafe_view_3, t_3, _softmax, view_6, mul_2, t, view_9, t_1, primals_5, add_1,
_unsafe_view_4, view_2, view_10, t_5, t_2, primals_8, view_4, view_8, rsqrt_1, primals_9, t_4,
mul, _unsafe_view_6, relu, view]
class TestFxNvFuserBackend(TestCase):
def _generate_random_inputs(self, device, inputs_meta: List[Tuple[torch.Size, torch.dtype]]):
inputs = []
for meta in inputs_meta:
shape, dtype = meta
if dtype in {torch.int, torch.int32, torch.int64, torch.bool, torch.int, torch.uint8}:
input = torch.randint(0, 1, shape, dtype=dtype, device=device)
else:
input = torch.rand(shape, dtype=dtype, device=device)
inputs.append(input)
return inputs
@dtypes(torch.float32)
def test_nvfuser_call_module_backend(self, device, dtype):
class Model(torch.nn.Module):
def __init__(self):
super(Model, self).__init__()
self.bn = torch.nn.BatchNorm2d(3)
self.relu = torch.nn.ReLU()
def forward(self, inp):
o = self.bn(inp)
o = self.relu(o)
return o
inp = torch.randn(2, 3, 4, 5).to(dtype=dtype, device=device)
m = Model().to(dtype=dtype, device=device)
# note that the traced module here contains only `call_module` node,
# which isn't fused by nvfuser backend. But `nvfuser.compile` should run without error
traced = symbolic_trace(m)
nvfuser = NvFuserBackend()
compiled_module = nvfuser.compile(traced)
eager_result = m(inp)
nvfuser_result = compiled_module(inp)
torch.testing.assert_close(eager_result, nvfuser_result, rtol=1e-5, atol=1e-5)
@dtypes(torch.float32)
def test_nvfuser_backend(self, device, dtype):
m = HF_T5_Partial()
m.to(device)
traced = symbolic_trace(m)
nvfuser = NvFuserBackend()
compiled_module = nvfuser.compile(traced)
inputs = self._generate_random_inputs(device, m.inputs_meta())
eager_result = m(*inputs)
nvfuser_result = compiled_module(*inputs)
torch.testing.assert_close(eager_result, nvfuser_result, rtol=1e-5, atol=1e-5)
@dtypes(torch.float32)
def test_aten_square(self, device, dtype):
def fn(x):
square = torch.square(x)
a = square + 1
b = a + 1
return b
inputs = torch.randn(4, device=device)
traced = make_fx(fn)(inputs)
nvfuser = NvFuserBackend()
compiled_module = nvfuser.compile(copy.deepcopy(traced))
for node in compiled_module.graph.nodes:
if node.op == "call_function":
assert "fused" in str(node.target), "the entire function should be fused into a single fusion group"
eager_result = traced(inputs)
nvfuser_result = compiled_module(inputs)
torch.testing.assert_close(eager_result, nvfuser_result, rtol=1e-5, atol=1e-5)
@dtypes(torch.float32)
def test_aten_leakyrelu(self, device, dtype):
def fn(x):
square = torch.ops.aten.leaky_relu(x, 0.1)
a = square + 1
b = a + 1
return b
inputs = torch.randn(4, device=device)
traced = make_fx(fn)(inputs)
nvfuser = NvFuserBackend()
compiled_module = nvfuser.compile(copy.deepcopy(traced))
for node in compiled_module.graph.nodes:
if node.op == "call_function":
assert "fused" in str(node.target), "the entire function should be fused into a single fusion group"
eager_result = traced(inputs)
nvfuser_result = compiled_module(inputs)
torch.testing.assert_close(eager_result, nvfuser_result, rtol=1e-5, atol=1e-5)
@dtypes(torch.float32)
def test_aten_where(self, device, dtype):
def fn(x):
where = torch.ops.aten.where(x < 0, -x, x)
a = where + 1
b = a + 1
return b
inputs = torch.randn(4, device=device)
traced = make_fx(fn)(inputs)
nvfuser = NvFuserBackend()
compiled_module = nvfuser.compile(copy.deepcopy(traced))
for node in compiled_module.graph.nodes:
if node.op == "call_function":
assert "fused" in str(node.target), "the entire function should be fused into a single fusion group"
eager_result = traced(inputs)
nvfuser_result = compiled_module(inputs)
torch.testing.assert_close(eager_result, nvfuser_result, rtol=1e-5, atol=1e-5)
instantiate_device_type_tests(TestFxNvFuserBackend, globals(), only_for="cuda")
if __name__ == "__main__":
run_tests()