| # Owner(s): ["module: fx"] |
| |
| import copy |
| import sys |
| import logging |
| from typing import List, Tuple |
| |
| import torch |
| from torch.fx._symbolic_trace import symbolic_trace |
| from torch.fx.experimental.proxy_tensor import make_fx |
| from torch.fx.passes.backends.nvfuser import NvFuserBackend |
| |
| from torch.testing._internal.common_utils import run_tests, TEST_CUDA, TestCase |
| from torch.testing._internal.common_device_type import ( |
| instantiate_device_type_tests, |
| dtypes, |
| ) |
| |
| if not TEST_CUDA: |
| print('CUDA not available, skipping tests', file=sys.stderr) |
| TestCase = object # noqa: F811 |
| |
| logging.basicConfig(level=logging.DEBUG) |
| logger = logging.getLogger(__name__) |
| |
| class HF_T5_Partial(torch.nn.Module): |
| |
| def inputs_meta(self): |
| return [ |
| (torch.Size([512, 512]), torch.float32), |
| (torch.Size([512, 512]), torch.float32), |
| (torch.Size([512, 512]), torch.float32), |
| (torch.Size([512, 512]), torch.float32), |
| (torch.Size([512]), torch.float32), |
| (torch.Size([2048, 512]), torch.float32), |
| (torch.Size([512, 2048]), torch.float32), |
| (torch.Size([512]), torch.float32), |
| (torch.Size([8, 1024, 512]), torch.float32), |
| (torch.Size([8, 8, 1024, 1024]), torch.float32), |
| ] |
| |
| def forward(self, primals_1, primals_2, primals_3, primals_4, primals_5, |
| primals_6, primals_7, primals_8, primals_9, primals_10): |
| pow_1 = torch.ops.aten.pow(primals_9, 2) |
| mean = torch.ops.aten.mean(pow_1, [-1], True) |
| add = torch.ops.aten.add(mean, 1e-06) |
| rsqrt = torch.ops.aten.rsqrt(add) |
| mul = torch.ops.aten.mul(primals_9, rsqrt) |
| mul_1 = torch.ops.aten.mul(primals_5, mul) |
| t = torch.ops.aten.t(primals_3) |
| view = torch.ops.aten.view(mul_1, [8192, 512]) |
| mm = torch.ops.aten.mm(view, t) |
| _unsafe_view = torch.ops.aten._unsafe_view(mm, [8, 1024, 512]) |
| view_1 = torch.ops.aten.view(_unsafe_view, [8, -1, 8, 64]) |
| transpose = torch.ops.aten.transpose(view_1, 1, 2) |
| t_1 = torch.ops.aten.t(primals_1) |
| view_2 = torch.ops.aten.view(mul_1, [8192, 512]) |
| mm_1 = torch.ops.aten.mm(view_2, t_1) |
| _unsafe_view_1 = torch.ops.aten._unsafe_view(mm_1, [8, 1024, 512]) |
| view_3 = torch.ops.aten.view(_unsafe_view_1, [8, -1, 8, 64]) |
| transpose_1 = torch.ops.aten.transpose(view_3, 1, 2) |
| t_2 = torch.ops.aten.t(primals_4) |
| view_4 = torch.ops.aten.view(mul_1, [8192, 512]) |
| mm_2 = torch.ops.aten.mm(view_4, t_2) |
| _unsafe_view_2 = torch.ops.aten._unsafe_view(mm_2, [8, 1024, 512]) |
| view_5 = torch.ops.aten.view(_unsafe_view_2, [8, -1, 8, 64]) |
| transpose_2 = torch.ops.aten.transpose(view_5, 1, 2) |
| transpose_3 = torch.ops.aten.transpose(transpose_1, 3, 2) |
| expand = torch.ops.aten.expand(transpose, [8, 8, 1024, 64]) |
| clone = torch.ops.aten.clone(expand, memory_format=torch.contiguous_format) |
| _unsafe_view_3 = torch.ops.aten._unsafe_view(clone, [64, 1024, 64]) |
| expand_1 = torch.ops.aten.expand(transpose_3, [8, 8, 64, 1024]) |
| clone_1 = torch.ops.aten.clone(expand_1, memory_format=torch.contiguous_format) |
| _unsafe_view_4 = torch.ops.aten._unsafe_view(clone_1, [64, 64, 1024]) |
| bmm = torch.ops.aten.bmm(_unsafe_view_3, _unsafe_view_4) |
| _unsafe_view_5 = torch.ops.aten._unsafe_view(bmm, [8, 8, 1024, 1024]) |
| add_ = torch.ops.aten.add_(_unsafe_view_5, primals_10) |
| _softmax = torch.ops.aten._softmax(add_, -1, False) |
| expand_2 = torch.ops.aten.expand(_softmax, [8, 8, 1024, 1024]) |
| view_6 = torch.ops.aten.view(expand_2, [64, 1024, 1024]) |
| expand_3 = torch.ops.aten.expand(transpose_2, [8, 8, 1024, 64]) |
| clone_2 = torch.ops.aten.clone(expand_3, memory_format=torch.contiguous_format) |
| _unsafe_view_6 = torch.ops.aten._unsafe_view(clone_2, [64, 1024, 64]) |
| bmm_1 = torch.ops.aten.bmm(view_6, _unsafe_view_6) |
| _unsafe_view_7 = torch.ops.aten._unsafe_view(bmm_1, [8, 8, 1024, 64]) |
| transpose_4 = torch.ops.aten.transpose(_unsafe_view_7, 1, 2) |
| clone_3 = torch.ops.aten.clone(transpose_4, memory_format=torch.contiguous_format) |
| view_7 = torch.ops.aten.view(clone_3, [8, -1, 512]) |
| t_3 = torch.ops.aten.t(primals_2) |
| view_8 = torch.ops.aten.view(view_7, [8192, 512]) |
| mm_3 = torch.ops.aten.mm(view_8, t_3) |
| _unsafe_view_8 = torch.ops.aten._unsafe_view(mm_3, [8, 1024, 512]) |
| add_1 = torch.ops.aten.add(primals_9, _unsafe_view_8) |
| pow_2 = torch.ops.aten.pow(add_1, 2) |
| mean_1 = torch.ops.aten.mean(pow_2, [-1], True) |
| add_2 = torch.ops.aten.add(mean_1, 1e-06) |
| rsqrt_1 = torch.ops.aten.rsqrt(add_2) |
| mul_2 = torch.ops.aten.mul(add_1, rsqrt_1) |
| mul_3 = torch.ops.aten.mul(primals_8, mul_2) |
| t_4 = torch.ops.aten.t(primals_6) |
| view_9 = torch.ops.aten.view(mul_3, [8192, 512]) |
| mm_4 = torch.ops.aten.mm(view_9, t_4) |
| _unsafe_view_9 = torch.ops.aten._unsafe_view(mm_4, [8, 1024, 2048]) |
| relu = torch.ops.aten.relu(_unsafe_view_9) |
| t_5 = torch.ops.aten.t(primals_7) |
| view_10 = torch.ops.aten.view(relu, [8192, 2048]) |
| mm_5 = torch.ops.aten.mm(view_10, t_5) |
| _unsafe_view_10 = torch.ops.aten._unsafe_view(mm_5, [8, 1024, 512]) |
| add_3 = torch.ops.aten.add(add_1, _unsafe_view_10) |
| return [add_3, rsqrt, _unsafe_view_3, t_3, _softmax, view_6, mul_2, t, view_9, t_1, primals_5, add_1, |
| _unsafe_view_4, view_2, view_10, t_5, t_2, primals_8, view_4, view_8, rsqrt_1, primals_9, t_4, |
| mul, _unsafe_view_6, relu, view] |
| |
| |
| class TestFxNvFuserBackend(TestCase): |
| |
| def _generate_random_inputs(self, device, inputs_meta: List[Tuple[torch.Size, torch.dtype]]): |
| inputs = [] |
| for meta in inputs_meta: |
| shape, dtype = meta |
| |
| if dtype in {torch.int, torch.int32, torch.int64, torch.bool, torch.int, torch.uint8}: |
| input = torch.randint(0, 1, shape, dtype=dtype, device=device) |
| else: |
| input = torch.rand(shape, dtype=dtype, device=device) |
| |
| inputs.append(input) |
| |
| return inputs |
| |
| |
| @dtypes(torch.float32) |
| def test_nvfuser_call_module_backend(self, device, dtype): |
| |
| class Model(torch.nn.Module): |
| |
| def __init__(self): |
| super(Model, self).__init__() |
| self.bn = torch.nn.BatchNorm2d(3) |
| self.relu = torch.nn.ReLU() |
| |
| def forward(self, inp): |
| o = self.bn(inp) |
| o = self.relu(o) |
| return o |
| |
| inp = torch.randn(2, 3, 4, 5).to(dtype=dtype, device=device) |
| m = Model().to(dtype=dtype, device=device) |
| |
| # note that the traced module here contains only `call_module` node, |
| # which isn't fused by nvfuser backend. But `nvfuser.compile` should run without error |
| traced = symbolic_trace(m) |
| |
| nvfuser = NvFuserBackend() |
| compiled_module = nvfuser.compile(traced) |
| |
| eager_result = m(inp) |
| nvfuser_result = compiled_module(inp) |
| |
| torch.testing.assert_close(eager_result, nvfuser_result, rtol=1e-5, atol=1e-5) |
| |
| |
| @dtypes(torch.float32) |
| def test_nvfuser_backend(self, device, dtype): |
| m = HF_T5_Partial() |
| m.to(device) |
| |
| traced = symbolic_trace(m) |
| |
| nvfuser = NvFuserBackend() |
| compiled_module = nvfuser.compile(traced) |
| |
| inputs = self._generate_random_inputs(device, m.inputs_meta()) |
| |
| eager_result = m(*inputs) |
| nvfuser_result = compiled_module(*inputs) |
| |
| torch.testing.assert_close(eager_result, nvfuser_result, rtol=1e-5, atol=1e-5) |
| |
| @dtypes(torch.float32) |
| def test_aten_square(self, device, dtype): |
| |
| def fn(x): |
| square = torch.square(x) |
| a = square + 1 |
| b = a + 1 |
| return b |
| |
| inputs = torch.randn(4, device=device) |
| traced = make_fx(fn)(inputs) |
| |
| nvfuser = NvFuserBackend() |
| compiled_module = nvfuser.compile(copy.deepcopy(traced)) |
| |
| for node in compiled_module.graph.nodes: |
| if node.op == "call_function": |
| assert "fused" in str(node.target), "the entire function should be fused into a single fusion group" |
| |
| eager_result = traced(inputs) |
| nvfuser_result = compiled_module(inputs) |
| torch.testing.assert_close(eager_result, nvfuser_result, rtol=1e-5, atol=1e-5) |
| |
| @dtypes(torch.float32) |
| def test_aten_leakyrelu(self, device, dtype): |
| |
| def fn(x): |
| square = torch.ops.aten.leaky_relu(x, 0.1) |
| a = square + 1 |
| b = a + 1 |
| return b |
| |
| inputs = torch.randn(4, device=device) |
| traced = make_fx(fn)(inputs) |
| |
| nvfuser = NvFuserBackend() |
| compiled_module = nvfuser.compile(copy.deepcopy(traced)) |
| |
| for node in compiled_module.graph.nodes: |
| if node.op == "call_function": |
| assert "fused" in str(node.target), "the entire function should be fused into a single fusion group" |
| |
| eager_result = traced(inputs) |
| nvfuser_result = compiled_module(inputs) |
| torch.testing.assert_close(eager_result, nvfuser_result, rtol=1e-5, atol=1e-5) |
| |
| @dtypes(torch.float32) |
| def test_aten_where(self, device, dtype): |
| |
| def fn(x): |
| where = torch.ops.aten.where(x < 0, -x, x) |
| a = where + 1 |
| b = a + 1 |
| return b |
| |
| inputs = torch.randn(4, device=device) |
| traced = make_fx(fn)(inputs) |
| |
| nvfuser = NvFuserBackend() |
| compiled_module = nvfuser.compile(copy.deepcopy(traced)) |
| |
| for node in compiled_module.graph.nodes: |
| if node.op == "call_function": |
| assert "fused" in str(node.target), "the entire function should be fused into a single fusion group" |
| |
| eager_result = traced(inputs) |
| nvfuser_result = compiled_module(inputs) |
| torch.testing.assert_close(eager_result, nvfuser_result, rtol=1e-5, atol=1e-5) |
| |
| instantiate_device_type_tests(TestFxNvFuserBackend, globals(), only_for="cuda") |
| |
| if __name__ == "__main__": |
| run_tests() |