| # Owner(s): ["module: mkldnn"] |
| import torch |
| import unittest |
| import itertools |
| |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from torch.testing._internal.jit_utils import JitTestCase |
| from torch.testing._internal.common_utils import run_tests, TEST_SCIPY, IS_WINDOWS, IS_MACOS |
| |
| LLGA_FUSION_GROUP = 'prim::oneDNNFusionGroup' |
| LLGA_NOT_ENABLED = not torch._C.has_mkldnn or IS_WINDOWS or IS_MACOS |
| |
| |
| def warmup_forward(f, *args, profiling_count=2): |
| for i in range(profiling_count): |
| results = f(*args) |
| |
| return results |
| |
| |
| class JitLlgaTestCase(JitTestCase): |
| def setUp(self): |
| torch.jit.enable_onednn_fusion(True) |
| |
| def tearDown(self): |
| torch.jit.enable_onednn_fusion(False) |
| |
| def checkTrace(self, m, x, *args, **kwargs): |
| if isinstance(m, torch.nn.Module): |
| m.eval() |
| with torch.no_grad(), \ |
| torch._jit_internal._disable_emit_hooks(): |
| traced = torch.jit.trace(m, x) |
| if isinstance(m, torch.nn.Module): |
| traced = torch.jit.freeze(traced) |
| warmup_forward(traced, *x) |
| fwd_graph = traced.graph_for(*x) |
| |
| ref_o = m(*x) |
| jit_o = traced(*x) |
| self.assertEqual(jit_o, ref_o) |
| return traced, fwd_graph |
| |
| def assertFused(self, graph, fused_patterns): |
| for pat in fused_patterns: |
| self.assertGraphContainsExactly(graph, pat, 0) |
| |
| |
| try: |
| import torchvision |
| HAS_TORCHVISION = True |
| except ImportError: |
| HAS_TORCHVISION = False |
| except RuntimeError: |
| HAS_TORCHVISION = False |
| skipIfNoTorchVision = unittest.skipIf(not HAS_TORCHVISION, 'no torchvision') |
| |
| def get_eltwise_fn(name): |
| if hasattr(torch, name): |
| return getattr(torch, name) |
| elif hasattr(F, name): |
| return getattr(F, name) |
| else: |
| raise NameError('Eltwise function %s not found' % name) |
| |
| |
| @unittest.skipIf(LLGA_NOT_ENABLED, "MKL-DNN build is disabled") |
| class TestOp(JitLlgaTestCase): |
| def test_conv2d(self): |
| for [spatial, in_channels, out_channels, kernel, padding, stride, dilation, g, bias] in itertools.product( |
| [7, 8], |
| [8, 15], |
| [7, 16], |
| [3, 4], |
| [0, 2], |
| [1, 2], |
| [1, 2], |
| [1, 2], |
| [True, False]): |
| |
| m = nn.Conv2d(in_channels=in_channels * g, |
| out_channels=out_channels * g, |
| kernel_size=kernel, |
| padding=padding, |
| stride=stride, |
| dilation=dilation, |
| groups=g, |
| bias=bias) |
| |
| x = torch.rand(1, in_channels * g, spatial, spatial) |
| _, graph = self.checkTrace(m, [x]) |
| self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 1) |
| |
| def test_bn2d(self): |
| m = nn.BatchNorm2d(32).eval() |
| x = torch.rand(1, 32, 28, 28) |
| _, graph = self.checkTrace(m, [x]) |
| # single-op partition shouldn't be created for softmax |
| self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 0) |
| |
| def test_eltwise(self): |
| class M(nn.Module): |
| def __init__(self, eltwise_fn): |
| super(M, self).__init__() |
| self.eltwise = eltwise_fn |
| |
| def forward(self, x): |
| return self.eltwise(x) |
| |
| for eltwise in ['relu', 'gelu']: |
| eltwise_fn = get_eltwise_fn(eltwise) |
| m = M(eltwise_fn) |
| x = torch.rand(1, 32, 28, 28) |
| _, graph = self.checkTrace(m, [x]) |
| # single-op partition shouldn't be created. |
| self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 0) |
| |
| def test_max_pool2d(self): |
| for [spatial, kernel, padding, stride, dilation, ceil_mode] in itertools.product( |
| [15, 16, 17, 18, 19], |
| [4, 5], |
| [0, 1, 2], |
| [1, 2], # [1, 2, 4], TODO: fix issue in pad calculation |
| [1], # [1, 2], TODO: backend support for dilation |
| [True, False]): |
| |
| m = nn.MaxPool2d(kernel_size=kernel, |
| stride=stride, |
| padding=padding, |
| dilation=dilation, |
| ceil_mode=ceil_mode) |
| |
| x = torch.rand(1, 4, spatial, spatial) |
| _, graph = self.checkTrace(m, [x]) |
| self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 1) |
| |
| def test_avg_pool2d(self): |
| for [spatial, kernel, padding, stride, ceil_mode, count_include_pad] in itertools.product( |
| [15, 16, 17, 18, 19], |
| [4, 5], |
| [0, 1, 2], |
| [1, 2, 4], |
| [False], # TODO: oneDNN Graph does not fully support ceil_mode=True |
| [True, False]): |
| |
| m = nn.AvgPool2d(kernel_size=kernel, |
| stride=stride, |
| padding=padding, |
| ceil_mode=ceil_mode, |
| count_include_pad=count_include_pad) |
| |
| x = torch.rand(1, 4, spatial, spatial) |
| _, graph = self.checkTrace(m, [x]) |
| self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 1) |
| |
| def test_variable_kernel_avg_pool2d(self): |
| class M(nn.Module): |
| def __init__(self): |
| super(M, self).__init__() |
| |
| def forward(self, x): |
| x = F.avg_pool2d(x, kernel_size=(x.size(2), x.size(3)), padding=0, count_include_pad=False) |
| return x |
| |
| x = torch.randn(1, 1000, 1, 1) |
| m = M() |
| _, graph = self.checkTrace(m, [x]) |
| # kernel_size is not Constant, shouldn't have any LLGA_FUSION_GROUP |
| # TODO: with shape specialization, should have 1 LLGA_FUSION_GROUP |
| self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 0) |
| |
| def test_softmax(self): |
| for dim in [-4, -3, -2, -1, 0, 1, 2, 3]: |
| m = nn.Softmax(dim=dim) |
| x = torch.rand(8, 12, 12, 12) |
| _, graph = self.checkTrace(m, [x]) |
| # single-op partition shouldn't be created for softmax |
| self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 0) |
| |
| def test_linear(self): |
| for bias in [True, False]: |
| x = torch.rand(32, 28) |
| m = torch.nn.Linear(in_features=28, out_features=64, bias=bias) |
| _, graph = self.checkTrace(m, [x]) |
| self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 1) |
| self.assertFused(graph, ['aten::linear']) |
| |
| def _gen_binary_inputs(self, gen_permute=True): |
| for xshape, yshape in [ |
| [[1, 32, 28, 28], [1, 32, 28, 28]], |
| [[1, 32, 28, 28], [1, 1, 28, 28]], |
| [[1, 32, 28, 28], [28]], |
| [[1, 32, 28, 28], [1]], |
| |
| ]: |
| yield torch.rand(xshape), torch.rand(yshape) |
| if gen_permute and xshape != yshape: |
| yield torch.rand(yshape), torch.rand(xshape) |
| |
| def test_add(self): |
| def forward_add(x, y): |
| return torch.add(x, y, alpha=2) |
| |
| for x, y in self._gen_binary_inputs(): |
| _, graph = self.checkTrace(forward_add, [x, y]) |
| self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 1) |
| |
| def test_add_scalar(self): |
| def add_scalar(x): |
| return 42 + x + 3.14 |
| |
| x = torch.rand(32, 32) |
| _, graph = self.checkTrace(add_scalar, [x]) |
| self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 1) |
| |
| def test_addmm(self): |
| def addmm(x, y, z): |
| # alpha and beta are 1, by default |
| return torch.addmm(z, x, y) |
| |
| x = torch.rand(64, 32) |
| y = torch.rand(32, 32) |
| z = torch.rand(64, 32) |
| _, graph = self.checkTrace(addmm, [x, y, z]) |
| # single-op partition should be created for matmul with bias. |
| self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 1) |
| |
| def test_mul(self): |
| def forward_mul(x, y): |
| return torch.mul(x, y) * 3 |
| |
| for x, y in self._gen_binary_inputs(): |
| _, graph = self.checkTrace(forward_mul, [x, y]) |
| # single-op partitions shouldn't be created |
| self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 1) |
| |
| def test_identity_binary(self): |
| def forward(x): |
| return x * 1 + 0.0 |
| |
| x = torch.rand(32) |
| _, graph = self.checkTrace(forward, [x]) |
| self.assertFused(graph, ['aten::add', 'aten::mul']) |
| |
| def test_layer_norm(self): |
| # TODO: support more normalized_shape |
| m = torch.nn.LayerNorm(10) |
| x = torch.randn(2, 5, 10, 10) |
| _, graph = self.checkTrace(m, [x]) |
| self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 1) |
| |
| def test_cat(self): |
| def cat_along_dim(d): |
| def forward_cat(*inputs): |
| return torch.cat(inputs, d) |
| return forward_cat |
| |
| for xshape in [ |
| [8, 8, 8, 8], |
| [64, 8, 32], |
| [2048, 64], |
| ]: |
| for d in range(len(xshape)): |
| x = torch.rand(xshape) |
| _, graph = self.checkTrace(cat_along_dim(d), [x, x, x]) |
| self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 1) |
| |
| def test_typecheck(self): |
| x = torch.rand(32, 28) |
| m = torch.nn.Linear(in_features=28, out_features=64, bias=True) |
| traced, graph = self.checkTrace(m, [x]) |
| self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 1) |
| self.assertFused(graph, ['aten::linear']) |
| # change the shape of the input, we should enter fallback graph |
| x = torch.rand(5, 28) |
| self.assertEqual(m(x), traced(x)) |
| |
| |
| @unittest.skipIf(LLGA_NOT_ENABLED, "MKL-DNN build is disabled") |
| class TestFusionPattern(JitLlgaTestCase): |
| def test_conv2d_eltwise(self): |
| class M(nn.Module): |
| def __init__(self, eltwise_fn): |
| super(M, self).__init__() |
| self.conv1 = nn.Conv2d(32, 32, 3, padding=1, bias=True) |
| self.conv2 = nn.Conv2d(32, 32, 3, padding=1, bias=False) |
| self.eltwise = eltwise_fn |
| |
| def forward(self, x): |
| x = self.conv1(x) |
| x = self.eltwise(x) |
| x = self.conv2(x) |
| x = self.eltwise(x) |
| return x |
| |
| # for eltwise in ['relu', 'sigmoid', 'sqrt', 'abs', 'square', 'hardtanh']: |
| for eltwise in ['relu']: |
| for inplace in [True, False]: |
| eltwise_fn_name = eltwise + '_' if inplace else eltwise |
| eltwise_fn = get_eltwise_fn(eltwise_fn_name) |
| |
| m = M(eltwise_fn) |
| x = torch.rand(1, 32, 28, 28) |
| _, graph = self.checkTrace(m, [x]) |
| self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 2) |
| # test if relu_ is replace with relu by mutation removal pass |
| self.assertFused(graph, ['aten::' + eltwise_fn_name]) |
| # test if relu is fused into the fusion group |
| self.assertFused(graph, ['aten::' + eltwise]) |
| |
| def test_conv2d_bn(self): |
| class M(nn.Module): |
| def __init__(self): |
| super(M, self).__init__() |
| self.conv1 = nn.Conv2d(32, 32, 3, padding=1, bias=True) |
| self.bn1 = nn.BatchNorm2d(32) |
| |
| def forward(self, x): |
| x = self.conv1(x) |
| x = self.bn1(x) |
| return x |
| |
| m = M().eval() |
| x = torch.rand(1, 32, 28, 28) |
| _, graph = self.checkTrace(m, [x]) |
| self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 1) |
| self.assertFused(graph, ['aten::_convolution', 'aten::batch_norm']) |
| |
| |
| def test_conv2d_bn_relu(self): |
| class M(nn.Module): |
| def __init__(self): |
| super(M, self).__init__() |
| self.conv1 = nn.Conv2d(32, 32, 3, padding=1, bias=True) |
| self.bn1 = nn.BatchNorm2d(32) |
| |
| def forward(self, x): |
| x = self.conv1(x) |
| x = self.bn1(x) |
| x = F.relu(x) |
| return x |
| |
| m = M().eval() |
| x = torch.rand(1, 32, 28, 28) |
| _, graph = self.checkTrace(m, [x]) |
| self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 1) |
| self.assertFused(graph, ['aten::_convolution', 'aten::batch_norm', |
| 'aten::relu']) |
| |
| def test_bn2d_eltwise(self): |
| class M(nn.Module): |
| def __init__(self, eltwise_fn): |
| super(M, self).__init__() |
| self.eltwise = eltwise_fn |
| self.bn = nn.BatchNorm2d(32) |
| |
| def forward(self, x): |
| x = self.bn(x) |
| x = self.eltwise(x) |
| return x |
| |
| for eltwise in ['relu']: |
| eltwise_fn = get_eltwise_fn(eltwise) |
| m = M(eltwise_fn).eval() |
| x = torch.rand(1, 32, 28, 28) |
| _, graph = self.checkTrace(m, [x]) |
| self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 1) |
| self.assertFused(graph, ['aten::' + eltwise]) |
| |
| def test_linear_eltwise(self): |
| class M(nn.Module): |
| def __init__(self, eltwise_fn, bias): |
| super(M, self).__init__() |
| self.linear = nn.Linear(28, 64, bias) |
| self.eltwise = eltwise_fn |
| |
| def forward(self, x): |
| x = self.linear(x) |
| x = self.eltwise(x) |
| return x |
| |
| for [has_bias, eltwise] in itertools.product( |
| [True, False], |
| ['relu', 'gelu', 'sigmoid', 'hardtanh', 'relu6', 'elu']): |
| |
| eltwise_fn = get_eltwise_fn(eltwise) |
| m = M(eltwise_fn, has_bias) |
| x = torch.rand(32, 28, requires_grad=False) |
| _, graph = self.checkTrace(m, [x]) |
| self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 1) |
| self.assertFused(graph, ['aten::' + eltwise]) |
| |
| def test_conv2d_sum(self): |
| class M(nn.Module): |
| def __init__(self, bias=False): |
| super(M, self).__init__() |
| self.conv1 = nn.Conv2d(32, 32, 3, padding=1, bias=bias) |
| self.bn1 = nn.BatchNorm2d(32) |
| self.conv2 = nn.Conv2d(32, 32, 3, padding=1, bias=bias) |
| self.bn2 = nn.BatchNorm2d(32) |
| self.relu = nn.ReLU() |
| self.conv3 = nn.Conv2d(32, 32, 3, padding=1, bias=bias) |
| self.bn3 = nn.BatchNorm2d(32) |
| |
| def forward(self, x, y): |
| x = self.conv1(x) |
| x = self.bn1(x) |
| y = self.conv2(y) |
| y = self.bn2(y) |
| z = self.relu(x + y) |
| z = self.conv3(z) |
| z = self.bn3(z) |
| return z |
| |
| for bias in [True, False]: |
| m = M(bias).eval() |
| x = torch.rand(1, 32, 16, 16, requires_grad=False) |
| y = torch.rand(1, 32, 16, 16, requires_grad=False) |
| _, graph = self.checkTrace(m, [x, y]) |
| self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 3) |
| |
| def test_wildcard(self): |
| class M(nn.Module): |
| def __init__(self): |
| super(M, self).__init__() |
| self.conv1 = nn.Conv2d(32, 32, 3, padding=1, bias=True) |
| self.eltwise = nn.ReLU() |
| |
| def forward(self, x): |
| x = self.conv1(x) |
| y = self.eltwise(x) |
| return [x, y] |
| |
| # The pattern is as the following: |
| # conv |
| # | \ |
| # eltwise \ |
| # | \ |
| # ListConstruct |
| # |
| # The output of conv is used by a wildcard op: ListConstruct. |
| # Thus conv-eltwise cannot be selected into the same Partition. |
| m = M() |
| x = torch.rand(1, 32, 28, 28) |
| _, graph = self.checkTrace(m, [x]) |
| # conv can exist in a single-op oneDNN Graph partition but not relu |
| self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 1) |
| self.assertFused(graph, ['aten::_convolution']) |
| |
| def test_rewrap_tensor_input_to_pytorch(self): |
| class M(nn.Module): |
| def __init__(self, eltwise_fn, data_type): |
| super(M, self).__init__() |
| self.conv1 = nn.Conv2d(32, 32, 3, padding=1, bias=True, dtype=data_type) |
| self.conv2 = nn.Conv2d(32, 32, 3, padding=1, bias=True, dtype=data_type) |
| self.eltwise = eltwise_fn |
| self.adaptive_avg_pool_2d = nn.AdaptiveAvgPool2d((5, 7)) |
| |
| def forward(self, x, y): |
| x = self.conv1(x) |
| x = self.eltwise(x) |
| x = self.conv2(x) |
| x = self.eltwise(x) |
| x = torch.add(x, y) |
| x = self.adaptive_avg_pool_2d(x) |
| return x |
| |
| eltwise_fn_name = 'relu' |
| eltwise_fn = get_eltwise_fn(eltwise_fn_name) |
| # Add bfloat16 later |
| for data_type in [torch.float]: |
| m = M(eltwise_fn, data_type) |
| m = m.to(memory_format=torch.channels_last) |
| x = torch.rand(1, 32, 28, 28, dtype=data_type).to(memory_format=torch.channels_last) |
| y = torch.rand(1, 32, 28, 28, dtype=data_type).to(memory_format=torch.channels_last) |
| # Simply test if the output is accurate |
| # The output of the second partition is input to adaptive_avg_pool2d, which is |
| # unsupported by LLGA, so it must be handled by PyTorch, which should receive |
| # correct strides info of the channels-last tensor. |
| graph, _ = self.checkTrace(m, [x, y]) |
| |
| |
| @unittest.skipIf(LLGA_NOT_ENABLED, "MKL-DNN build is disabled") |
| class TestEnableDisableLlgaFuser(JitTestCase): |
| def setUp(self): |
| super().setUp() |
| self.is_enabled = torch._C._jit_set_llga_enabled(False) |
| |
| def tearDown(self): |
| torch._C._jit_set_llga_enabled(self.is_enabled) |
| super().tearDown() |
| |
| def test_context_manager(self): |
| x = torch.randn(4, 8) |
| y = torch.randn(4, 8) |
| with torch.jit.fuser('fuser3'): |
| with torch.jit.fuser('fuser3'): |
| |
| def t1(x, y): |
| o = x + y |
| o = o + 2.0 |
| return o |
| t_jit = torch.jit.script(t1) |
| t_jit(x, y) |
| t_jit(x, y) |
| self.assertGraphContains(t_jit.graph_for(x, y), LLGA_FUSION_GROUP) |
| |
| def t2(x, y): |
| o = x + y |
| o = o + 3.0 |
| return o |
| t_jit_2 = torch.jit.script(t2) |
| t_jit_2(x, y) |
| t_jit_2(x, y) |
| self.assertGraphContains(t_jit_2.graph_for(x, y), LLGA_FUSION_GROUP) |
| |
| def t3(x, y): |
| o = x + y |
| o = o + 4.0 |
| return o |
| t_jit_3 = torch.jit.script(t3) |
| t_jit_3(x, y) |
| t_jit_3(x, y) |
| self.assertGraphContainsExactly(t_jit_3.graph_for(x, y), LLGA_FUSION_GROUP, 0) |
| |
| |
| @unittest.skipIf(LLGA_NOT_ENABLED, "MKL-DNN build is disabled") |
| class TestModel(JitLlgaTestCase): |
| @skipIfNoTorchVision |
| def _test_vision(self, model_name): |
| m = getattr(torchvision.models, model_name)().eval() |
| x = torch.rand(1, 3, 224, 224) / 10 |
| _, graph = self.checkTrace(m, [x]) |
| self.assertFused(graph, ['aten::_convolution', 'aten::batch_norm', |
| 'aten::relu', 'aten::linear', |
| 'aten::avg_pool2d', 'aten::max_pool2d']) |
| |
| |
| for model_name, enabled in [ |
| ['resnet50', True], |
| ['resnext50_32x4d', True], |
| ['resnext101_32x8d', True], |
| ['densenet121', True], |
| ['googlenet', TEST_SCIPY], |
| ['mobilenet_v2', True], |
| ['mnasnet1_0', True], |
| ['squeezenet1_0', True], |
| ['vgg16', True], |
| ['alexnet', True], |
| ['shufflenet_v2_x1_0', True], |
| ['wide_resnet50_2', True], |
| ]: |
| def wrapper(mname): |
| @unittest.skipIf(not enabled, 'Disabled') |
| def test(self): |
| return self._test_vision(mname) |
| return test |
| |
| setattr(TestModel, 'test_vision_%s' % model_name, wrapper(model_name)) |
| |
| if __name__ == '__main__': |
| run_tests() |