blob: 0f2c6b69ff8c1d0dbea51b6982e2a9ab56106d30 [file] [log] [blame] [edit]
# Owner(s): ["module: mkldnn"]
import itertools
import unittest
from typing import NamedTuple, List
import torch
from torch import nn
from torch.testing._internal.common_utils import run_tests, skipIfTorchDynamo
from torch.testing._internal.jit_utils import JitTestCase
from test_tensorexpr import warmup_and_run_forward
FUSION_GROUP = 'prim::TensorExprGroup'
class PointwisePostOp(NamedTuple):
attr : str
pointwise_module : nn.Module
scalars : List = []
algorithm : str = ""
CONV_MODULES = {2: torch.nn.Conv2d, 3: torch.nn.Conv3d}
CONV_TRANSPOSE_MODULES = {2: torch.nn.ConvTranspose2d}
@skipIfTorchDynamo("too slow")
@unittest.skipIf(not torch.backends.mkldnn.is_available(), "MKL-DNN build is disabled")
class TestMkldnnFusion(JitTestCase):
def assertFused(self, graph, fused_patterns):
for pat in fused_patterns:
self.assertGraphContainsExactly(graph, pat, 0)
def _check_model(self, m, x, trace=False):
old_fusion_inlining = torch._C._debug_get_fusion_group_inlining()
torch._C._debug_set_fusion_group_inlining(False)
old_cpu_fuser_state = torch._C._jit_can_fuse_on_cpu()
torch._C._jit_override_can_fuse_on_cpu(True)
old_te_must_use_llvm_cpu = torch._C._jit_get_te_must_use_llvm_cpu()
torch._C._jit_set_te_must_use_llvm_cpu(False)
m.eval()
with torch.no_grad():
if trace:
script = torch.jit.trace(m, x)
else:
script = torch.jit.script(m)
script = torch.jit.freeze(script)
with torch.no_grad():
y = warmup_and_run_forward(script, x)
y = script(x)
y_ref = m(x)
graph = script.graph_for(*x)
self.assertEqual(y, y_ref)
torch._C._debug_set_fusion_group_inlining(old_fusion_inlining)
torch._C._jit_override_can_fuse_on_cpu(old_cpu_fuser_state)
torch._C._jit_set_te_must_use_llvm_cpu(old_te_must_use_llvm_cpu)
return graph
def test_single_conv(self):
class M(nn.Module):
def __init__(self, in_channels, out_channels, bias, **kwargs):
super().__init__()
self.conv = torch.nn.Conv2d(in_channels, out_channels, bias=bias, **kwargs)
def forward(self, x):
res = self.conv(x)
return res
for memory_format, enabled in [
[torch.contiguous_format, False],
[torch.channels_last, True],
]:
for trace in [True, False]:
input_size = 224
batch_size = 1
kernel_size = 3
options = itertools.product([True, False], [1, 2], [1, 4])
for bias, dilation, groups in options:
iC = 3 * groups
oC = 10 * groups
m = M(iC,
oC,
bias,
kernel_size=(kernel_size, kernel_size),
stride=2,
padding=1,
dilation=dilation,
groups=groups).to(memory_format=memory_format)
x = torch.randn(batch_size, iC, input_size, input_size).to(memory_format=memory_format)
graph = self._check_model(m, x, trace)
conv_node_name = 'aten::_convolution' if trace else 'aten::conv2d'
if enabled:
self.assertFused(graph, [conv_node_name])
self.assertGraphContainsExactly(graph, FUSION_GROUP, 1)
else:
self.assertGraphContains(graph, kind=conv_node_name)
def test_conv_unary_fusion_nnc(self):
class M(nn.Module):
def __init__(self, unary_fn, in_channels, out_channels, bias, **kwargs):
super().__init__()
self.conv = torch.nn.Conv2d(in_channels, out_channels, bias=bias, **kwargs)
self.unary = unary_fn
def forward(self, x):
x = self.conv(x)
x = self.unary(x)
return x
for memory_format, enabled in [
[torch.contiguous_format, False],
[torch.channels_last, True],
]:
for unary_fn in [torch.relu]:
for bias in [True, False]:
for oC in [1, 10]:
m = M(unary_fn, 3, oC, bias, kernel_size=(3, 3)).to(memory_format=memory_format)
x = torch.randn(1, 3, 224, 224).to(memory_format=memory_format)
graph = self._check_model(m, x)
if enabled:
self.assertFused(graph, ['aten::conv2d', 'aten::' + unary_fn.__name__])
self.assertGraphContainsExactly(graph, FUSION_GROUP, 1)
else:
self.assertGraphContains(graph, kind='aten::conv2d')
def test_unsupported_conv(self):
class M(nn.Module):
def __init__(self, m, in_channels, out_channels, bias, **kwargs):
super().__init__()
self.conv = m(in_channels, out_channels, bias=bias, **kwargs)
def forward(self, x):
res = self.conv(x)
return res
for module, dim, memory_format in [
[nn.Conv3d, 3, torch.contiguous_format],
[nn.Conv3d, 3, torch.channels_last_3d],
[nn.ConvTranspose2d, 2, torch.contiguous_format],
[nn.ConvTranspose2d, 2, torch.channels_last],
]:
trace = True
input_size = 224
batch_size = 1
kernel_size = 3
groups = 2
bias = True
iC = 3 * groups
oC = 10 * groups
dilation = 2
m = M(module,
iC,
oC,
bias,
kernel_size=kernel_size,
stride=2,
padding=1,
dilation=dilation,
groups=groups).to(memory_format=memory_format)
input_sizes = [batch_size, iC, input_size, input_size]
if dim == 3:
input_sizes.append(input_size)
x = torch.randn(input_sizes).to(memory_format=memory_format)
graph = self._check_model(m, x, trace)
self.assertGraphContains(graph, kind='aten::_convolution')
def _unary_list(self):
unary_list = {
"relu": PointwisePostOp("relu", nn.ReLU()),
"sigmoid": PointwisePostOp("sigmoid", nn.Sigmoid()),
"tanh": PointwisePostOp("tanh", nn.Tanh()),
"hardswish": PointwisePostOp("hardswish", nn.Hardswish()),
"leaky_relu": PointwisePostOp("leaky_relu", nn.LeakyReLU(0.1, inplace=False), scalars=[0.1]),
"hardtanh": PointwisePostOp("hardtanh", nn.Hardtanh(min_val=-0.5, max_val=4, inplace=False), scalars=[-0.5, 4]),
"gelu_none": PointwisePostOp("gelu", nn.GELU(approximate="none"), algorithm="none"),
"gelu_tanh": PointwisePostOp("gelu", nn.GELU(approximate="tanh"), algorithm="tanh"),
}
return unary_list
def _binary_list(self):
binary_list = {
"add": torch.add,
"sub": torch.sub,
"mul": torch.mul,
"div": torch.div,
}
return binary_list
def test_linear_unary_fusion_ops(self):
class M(nn.Module):
def __init__(self, unary_fn, in_channels, out_channels, bias, **kwargs):
super().__init__()
self.linear = torch.nn.Linear(
in_channels, out_channels, bias=bias, **kwargs
)
self.unary = unary_fn
def forward(self, x):
x = self.linear(x)
x = self.unary(x)
return x
for pointwise_info in self._unary_list().values():
# Tensor with size = [1, 10] and stride = [0, 1] is contiguous tensor
# but it's strides is not default contiguous strides.
options = itertools.product([[[2, 3, 10], None], [[2, 10], None], [[1, 10], [0, 1]]], [True, False])
for (input_shape, input_stride), bias in options:
with torch.no_grad():
mod = M(pointwise_info.pointwise_module, input_shape[-1], 10, bias).eval()
v = torch.randn(input_shape)
if input_stride is not None:
v = v.as_strided(input_shape, input_stride)
ref = mod(v)
attr = pointwise_info.attr
scalars = pointwise_info.scalars
algorithm = pointwise_info.algorithm
fused = torch.ops.mkldnn._linear_pointwise(
v, mod.linear.weight, mod.linear.bias, attr, scalars, algorithm
)
self.assertEqual(ref, fused)
def test_conv_unary_fusion_ops(self):
class M(nn.Module):
def __init__(self, unary_fn, dim, in_channels, out_channels, dilation, groups, bias, **kwargs):
super().__init__()
self.conv = CONV_MODULES[dim](in_channels, out_channels, dilation=dilation, groups=groups, bias=bias, **kwargs)
self.unary = unary_fn
def forward(self, x):
x = self.conv(x)
x = self.unary(x)
return x
input_shapes = {2: (112, 112), 3: (55, 55, 55)}
for pointwise_info in self._unary_list().values():
for dim in [2, 3]:
channels_last = torch.channels_last if dim == 2 else torch.channels_last_3d
options = itertools.product([True, False], [1, 2], [1, 4], [torch.contiguous_format, channels_last])
for bias, dilation, groups, memory_format in options:
oC = 32 * groups
iC = 3 * groups
x_shape = (1, iC) + input_shapes[dim]
x = torch.randn(x_shape, dtype=torch.float32).to(memory_format=memory_format)
mod = M(pointwise_info.pointwise_module, dim, iC, oC, dilation, groups, bias, kernel_size=3)
mod = mod.to(memory_format=memory_format).eval()
with torch.no_grad():
ref = mod(x)
attr = pointwise_info.attr
scalars = pointwise_info.scalars
algorithm = pointwise_info.algorithm
fused = torch.ops.mkldnn._convolution_pointwise(
x, mod.conv.weight, mod.conv.bias, mod.conv.padding, mod.conv.stride, mod.conv.dilation,
mod.conv.groups, attr, scalars, algorithm
)
self.assertEqual(ref, fused)
def test_conv_binary_fusion_ops(self):
class M(nn.Module):
def __init__(self, binary_fn, dim, in_channels, out_channels, dilation, groups, bias, **kwargs):
super().__init__()
self.conv = CONV_MODULES[dim](in_channels, out_channels, dilation=dilation, groups=groups, bias=bias, **kwargs)
self.binary = binary_fn
def forward(self, x, other):
x = self.conv(x)
x = self.binary(x, other)
return x
input_shapes = {2: (112, 112), 3: (22, 22, 22)}
for pointwise_name, pointwise_fn in self._binary_list().items():
for dim in [2, 3]:
channels_last = torch.channels_last if dim == 2 else torch.channels_last_3d
options = itertools.product([False, True], [True, False], [1, 2], [1, 4], [torch.contiguous_format, channels_last])
for fuse_relu, bias, dilation, groups, memory_format in options:
oC = 32 * groups
iC = 3 * groups
x_shape = (1, iC) + input_shapes[dim]
x = torch.randn(x_shape, dtype=torch.float32).to(memory_format=memory_format)
mod = M(pointwise_fn, dim, iC, oC, dilation, groups, bias, kernel_size=3)
mod = mod.to(memory_format=memory_format).eval()
other = torch.randn_like(mod.conv(x))
with torch.no_grad():
ref = mod(x, other)
unary_attr = None
if fuse_relu:
ref.relu_()
unary_attr = "relu"
attr = pointwise_name
fused = torch.ops.mkldnn._convolution_pointwise(
x, other, mod.conv.weight, mod.conv.bias, mod.conv.padding, mod.conv.stride, mod.conv.dilation,
mod.conv.groups, attr, None, unary_attr, [], None
)
# for binary add, we support inplace version.
if attr == "add":
fused_inplace = torch.ops.mkldnn._convolution_pointwise_(
other, x, mod.conv.weight, mod.conv.bias, mod.conv.padding, mod.conv.stride, mod.conv.dilation,
mod.conv.groups, attr, None, unary_attr, [], None
)
self.assertEqual(ref, other)
self.assertEqual(ref, fused_inplace)
self.assertEqual(ref, fused, atol=5e-4, rtol=5e-4)
def test_linear_binary_fusion_ops(self):
class M(nn.Module):
def __init__(self, binary_fn, in_channels, out_channels, bias, **kwargs):
super().__init__()
self.linear = torch.nn.Linear(
in_channels, out_channels, bias=bias, **kwargs
)
self.binary = binary_fn
def forward(self, x, other):
x = self.linear(x)
x = self.binary(x, other)
return x
out_feature = 20
for pointwise_name, pointwise_fn in self._binary_list().items():
# Tensor with size = [1, 10] and stride = [0, 1] is contiguous tensor
# but it's strides is not default contiguous strides.
options = itertools.product([[[2, 3, 10], None], [[2, 10], None], [[1, 10], [0, 1]]], [True, False])
for (input_shape, input_stride), bias in options:
with torch.no_grad():
mod = M(pointwise_fn, input_shape[-1], out_feature, bias).eval()
v = torch.randn(input_shape)
if input_stride is not None:
v = v.as_strided(input_shape, input_stride)
other = torch.randn(input_shape[:-1] + [out_feature])
ref = mod(v, other)
attr = pointwise_name
fused = torch.ops.mkldnn._linear_pointwise(
v, other, mod.linear.weight, mod.linear.bias, attr
)
self.assertEqual(ref, fused)
def test_conv_transpose_unary_fusion_ops(self):
class M(nn.Module):
def __init__(self, unary_fn, dim, in_channels, out_channels, kernel_size, **kwargs):
super().__init__()
self.conv_transpose = CONV_TRANSPOSE_MODULES[dim](in_channels, out_channels, kernel_size, **kwargs)
self.unary = unary_fn
def forward(self, x):
x = self.conv_transpose(x)
x = self.unary(x)
return x
input_shapes = {2: (28, 28)}
kernel_size = 3
for pointwise_info in self._unary_list().values():
for dim in [2]:
channels_last = torch.channels_last if dim == 2 else torch.channels_last_3d
options = itertools.product([True, False], [1, 2], [1, 4], [torch.contiguous_format, channels_last], [False, True])
for bias, dilation, groups, memory_format, prepack_weight in options:
oC = 32 * groups
iC = 3 * groups
x_shape = (1, iC) + input_shapes[dim]
x = torch.randn(x_shape, dtype=torch.float32).to(memory_format=memory_format)
mod = M(pointwise_info.pointwise_module, dim, iC, oC, kernel_size, dilation=dilation, groups=groups, bias=bias)
mod = mod.to(memory_format=memory_format).eval()
with torch.no_grad():
ref = mod(x)
attr = pointwise_info.attr
scalars = pointwise_info.scalars
algorithm = pointwise_info.algorithm
if prepack_weight:
packed_weight = torch.ops.mkldnn._reorder_convolution_transpose_weight(
mod.conv_transpose.weight,
mod.conv_transpose.padding,
mod.conv_transpose.output_padding,
mod.conv_transpose.stride,
mod.conv_transpose.dilation,
mod.conv_transpose.groups,
x.size())
mod.conv_transpose.weight = torch.nn.Parameter(
packed_weight,
requires_grad=mod.conv_transpose.weight.requires_grad,
)
fused = torch.ops.mkldnn._convolution_transpose_pointwise(
x,
mod.conv_transpose.weight,
mod.conv_transpose.bias,
mod.conv_transpose.padding,
mod.conv_transpose.output_padding,
mod.conv_transpose.stride,
mod.conv_transpose.dilation,
mod.conv_transpose.groups,
attr,
scalars,
algorithm)
self.assertEqual(ref, fused)
if __name__ == "__main__":
run_tests()