| # Owner(s): ["oncall: mobile"] |
| |
| import io |
| import itertools |
| import unittest |
| |
| from hypothesis import assume, given, strategies as st |
| |
| import torch |
| import torch.backends.xnnpack |
| import torch.testing._internal.hypothesis_utils as hu |
| from torch.nn import functional as F |
| from torch.testing import FileCheck |
| from torch.testing._internal.common_utils import ( |
| IS_FBCODE, |
| run_tests, |
| slowTest, |
| TEST_WITH_TSAN, |
| TestCase, |
| ) |
| from torch.utils.mobile_optimizer import optimize_for_mobile |
| |
| |
| @unittest.skipUnless( |
| torch.backends.xnnpack.enabled, |
| " XNNPACK must be enabled for these tests." " Please build with USE_XNNPACK=1.", |
| ) |
| @unittest.skipIf( |
| TEST_WITH_TSAN, |
| "TSAN fails with XNNPACK. Does not seem to have a good reason for failures.", |
| ) |
| class TestXNNPACKOps(TestCase): |
| @unittest.skip( |
| "Fails on some platforms, see https://github.com/pytorch/pytorch/issues/73488" |
| ) |
| @given( |
| batch_size=st.integers(0, 3), |
| data_shape=hu.array_shapes(1, 3, 2, 64), |
| weight_output_dim=st.integers(2, 64), |
| use_bias=st.booleans(), |
| ) |
| def test_linear(self, batch_size, data_shape, weight_output_dim, use_bias): |
| data_shape = [batch_size] + list(data_shape) |
| input_data = torch.rand(data_shape) |
| weight = torch.rand((weight_output_dim, data_shape[-1])) |
| if use_bias: |
| bias = torch.rand(weight_output_dim) |
| else: |
| bias = None |
| ref_result = F.linear(input_data, weight, bias) |
| packed_weight_bias = torch.ops.prepacked.linear_clamp_prepack(weight, bias) |
| output_linearprepacked = torch.ops.prepacked.linear_clamp_run( |
| input_data, packed_weight_bias |
| ) |
| torch.testing.assert_close( |
| ref_result, output_linearprepacked, rtol=1e-2, atol=1e-3 |
| ) |
| |
| @given( |
| input_size=st.integers(2, 32), |
| weight_output_dim=st.integers(2, 64), |
| use_bias=st.booleans(), |
| ) |
| def test_linear_1d_input(self, input_size, weight_output_dim, use_bias): |
| input_data = torch.rand(input_size) |
| weight = torch.rand((weight_output_dim, input_data.shape[-1])) |
| if use_bias: |
| bias = torch.rand(weight_output_dim) |
| else: |
| bias = None |
| ref_result = F.linear(input_data, weight, bias) |
| packed_weight_bias = torch.ops.prepacked.linear_clamp_prepack(weight, bias) |
| output_linearprepacked = torch.ops.prepacked.linear_clamp_run( |
| input_data, packed_weight_bias |
| ) |
| torch.testing.assert_close( |
| ref_result, output_linearprepacked, rtol=1e-2, atol=1e-3 |
| ) |
| |
| @given( |
| batch_size=st.integers(0, 3), |
| input_channels_per_group=st.integers(1, 32), |
| height=st.integers(5, 64), |
| width=st.integers(5, 64), |
| output_channels_per_group=st.integers(1, 32), |
| groups=st.integers(1, 16), |
| kernel_h=st.integers(1, 7), |
| kernel_w=st.integers(1, 7), |
| stride_h=st.integers(1, 2), |
| stride_w=st.integers(1, 2), |
| pad_h=st.integers(0, 2), |
| pad_w=st.integers(0, 2), |
| dilation=st.integers(1, 2), |
| use_bias=st.booleans(), |
| format=st.sampled_from( |
| [None, torch.preserve_format, torch.contiguous_format, torch.channels_last] |
| ), |
| ) |
| def test_conv2d( |
| self, |
| batch_size, |
| input_channels_per_group, |
| height, |
| width, |
| output_channels_per_group, |
| groups, |
| kernel_h, |
| kernel_w, |
| stride_h, |
| stride_w, |
| pad_h, |
| pad_w, |
| dilation, |
| use_bias, |
| format, |
| ): |
| input_channels = input_channels_per_group * groups |
| output_channels = output_channels_per_group * groups |
| kernels = (kernel_h, kernel_w) |
| strides = (stride_h, stride_w) |
| paddings = (pad_h, pad_w) |
| dilations = (dilation, dilation) |
| assume(height + 2 * paddings[0] >= dilations[0] * (kernels[0] - 1) + 1) |
| assume(width + 2 * paddings[1] >= dilations[1] * (kernels[1] - 1) + 1) |
| |
| input_data = torch.rand((batch_size, input_channels, height, width)) |
| if format is not None: |
| input_data = input_data.contiguous(memory_format=format) |
| weight = torch.rand( |
| (output_channels, input_channels_per_group, kernel_h, kernel_w) |
| ) |
| bias = None |
| if use_bias: |
| bias = torch.rand(output_channels) |
| |
| ref_result = F.conv2d( |
| input_data, weight, bias, strides, paddings, dilations, groups |
| ) |
| packed_weight_bias = torch.ops.prepacked.conv2d_clamp_prepack( |
| weight, bias, strides, paddings, dilations, groups |
| ) |
| xnnpack_result = torch.ops.prepacked.conv2d_clamp_run( |
| input_data, packed_weight_bias |
| ) |
| torch.testing.assert_close(ref_result, xnnpack_result, rtol=1e-2, atol=1e-3) |
| |
| @given( |
| batch_size=st.integers(1, 3), |
| input_channels_per_group=st.integers(1, 32), |
| height=st.integers(5, 64), |
| width=st.integers(5, 64), |
| output_channels_per_group=st.integers(1, 32), |
| groups=st.integers(1, 16), |
| kernel_h=st.integers(1, 7), |
| kernel_w=st.integers(1, 7), |
| stride_h=st.integers(1, 2), |
| stride_w=st.integers(1, 2), |
| pad_h=st.integers(0, 2), |
| pad_w=st.integers(0, 2), |
| output_pad_h=st.integers(0, 2), |
| output_pad_w=st.integers(0, 2), |
| dilation=st.integers(1, 2), |
| use_bias=st.booleans(), |
| format=st.sampled_from( |
| [None, torch.preserve_format, torch.contiguous_format, torch.channels_last] |
| ), |
| ) |
| def test_conv2d_transpose( |
| self, |
| batch_size, |
| input_channels_per_group, |
| height, |
| width, |
| output_channels_per_group, |
| groups, |
| kernel_h, |
| kernel_w, |
| stride_h, |
| stride_w, |
| pad_h, |
| pad_w, |
| output_pad_h, |
| output_pad_w, |
| dilation, |
| use_bias, |
| format, |
| ): |
| input_channels = input_channels_per_group * groups |
| output_channels = output_channels_per_group * groups |
| kernels = (kernel_h, kernel_w) |
| strides = (stride_h, stride_w) |
| paddings = (pad_h, pad_w) |
| output_paddings = (output_pad_h, output_pad_w) |
| dilations = (dilation, dilation) |
| assume(height + 2 * paddings[0] >= dilations[0] * (kernels[0] - 1) + 1) |
| assume(width + 2 * paddings[1] >= dilations[1] * (kernels[1] - 1) + 1) |
| assume((output_pad_h < stride_h) and (output_pad_h < dilation)) |
| assume((output_pad_w < stride_w) and (output_pad_w < dilation)) |
| |
| input_data = torch.rand((batch_size, input_channels, height, width)) |
| if format is not None: |
| input_data = input_data.contiguous(memory_format=format) |
| weight = torch.rand( |
| (input_channels, output_channels_per_group, kernel_h, kernel_w) |
| ) |
| bias = None |
| if use_bias: |
| bias = torch.rand(output_channels) |
| |
| # Note that groups/dilation is in reverse order from conv2d |
| ref_result = F.conv_transpose2d( |
| input_data, |
| weight, |
| bias, |
| strides, |
| paddings, |
| output_paddings, |
| groups, |
| dilation, |
| ) |
| packed_weight_bias = torch.ops.prepacked.conv2d_transpose_clamp_prepack( |
| weight, bias, strides, paddings, output_paddings, dilations, groups |
| ) |
| xnnpack_result = torch.ops.prepacked.conv2d_transpose_clamp_run( |
| input_data, packed_weight_bias |
| ) |
| torch.testing.assert_close( |
| ref_result.contiguous(), xnnpack_result.contiguous(), rtol=1e-2, atol=1e-3 |
| ) |
| |
| |
| @unittest.skipUnless( |
| torch.backends.xnnpack.enabled, |
| " XNNPACK must be enabled for these tests." " Please build with USE_XNNPACK=1.", |
| ) |
| @unittest.skipIf( |
| TEST_WITH_TSAN, |
| "TSAN fails with XNNPACK. Does not seem to have a good reason for failures.", |
| ) |
| class TestXNNPACKSerDes(TestCase): |
| @unittest.skip( |
| "Fails on some platforms, see https://github.com/pytorch/pytorch/issues/73488" |
| ) |
| @given( |
| batch_size=st.integers(0, 3), |
| data_shape=hu.array_shapes(1, 3, 2, 64), |
| weight_output_dim=st.integers(2, 64), |
| use_bias=st.booleans(), |
| ) |
| def test_linear(self, batch_size, data_shape, weight_output_dim, use_bias): |
| class Linear(torch.nn.Module): |
| def __init__(self, weight, bias=None): |
| super().__init__() |
| self.weight = weight |
| self.bias = bias |
| |
| def forward(self, x): |
| return F.linear(x, self.weight, self.bias) |
| |
| class LinearPrePacked(torch.nn.Module): |
| def __init__(self, weight, bias=None): |
| super().__init__() |
| self.packed_weight_bias = torch.ops.prepacked.linear_clamp_prepack( |
| weight, bias |
| ) |
| |
| def forward(self, x): |
| return torch.ops.prepacked.linear_clamp_run(x, self.packed_weight_bias) |
| |
| data_shape = [batch_size] + list(data_shape) |
| weight = torch.rand((weight_output_dim, data_shape[-1])) |
| if use_bias: |
| bias = torch.rand(weight_output_dim) |
| else: |
| bias = None |
| scripted_linear = torch.jit.script(Linear(weight, bias)) |
| scripted_linear_clamp_prepacked = torch.jit.script( |
| LinearPrePacked(weight, bias) |
| ) |
| input_data = torch.rand(data_shape) |
| ref_result = scripted_linear(input_data) |
| output_linearprepacked = scripted_linear_clamp_prepacked(input_data) |
| torch.testing.assert_close( |
| ref_result, output_linearprepacked, rtol=1e-2, atol=1e-3 |
| ) |
| |
| # Serialize the modules and then deserialize |
| input_data = torch.rand(data_shape) |
| buffer = io.BytesIO() |
| torch.jit.save(scripted_linear, buffer) |
| buffer.seek(0) |
| deserialized_linear = torch.jit.load(buffer) |
| buffer = io.BytesIO() |
| torch.jit.save(scripted_linear_clamp_prepacked, buffer) |
| buffer.seek(0) |
| deserialized_linear_clamp_prepacked = torch.jit.load(buffer) |
| ref_result = deserialized_linear(input_data) |
| output_linearprepacked = deserialized_linear_clamp_prepacked(input_data) |
| torch.testing.assert_close( |
| ref_result, output_linearprepacked, rtol=1e-2, atol=1e-3 |
| ) |
| |
| @given( |
| batch_size=st.integers(0, 3), |
| input_channels_per_group=st.integers(1, 32), |
| height=st.integers(5, 64), |
| width=st.integers(5, 64), |
| output_channels_per_group=st.integers(1, 32), |
| groups=st.integers(1, 16), |
| kernel_h=st.integers(1, 7), |
| kernel_w=st.integers(1, 7), |
| stride_h=st.integers(1, 2), |
| stride_w=st.integers(1, 2), |
| pad_h=st.integers(0, 2), |
| pad_w=st.integers(0, 2), |
| dilation=st.integers(1, 2), |
| use_bias=st.booleans(), |
| format=st.sampled_from( |
| [None, torch.preserve_format, torch.contiguous_format, torch.channels_last] |
| ), |
| ) |
| def test_conv2d( |
| self, |
| batch_size, |
| input_channels_per_group, |
| height, |
| width, |
| output_channels_per_group, |
| groups, |
| kernel_h, |
| kernel_w, |
| stride_h, |
| stride_w, |
| pad_h, |
| pad_w, |
| dilation, |
| use_bias, |
| format, |
| ): |
| class Conv2D(torch.nn.Module): |
| def __init__(self, weight, bias, strides, paddings, dilations, groups): |
| super().__init__() |
| self.weight = weight |
| self.bias = bias |
| self.strides = strides |
| self.paddings = paddings |
| self.dilations = dilations |
| self.groups = groups |
| |
| def forward(self, x): |
| return F.conv2d( |
| x, |
| self.weight, |
| self.bias, |
| self.strides, |
| self.paddings, |
| self.dilations, |
| self.groups, |
| ) |
| |
| class Conv2DPrePacked(torch.nn.Module): |
| def __init__(self, weight, bias, strides, paddings, dilations, groups): |
| super().__init__() |
| self.packed_weight_bias = torch.ops.prepacked.conv2d_clamp_prepack( |
| weight, bias, strides, paddings, dilations, groups |
| ) |
| |
| def forward(self, x): |
| return torch.ops.prepacked.conv2d_clamp_run(x, self.packed_weight_bias) |
| |
| input_channels = input_channels_per_group * groups |
| output_channels = output_channels_per_group * groups |
| kernels = (kernel_h, kernel_w) |
| strides = (stride_h, stride_w) |
| paddings = (pad_h, pad_w) |
| dilations = (dilation, dilation) |
| assume(height + 2 * paddings[0] >= dilations[0] * (kernels[0] - 1) + 1) |
| assume(width + 2 * paddings[1] >= dilations[1] * (kernels[1] - 1) + 1) |
| |
| input_data = torch.rand((batch_size, input_channels, height, width)) |
| if format is not None: |
| input_data = input_data.contiguous(memory_format=format) |
| weight = torch.rand( |
| (output_channels, input_channels_per_group, kernel_h, kernel_w) |
| ) |
| bias = None |
| if use_bias: |
| bias = torch.rand(output_channels) |
| |
| scripted_conv2d = torch.jit.script( |
| Conv2D(weight, bias, strides, paddings, dilations, groups) |
| ) |
| scripted_conv2d_clamp_prepacked = torch.jit.script( |
| Conv2DPrePacked(weight, bias, strides, paddings, dilations, groups) |
| ) |
| ref_result = scripted_conv2d(input_data) |
| xnnpack_result = scripted_conv2d_clamp_prepacked(input_data) |
| torch.testing.assert_close(ref_result, xnnpack_result, rtol=1e-2, atol=1e-3) |
| |
| # Serialize the modules and then deserialize |
| input_data = torch.rand((batch_size, input_channels, height, width)) |
| if format is not None: |
| input_data = input_data.contiguous(memory_format=format) |
| buffer = io.BytesIO() |
| torch.jit.save(scripted_conv2d, buffer) |
| buffer.seek(0) |
| deserialized_conv2d = torch.jit.load(buffer) |
| buffer = io.BytesIO() |
| torch.jit.save(scripted_conv2d_clamp_prepacked, buffer) |
| buffer.seek(0) |
| deserialized_conv2d_clamp_prepacked = torch.jit.load(buffer) |
| ref_result = deserialized_conv2d(input_data) |
| xnnpack_result = deserialized_conv2d_clamp_prepacked(input_data) |
| torch.testing.assert_close(ref_result, xnnpack_result, rtol=1e-2, atol=1e-3) |
| |
| @given( |
| batch_size=st.integers(0, 3), |
| input_channels_per_group=st.integers(1, 32), |
| height=st.integers(5, 64), |
| width=st.integers(5, 64), |
| output_channels_per_group=st.integers(1, 32), |
| groups=st.integers(1, 16), |
| kernel_h=st.integers(1, 7), |
| kernel_w=st.integers(1, 7), |
| stride_h=st.integers(1, 2), |
| stride_w=st.integers(1, 2), |
| pad_h=st.integers(0, 2), |
| pad_w=st.integers(0, 2), |
| output_pad_h=st.integers(0, 2), |
| output_pad_w=st.integers(0, 2), |
| dilation=st.integers(1, 2), |
| use_bias=st.booleans(), |
| format=st.sampled_from( |
| [None, torch.preserve_format, torch.contiguous_format, torch.channels_last] |
| ), |
| ) |
| def test_conv2d_transpose( |
| self, |
| batch_size, |
| input_channels_per_group, |
| height, |
| width, |
| output_channels_per_group, |
| groups, |
| kernel_h, |
| kernel_w, |
| stride_h, |
| stride_w, |
| pad_h, |
| pad_w, |
| output_pad_h, |
| output_pad_w, |
| dilation, |
| use_bias, |
| format, |
| ): |
| class Conv2DT(torch.nn.Module): |
| def __init__( |
| self, |
| weight, |
| bias, |
| strides, |
| paddings, |
| output_paddings, |
| dilations, |
| groups, |
| ): |
| super().__init__() |
| self.weight = weight |
| self.bias = bias |
| self.strides = strides |
| self.paddings = paddings |
| self.output_paddings = output_paddings |
| self.dilations = dilations |
| self.groups = groups |
| |
| def forward(self, x): |
| return F.conv_transpose2d( |
| x, |
| self.weight, |
| self.bias, |
| self.strides, |
| self.paddings, |
| self.output_paddings, |
| self.groups, |
| self.dilations, |
| ) |
| |
| class Conv2DTPrePacked(torch.nn.Module): |
| def __init__( |
| self, |
| weight, |
| bias, |
| strides, |
| paddings, |
| output_paddings, |
| dilations, |
| groups, |
| ): |
| super().__init__() |
| self.packed_weight_bias = ( |
| torch.ops.prepacked.conv2d_transpose_clamp_prepack( |
| weight, |
| bias, |
| strides, |
| paddings, |
| output_paddings, |
| dilations, |
| groups, |
| ) |
| ) |
| |
| def forward(self, x): |
| return torch.ops.prepacked.conv2d_transpose_clamp_run( |
| x, self.packed_weight_bias |
| ) |
| |
| input_channels = input_channels_per_group * groups |
| output_channels = output_channels_per_group * groups |
| kernels = (kernel_h, kernel_w) |
| strides = (stride_h, stride_w) |
| paddings = (pad_h, pad_w) |
| output_paddings = (output_pad_h, output_pad_w) |
| dilations = (dilation, dilation) |
| assume(height + 2 * paddings[0] >= dilations[0] * (kernels[0] - 1) + 1) |
| assume(width + 2 * paddings[1] >= dilations[1] * (kernels[1] - 1) + 1) |
| assume((output_pad_h < stride_h) and (output_pad_h < dilation)) |
| assume((output_pad_w < stride_w) and (output_pad_w < dilation)) |
| |
| input_data = torch.rand((batch_size, input_channels, height, width)) |
| if format is not None: |
| input_data = input_data.contiguous(memory_format=format) |
| weight = torch.rand( |
| (input_channels, output_channels_per_group, kernel_h, kernel_w) |
| ) |
| bias = None |
| if use_bias: |
| bias = torch.rand(output_channels) |
| |
| scripted_conv2d = torch.jit.script( |
| Conv2DT(weight, bias, strides, paddings, output_paddings, dilations, groups) |
| ) |
| scripted_conv2d_clamp_prepacked = torch.jit.script( |
| Conv2DTPrePacked( |
| weight, bias, strides, paddings, output_paddings, dilations, groups |
| ) |
| ) |
| ref_result = scripted_conv2d(input_data) |
| xnnpack_result = scripted_conv2d_clamp_prepacked(input_data) |
| torch.testing.assert_close(ref_result, xnnpack_result, rtol=1e-2, atol=1e-3) |
| |
| # Serialize the modules and then deserialize |
| input_data = torch.rand((batch_size, input_channels, height, width)) |
| if format is not None: |
| input_data = input_data.contiguous(memory_format=format) |
| buffer = io.BytesIO() |
| torch.jit.save(scripted_conv2d, buffer) |
| buffer.seek(0) |
| deserialized_conv2d = torch.jit.load(buffer) |
| buffer = io.BytesIO() |
| torch.jit.save(scripted_conv2d_clamp_prepacked, buffer) |
| buffer.seek(0) |
| deserialized_conv2d_clamp_prepacked = torch.jit.load(buffer) |
| ref_result = deserialized_conv2d(input_data) |
| xnnpack_result = deserialized_conv2d_clamp_prepacked(input_data) |
| torch.testing.assert_close(ref_result, xnnpack_result, rtol=1e-2, atol=1e-3) |
| |
| @unittest.skip( |
| "Fails on some platforms, see https://github.com/pytorch/pytorch/issues/73488" |
| ) |
| @given( |
| batch_size=st.integers(0, 3), |
| input_channels_per_group=st.integers(1, 32), |
| height=st.integers(5, 64), |
| width=st.integers(5, 64), |
| output_channels_per_group=st.integers(1, 32), |
| groups=st.integers(1, 16), |
| kernel_h=st.integers(1, 7), |
| kernel_w=st.integers(1, 7), |
| stride_h=st.integers(1, 2), |
| stride_w=st.integers(1, 2), |
| pad_h=st.integers(0, 2), |
| pad_w=st.integers(0, 2), |
| dilation=st.integers(1, 2), |
| linear_weight_output_dim=st.integers(2, 64), |
| use_bias=st.booleans(), |
| format=st.sampled_from( |
| [None, torch.preserve_format, torch.contiguous_format, torch.channels_last] |
| ), |
| ) |
| def test_combined_model( |
| self, |
| batch_size, |
| input_channels_per_group, |
| height, |
| width, |
| output_channels_per_group, |
| groups, |
| kernel_h, |
| kernel_w, |
| stride_h, |
| stride_w, |
| pad_h, |
| pad_w, |
| dilation, |
| linear_weight_output_dim, |
| use_bias, |
| format, |
| ): |
| class M(torch.nn.Module): |
| def __init__( |
| self, |
| conv_weight, |
| conv_bias, |
| linear_weight, |
| linear_bias, |
| strides, |
| paddings, |
| dilations, |
| groups, |
| ): |
| super().__init__() |
| self.conv_weight = conv_weight |
| self.conv_bias = conv_bias |
| self.linear_weight = linear_weight |
| self.linear_bias = linear_bias |
| self.strides = strides |
| self.paddings = paddings |
| self.dilations = dilations |
| self.groups = groups |
| |
| def forward(self, x): |
| o = F.conv2d( |
| x, |
| self.conv_weight, |
| self.conv_bias, |
| self.strides, |
| self.paddings, |
| self.dilations, |
| self.groups, |
| ) |
| o = o.permute([0, 2, 3, 1]) |
| o = F.linear(o, self.linear_weight, self.linear_bias) |
| return F.relu(o) |
| |
| class MPrePacked(torch.nn.Module): |
| def __init__( |
| self, |
| conv_weight, |
| conv_bias, |
| linear_weight, |
| linear_bias, |
| strides, |
| paddings, |
| dilations, |
| groups, |
| ): |
| super().__init__() |
| self.conv2d_clamp_run_weight_bias = ( |
| torch.ops.prepacked.conv2d_clamp_prepack( |
| conv_weight, conv_bias, strides, paddings, dilations, groups |
| ) |
| ) |
| self.linear_clamp_run_weight_bias = ( |
| torch.ops.prepacked.linear_clamp_prepack(linear_weight, linear_bias) |
| ) |
| |
| def forward(self, x): |
| o = torch.ops.prepacked.conv2d_clamp_run( |
| x, self.conv2d_clamp_run_weight_bias |
| ) |
| o = o.permute([0, 2, 3, 1]) |
| o = torch.ops.prepacked.linear_clamp_run( |
| o, self.linear_clamp_run_weight_bias |
| ) |
| return F.relu(o) |
| |
| input_channels = input_channels_per_group * groups |
| output_channels = output_channels_per_group * groups |
| kernels = (kernel_h, kernel_w) |
| strides = (stride_h, stride_w) |
| paddings = (pad_h, pad_w) |
| dilations = (dilation, dilation) |
| assume(height + 2 * paddings[0] >= dilations[0] * (kernels[0] - 1) + 1) |
| assume(width + 2 * paddings[1] >= dilations[1] * (kernels[1] - 1) + 1) |
| |
| input_data = torch.rand((batch_size, input_channels, height, width)) |
| if format is not None: |
| input_data = input_data.contiguous(memory_format=format) |
| conv_weight = torch.rand( |
| (output_channels, input_channels_per_group, kernel_h, kernel_w) |
| ) |
| conv_bias = None |
| if use_bias: |
| conv_bias = torch.rand(output_channels) |
| |
| # This is done just to find the output shape of the result |
| # so that the shape of weight for the following linear layer |
| # can be determined. |
| result = F.conv2d( |
| input_data, conv_weight, conv_bias, strides, paddings, dilations, groups |
| ) |
| linear_input_shape = result.shape[1] |
| |
| linear_weight = torch.rand((linear_weight_output_dim, linear_input_shape)) |
| linear_bias = None |
| if use_bias: |
| linear_bias = torch.rand(linear_weight_output_dim) |
| |
| scripted_m = torch.jit.script( |
| M( |
| conv_weight, |
| conv_bias, |
| linear_weight, |
| linear_bias, |
| strides, |
| paddings, |
| dilations, |
| groups, |
| ) |
| ) |
| scripted_m_prepacked = torch.jit.script( |
| MPrePacked( |
| conv_weight, |
| conv_bias, |
| linear_weight, |
| linear_bias, |
| strides, |
| paddings, |
| dilations, |
| groups, |
| ) |
| ) |
| ref_result = scripted_m(input_data) |
| xnnpack_result = scripted_m_prepacked(input_data) |
| torch.testing.assert_close(ref_result, xnnpack_result, rtol=1e-2, atol=1e-3) |
| |
| # Serialize the modules and then deserialize |
| input_data = torch.rand((batch_size, input_channels, height, width)) |
| input_data = input_data.contiguous(memory_format=torch.channels_last) |
| buffer = io.BytesIO() |
| torch.jit.save(scripted_m, buffer) |
| buffer.seek(0) |
| deserialized_m = torch.jit.load(buffer) |
| buffer = io.BytesIO() |
| torch.jit.save(scripted_m_prepacked, buffer) |
| buffer.seek(0) |
| deserialized_m_prepacked = torch.jit.load(buffer) |
| ref_result = deserialized_m(input_data) |
| xnnpack_result = deserialized_m_prepacked(input_data) |
| torch.testing.assert_close(ref_result, xnnpack_result, rtol=1e-2, atol=1e-3) |
| |
| |
| @unittest.skipUnless( |
| torch.backends.xnnpack.enabled, |
| " XNNPACK must be enabled for these tests." " Please build with USE_XNNPACK=1.", |
| ) |
| @unittest.skipIf( |
| TEST_WITH_TSAN, |
| "TSAN fails with XNNPACK. Does not seem to have a good reason for failures.", |
| ) |
| class TestXNNPACKRewritePass(TestCase): |
| @staticmethod |
| def validate_transformed_module( |
| # To please flake |
| self, |
| pattern_count_map, |
| data_shape, |
| prepack_removal=False, |
| fuse_clamping_ops=False, |
| ): |
| input_data = torch.normal(1, 20, size=data_shape) |
| |
| for jit_method in ["script", "trace"]: |
| module_instance = self |
| if jit_method == "script": |
| scripted_model = torch.jit.script(module_instance) |
| else: |
| scripted_model = torch.jit.trace(module_instance, input_data) |
| scripted_model.eval() |
| ref_result = scripted_model(input_data) |
| torch._C._jit_pass_insert_prepacked_ops(scripted_model._c) |
| if fuse_clamping_ops or prepack_removal: |
| scripted_model._c = torch._C._freeze_module(scripted_model._c) |
| if fuse_clamping_ops: |
| torch._C._jit_pass_fuse_clamp_w_prepacked_linear_conv(scripted_model._c) |
| if prepack_removal: |
| torch._C._jit_pass_fold_prepacking_ops(scripted_model._c) |
| |
| buffer = io.BytesIO() |
| torch.jit.save(scripted_model, buffer) |
| buffer.seek(0) |
| deserialized_scripted_model = torch.jit.load(buffer) |
| for pattern, v in pattern_count_map.items(): |
| if v == 0: |
| FileCheck().check(pattern).run(deserialized_scripted_model.graph) |
| elif v == -1: |
| FileCheck().check_not(pattern).run( |
| deserialized_scripted_model.graph |
| ) |
| else: |
| FileCheck().check_count(pattern, v, exactly=True).run( |
| deserialized_scripted_model.graph |
| ) |
| xnnpack_result = deserialized_scripted_model(input_data) |
| torch.testing.assert_close(ref_result, xnnpack_result, rtol=1e-2, atol=1e-3) |
| |
| def test_linear(self): |
| data_shape = [2, 3, 32] |
| weight_output_dim = 24 |
| weight_shape = (weight_output_dim, data_shape[-1]) |
| |
| class Linear(torch.nn.Module): |
| def __init__(self) -> None: |
| super().__init__() |
| self.weight = torch.nn.Parameter( |
| torch.rand(weight_shape), requires_grad=False |
| ) |
| self.bias = torch.nn.Parameter( |
| torch.rand(weight_output_dim), requires_grad=False |
| ) |
| |
| def forward(self, x): |
| return F.linear(x, self.weight, self.bias) |
| |
| class LinearNoBias(torch.nn.Module): |
| def __init__(self) -> None: |
| super().__init__() |
| self.weight = torch.nn.Parameter( |
| torch.rand(weight_shape), requires_grad=False |
| ) |
| |
| def forward(self, x): |
| return F.linear(x, self.weight, None) |
| |
| # Linear with bias pattern. |
| pattern_count_map = { |
| "Tensor = prim::CallFunction": -1, |
| "prepacked::linear_clamp_prepack": 1, |
| "prepacked::linear_clamp_run": 1, |
| } |
| TestXNNPACKRewritePass.validate_transformed_module( |
| Linear(), pattern_count_map, data_shape |
| ) |
| TestXNNPACKRewritePass.validate_transformed_module( |
| LinearNoBias(), pattern_count_map, data_shape |
| ) |
| |
| # Conv params |
| batch_size = 2 |
| input_channels_per_group = 6 |
| height = 16 |
| width = 16 |
| output_channels_per_group = 6 |
| groups = 4 |
| kernel_h = kernel_w = 3 |
| stride_h = stride_w = 1 |
| pad_h = pad_w = 1 |
| output_pad_h = output_pad_w = 0 |
| dilation = 1 |
| input_channels = input_channels_per_group * groups |
| output_channels = output_channels_per_group * groups |
| kernels = (kernel_h, kernel_w) |
| strides = (stride_h, stride_w) |
| paddings = (pad_h, pad_w) |
| output_paddings = (output_pad_h, output_pad_w) |
| dilations = (dilation, dilation) |
| conv_weight_shape = ( |
| output_channels, |
| input_channels_per_group, |
| kernel_h, |
| kernel_w, |
| ) |
| conv_transpose_weight_shape = ( |
| input_channels, |
| output_channels_per_group, |
| kernel_h, |
| kernel_w, |
| ) |
| conv_bias_shape = output_channels |
| |
| class Conv2D(torch.nn.Module): |
| def __init__(self) -> None: |
| super().__init__() |
| self.weight = torch.nn.Parameter( |
| torch.rand(conv_weight_shape), requires_grad=False |
| ) |
| self.bias = torch.nn.Parameter( |
| torch.rand(conv_bias_shape), requires_grad=False |
| ) |
| self.strides = strides |
| self.paddings = paddings |
| self.dilations = dilations |
| self.groups = groups |
| |
| def forward(self, x): |
| return F.conv2d( |
| x, |
| self.weight, |
| self.bias, |
| self.strides, |
| self.paddings, |
| self.dilations, |
| self.groups, |
| ) |
| |
| class Conv2DT(torch.nn.Module): |
| def __init__(self) -> None: |
| super().__init__() |
| self.weight = torch.nn.Parameter( |
| torch.rand(conv_transpose_weight_shape), requires_grad=False |
| ) |
| self.bias = torch.nn.Parameter( |
| torch.rand(conv_bias_shape), requires_grad=False |
| ) |
| self.strides = strides |
| self.paddings = paddings |
| self.output_paddings = output_paddings |
| self.dilations = dilations |
| self.groups = groups |
| |
| def forward(self, x): |
| return F.conv_transpose2d( |
| x, |
| self.weight, |
| self.bias, |
| self.strides, |
| self.paddings, |
| self.output_paddings, |
| self.groups, |
| self.dilations, |
| ) |
| |
| data_shape = (batch_size, input_channels, height, width) |
| pattern_count_map = { |
| "Tensor = aten::conv2d": -1, |
| "prepacked::conv2d_clamp_prepack": 1, |
| "prepacked::conv2d_clamp_run": 1, |
| } |
| TestXNNPACKRewritePass.validate_transformed_module( |
| Conv2D(), pattern_count_map, data_shape |
| ) |
| |
| transpose_data_shape = (batch_size, input_channels, height, width) |
| transpose_pattern_count_map = { |
| "Tensor = aten::conv_transpose2d": -1, |
| "prepacked::conv2d_transpose_clamp_prepack": 1, |
| "prepacked::conv2d_transpose_clamp_run": 1, |
| } |
| TestXNNPACKRewritePass.validate_transformed_module( |
| Conv2DT(), transpose_pattern_count_map, data_shape |
| ) |
| |
| input_data = torch.rand((batch_size, input_channels, height, width)) |
| conv_weight = torch.rand( |
| (output_channels, input_channels_per_group, kernel_h, kernel_w) |
| ) |
| conv_bias = torch.rand(output_channels) |
| result = F.conv2d( |
| input_data, conv_weight, conv_bias, strides, paddings, dilations, groups |
| ) |
| linear_input_shape = result.shape[1] |
| linear_weight_shape = (weight_output_dim, linear_input_shape) |
| |
| class M(torch.nn.Module): |
| def __init__(self, activation_fn=F.relu): |
| super().__init__() |
| self.conv_weight = torch.nn.Parameter( |
| torch.rand(conv_weight_shape), requires_grad=False |
| ) |
| self.conv_bias = torch.nn.Parameter( |
| torch.rand(conv_bias_shape), requires_grad=False |
| ) |
| self.linear_weight = torch.nn.Parameter( |
| torch.rand(linear_weight_shape), requires_grad=False |
| ) |
| self.linear_bias = torch.nn.Parameter( |
| torch.rand(weight_output_dim), requires_grad=False |
| ) |
| self.strides = strides |
| self.paddings = paddings |
| self.dilations = dilations |
| self.groups = groups |
| self.activation_fn = activation_fn |
| |
| def forward(self, x): |
| o = F.conv2d( |
| x, |
| self.conv_weight, |
| self.conv_bias, |
| self.strides, |
| self.paddings, |
| self.dilations, |
| self.groups, |
| ) |
| o = self.activation_fn(o) |
| o = o.permute([0, 2, 3, 1]) |
| o = F.linear(o, self.linear_weight, self.linear_bias) |
| return self.activation_fn(o) |
| |
| pattern_count_map = { |
| "Tensor = aten::conv2d": -1, |
| "prepacked::conv2d_clamp_prepack": 1, |
| "prepacked::conv2d_clamp_run": 1, |
| "prepacked::linear_clamp_prepack": 1, |
| "prepacked::linear_clamp_run": 1, |
| } |
| TestXNNPACKRewritePass.validate_transformed_module( |
| M(), pattern_count_map, data_shape |
| ) |
| pattern_count_map["prepacked::conv2d_clamp_prepack"] = -1 |
| pattern_count_map["Tensor = prim::CallFunction"] = -1 |
| pattern_count_map["prepacked::linear_clamp_prepack"] = -1 |
| TestXNNPACKRewritePass.validate_transformed_module( |
| M(), pattern_count_map, data_shape, prepack_removal=True |
| ) |
| |
| # Not inplace relu fusion test. |
| pattern_count_map = { |
| "aten::relu": 2, |
| "prepacked::conv2d_clamp_prepack": -1, |
| "prepacked::conv2d_clamp_run": 1, |
| "prepacked::linear_clamp_prepack": -1, |
| "prepacked::linear_clamp_run": 1, |
| } |
| TestXNNPACKRewritePass.validate_transformed_module( |
| M(), pattern_count_map, data_shape, prepack_removal=True |
| ) |
| pattern_count_map["prepacked::conv2d_clamp_prepack"] = -1 |
| pattern_count_map["prepacked::linear_clamp_prepack"] = -1 |
| pattern_count_map["aten::relu"] = -1 |
| TestXNNPACKRewritePass.validate_transformed_module( |
| M(), |
| pattern_count_map, |
| data_shape, |
| prepack_removal=True, |
| fuse_clamping_ops=True, |
| ) |
| |
| # Inplace relu fusion test. |
| pattern_count_map = { |
| "aten::relu": 2, |
| "prepacked::conv2d_clamp_prepack": -1, |
| "prepacked::conv2d_clamp_run": 1, |
| "prepacked::linear_clamp_prepack": -1, |
| "prepacked::linear_clamp_run": 1, |
| } |
| TestXNNPACKRewritePass.validate_transformed_module( |
| M(F.relu_), pattern_count_map, data_shape, prepack_removal=True |
| ) |
| pattern_count_map["prepacked::conv2d_clamp_prepack"] = -1 |
| pattern_count_map["prepacked::linear_clamp_prepack"] = -1 |
| pattern_count_map["aten::relu"] = -1 |
| TestXNNPACKRewritePass.validate_transformed_module( |
| M(F.relu_), |
| pattern_count_map, |
| data_shape, |
| prepack_removal=True, |
| fuse_clamping_ops=True, |
| ) |
| |
| # Not inplace hardtanh fusion test. |
| pattern_count_map = { |
| "aten::hardtanh": 2, |
| "prepacked::conv2d_clamp_prepack": -1, |
| "prepacked::conv2d_clamp_run": 1, |
| "prepacked::linear_clamp_prepack": -1, |
| "prepacked::linear_clamp_run": 1, |
| } |
| TestXNNPACKRewritePass.validate_transformed_module( |
| M(F.hardtanh), pattern_count_map, data_shape, prepack_removal=True |
| ) |
| pattern_count_map["prepacked::conv2d_clamp_prepack"] = -1 |
| pattern_count_map["prepacked::linear_clamp_prepack"] = -1 |
| pattern_count_map["aten::hardtanh"] = -1 |
| TestXNNPACKRewritePass.validate_transformed_module( |
| M(F.hardtanh), |
| pattern_count_map, |
| data_shape, |
| prepack_removal=True, |
| fuse_clamping_ops=True, |
| ) |
| |
| # Inplace hardtanh fusion test. |
| pattern_count_map = { |
| "aten::hardtanh_": 2, |
| "prepacked::conv2d_clamp_prepack": -1, |
| "prepacked::conv2d_clamp_run": 1, |
| "prepacked::linear_clamp_prepack": -1, |
| "prepacked::linear_clamp_run": 1, |
| } |
| TestXNNPACKRewritePass.validate_transformed_module( |
| M(F.hardtanh_), pattern_count_map, data_shape, prepack_removal=True |
| ) |
| pattern_count_map["prepacked::conv2d_clamp_prepack"] = -1 |
| pattern_count_map["prepacked::linear_clamp_prepack"] = -1 |
| pattern_count_map["aten::hardtanh_"] = -1 |
| TestXNNPACKRewritePass.validate_transformed_module( |
| M(F.hardtanh_), |
| pattern_count_map, |
| data_shape, |
| prepack_removal=True, |
| fuse_clamping_ops=True, |
| ) |
| |
| class MFusionAntiPattern(torch.nn.Module): |
| def __init__(self) -> None: |
| super().__init__() |
| self.linear_weight = torch.nn.Parameter( |
| torch.rand(linear_weight_shape), requires_grad=False |
| ) |
| self.linear_bias = torch.nn.Parameter( |
| torch.rand(weight_output_dim), requires_grad=False |
| ) |
| self.strides = strides |
| self.paddings = paddings |
| self.dilations = dilations |
| self.groups = groups |
| |
| def forward(self, x): |
| o = F.linear(x, self.linear_weight, self.linear_bias) |
| o = F.relu(o) |
| o = F.hardtanh(o) |
| return o |
| |
| # Unfusable hardtanh. |
| pattern_count_map = { |
| "aten::hardtanh": 1, # hardtanh cannot be. |
| "aten::relu": -1, # relu is fused. |
| "prepacked::linear_clamp_prepack": -1, |
| "prepacked::linear_clamp_run": 1, |
| } |
| TestXNNPACKRewritePass.validate_transformed_module( |
| MFusionAntiPattern(), |
| pattern_count_map, |
| (16, linear_weight_shape[1]), |
| prepack_removal=True, |
| fuse_clamping_ops=True, |
| ) |
| |
| class MFusionAntiPatternParamMinMax(torch.nn.Module): |
| def __init__(self) -> None: |
| super().__init__() |
| self.linear_weight = torch.nn.Parameter( |
| torch.rand(linear_weight_shape), requires_grad=False |
| ) |
| self.linear_bias = torch.nn.Parameter( |
| torch.rand(weight_output_dim), requires_grad=False |
| ) |
| self.strides = strides |
| self.paddings = paddings |
| self.dilations = dilations |
| self.groups = groups |
| |
| def forward(self, x): |
| min = x[0, 0] |
| max = min + 10 |
| o = F.linear(x, self.linear_weight, self.linear_bias) |
| o = F.hardtanh(o, min, max) |
| return o |
| |
| # Unfusable hardtanh. |
| pattern_count_map = { |
| "aten::hardtanh": 1, # hardtanh cannot be. |
| "prepacked::linear_clamp_prepack": -1, |
| "prepacked::linear_clamp_run": 1, |
| } |
| TestXNNPACKRewritePass.validate_transformed_module( |
| MFusionAntiPatternParamMinMax(), |
| pattern_count_map, |
| (16, linear_weight_shape[1]), |
| prepack_removal=True, |
| fuse_clamping_ops=True, |
| ) |
| |
| def test_decomposed_linear(self): |
| data_shape = [2, 32] |
| weight_output_dim = 24 |
| weight_shape = (weight_output_dim, data_shape[-1]) |
| |
| class DecomposedLinearAddmm(torch.nn.Module): |
| def __init__(self) -> None: |
| super().__init__() |
| self.weight = torch.nn.Parameter( |
| torch.rand(weight_shape), requires_grad=False |
| ) |
| self.bias = torch.nn.Parameter( |
| torch.rand(weight_output_dim), requires_grad=False |
| ) |
| |
| def forward(self, x): |
| weight_t = self.weight.t() |
| return torch.addmm(self.bias, x, weight_t) |
| |
| class DecomposedLinearMatmulAdd(torch.nn.Module): |
| def __init__(self) -> None: |
| super().__init__() |
| self.weight = torch.nn.Parameter( |
| torch.rand(weight_shape), requires_grad=False |
| ) |
| self.bias = torch.nn.Parameter( |
| torch.rand(weight_output_dim), requires_grad=False |
| ) |
| |
| def forward(self, x): |
| weight_t = self.weight.t() |
| y = torch.matmul(x, weight_t) |
| res = y.add_(self.bias) |
| return res |
| |
| class DecomposedLinearMatmul(torch.nn.Module): |
| def __init__(self) -> None: |
| super().__init__() |
| self.weight = torch.nn.Parameter( |
| torch.rand(weight_shape), requires_grad=False |
| ) |
| self.bias = torch.nn.Parameter( |
| torch.rand(weight_output_dim), requires_grad=False |
| ) |
| |
| def forward(self, x): |
| weight_t = self.weight.t() |
| res = torch.matmul(x, weight_t) |
| return res |
| |
| # Linear with bias pattern. |
| pattern_count_map = { |
| "Tensor = prim::CallFunction": -1, |
| "prepacked::linear_clamp_prepack": 1, |
| "prepacked::linear_clamp_run": 1, |
| } |
| TestXNNPACKRewritePass.validate_transformed_module( |
| DecomposedLinearAddmm(), pattern_count_map, data_shape |
| ) |
| TestXNNPACKRewritePass.validate_transformed_module( |
| DecomposedLinearMatmulAdd(), pattern_count_map, data_shape |
| ) |
| TestXNNPACKRewritePass.validate_transformed_module( |
| DecomposedLinearMatmul(), pattern_count_map, data_shape |
| ) |
| |
| |
| @unittest.skipUnless( |
| torch.backends.xnnpack.enabled, |
| " XNNPACK must be enabled for these tests." " Please build with USE_XNNPACK=1.", |
| ) |
| @unittest.skipIf( |
| TEST_WITH_TSAN, |
| "TSAN is not fork-safe since we're forking in a multi-threaded environment", |
| ) |
| class TestXNNPACKConv1dTransformPass(TestCase): |
| @staticmethod |
| def validate_transform_conv1d_to_conv2d( |
| self, pattern_count_transformed_map, pattern_count_optimized_map, data_shape |
| ): |
| input_data = torch.normal(1, 20, size=data_shape) |
| |
| for jit_method in ["script", "trace"]: |
| module_instance = self |
| if jit_method == "script": |
| scripted_model = torch.jit.script(module_instance) |
| else: |
| scripted_model = torch.jit.trace(module_instance, input_data) |
| scripted_model.eval() |
| ref_result = scripted_model(input_data) |
| torch._C._jit_pass_transform_conv1d_to_conv2d(scripted_model._c) |
| optimized_scripted_model = optimize_for_mobile(scripted_model) |
| |
| buffer = io.BytesIO() |
| torch.jit.save(scripted_model, buffer) |
| buffer.seek(0) |
| deserialized_scripted_model = torch.jit.load(buffer) |
| |
| for pattern, v in pattern_count_transformed_map.items(): |
| if v == 0: |
| FileCheck().check(pattern).run(deserialized_scripted_model.graph) |
| elif v == -1: |
| FileCheck().check_not(pattern).run( |
| deserialized_scripted_model.graph |
| ) |
| else: |
| FileCheck().check_count(pattern, v, exactly=True).run( |
| deserialized_scripted_model.graph |
| ) |
| transformed_result = deserialized_scripted_model(input_data) |
| torch.testing.assert_close( |
| ref_result, transformed_result, rtol=1e-2, atol=1e-3 |
| ) |
| |
| optimized_buffer = io.BytesIO() |
| torch.jit.save(optimized_scripted_model, optimized_buffer) |
| optimized_buffer.seek(0) |
| deserialized_optimized_scripted_model = torch.jit.load(optimized_buffer) |
| |
| for pattern, v in pattern_count_optimized_map.items(): |
| if v == 0: |
| FileCheck().check(pattern).run( |
| deserialized_optimized_scripted_model.graph |
| ) |
| elif v == -1: |
| FileCheck().check_not(pattern).run( |
| deserialized_optimized_scripted_model.graph |
| ) |
| else: |
| FileCheck().check_count(pattern, v, exactly=True).run( |
| deserialized_optimized_scripted_model.graph |
| ) |
| xnnpack_result = deserialized_optimized_scripted_model(input_data) |
| torch.testing.assert_close(ref_result, xnnpack_result, rtol=1e-2, atol=1e-3) |
| |
| @unittest.skipIf(IS_FBCODE, "T137513244") |
| def test_conv1d_basic(self): |
| batch_size_list = range(1, 3) |
| input_channels_per_group_list = range(10, 12) |
| width_list = range(10, 12) |
| output_channels_per_group_list = range(10, 12) |
| groups_list = range(1, 3) |
| kernel_list = range(1, 4) |
| stride_list = range(1, 3) |
| padding_list = range(0, 3) |
| dilation_list = range(1, 3) |
| |
| for hparams in itertools.product( |
| batch_size_list, |
| input_channels_per_group_list, |
| width_list, |
| output_channels_per_group_list, |
| groups_list, |
| kernel_list, |
| stride_list, |
| padding_list, |
| dilation_list, |
| ): |
| ( |
| batch_size, |
| input_channels_per_group, |
| width, |
| output_channels_per_group, |
| groups, |
| kernel, |
| stride, |
| padding, |
| dilation, |
| ) = hparams |
| |
| input_channels = input_channels_per_group * groups |
| output_channels = output_channels_per_group * groups |
| conv_weight_shape = (output_channels, input_channels_per_group, kernel) |
| conv_bias_shape = output_channels |
| |
| class Conv1D(torch.nn.Module): |
| def __init__(self) -> None: |
| super().__init__() |
| self.weight = torch.nn.Parameter( |
| torch.rand(conv_weight_shape), requires_grad=False |
| ) |
| self.bias = torch.nn.Parameter( |
| torch.rand(conv_bias_shape), requires_grad=False |
| ) |
| self.stride = stride |
| self.padding = padding |
| self.dilation = dilation |
| self.groups = groups |
| |
| def forward(self, x): |
| return F.conv1d( |
| x, |
| self.weight, |
| self.bias, |
| self.stride, |
| self.padding, |
| self.dilation, |
| self.groups, |
| ) |
| |
| data_shape = (batch_size, input_channels, width) |
| pattern_count_transformed_map = { |
| "Tensor = aten::conv1d": -1, |
| "Tensor = aten::conv2d": 1, |
| } |
| pattern_count_optimized_map = { |
| "Tensor = aten::conv1d": -1, |
| "Tensor = aten::conv2d": -1, |
| "prepacked::conv2d_clamp_prepack": -1, |
| "prepacked::conv2d_clamp_run": 1, |
| } |
| |
| TestXNNPACKConv1dTransformPass.validate_transform_conv1d_to_conv2d( |
| Conv1D(), |
| pattern_count_transformed_map, |
| pattern_count_optimized_map, |
| data_shape, |
| ) |
| |
| # See https://github.com/pytorch/pytorch/issues/46066 |
| @slowTest |
| def test_conv1d_with_relu_fc(self): |
| batch_size_list = range(1, 3) |
| input_channels_per_group_list = range(10, 12) |
| width_list = range(10, 12) |
| output_channels_per_group_list = range(10, 12) |
| groups_list = range(1, 3) |
| kernel_list = range(1, 4) |
| stride_list = range(1, 3) |
| padding_list = range(0, 3) |
| dilation_list = range(1, 3) |
| output_features_list = range(1, 3) |
| |
| for hparams in itertools.product( |
| batch_size_list, |
| input_channels_per_group_list, |
| width_list, |
| output_channels_per_group_list, |
| groups_list, |
| kernel_list, |
| stride_list, |
| padding_list, |
| dilation_list, |
| output_features_list, |
| ): |
| ( |
| batch_size, |
| input_channels_per_group, |
| width, |
| output_channels_per_group, |
| groups, |
| kernel, |
| stride, |
| padding, |
| dilation, |
| output_features, |
| ) = hparams |
| |
| input_channels = input_channels_per_group * groups |
| output_channels = output_channels_per_group * groups |
| conv_weight_shape = (output_channels, input_channels_per_group, kernel) |
| conv_bias_shape = output_channels |
| conv_output_width = ( |
| int((width + 2 * padding - dilation * (kernel - 1) - 1) / stride) + 1 |
| ) |
| fc_weight_shape = (output_features, output_channels * conv_output_width) |
| fc_bias_shape = output_features |
| |
| class Net(torch.nn.Module): |
| def __init__(self) -> None: |
| super().__init__() |
| self.conv_weight = torch.nn.Parameter( |
| torch.rand(conv_weight_shape), requires_grad=False |
| ) |
| self.conv_bias = torch.nn.Parameter( |
| torch.rand(conv_bias_shape), requires_grad=False |
| ) |
| self.stride = stride |
| self.padding = padding |
| self.dilation = dilation |
| self.groups = groups |
| |
| self.fc_weight = torch.nn.Parameter( |
| torch.rand(fc_weight_shape), requires_grad=False |
| ) |
| self.fc_bias = torch.nn.Parameter( |
| torch.rand(fc_bias_shape), requires_grad=False |
| ) |
| |
| def forward(self, x): |
| x = F.conv1d( |
| x, |
| self.conv_weight, |
| self.conv_bias, |
| self.stride, |
| self.padding, |
| self.dilation, |
| self.groups, |
| ) |
| x = F.relu(x) |
| x = x.view(x.size(0), -1) |
| x = F.linear(x, self.fc_weight, self.fc_bias) |
| return x |
| |
| data_shape = (batch_size, input_channels, width) |
| pattern_count_transformed_map = { |
| "Tensor = aten::conv1d": -1, |
| "Tensor = aten::conv2d": 1, |
| } |
| pattern_count_optimized_map = { |
| "Tensor = aten::conv1d": -1, |
| "Tensor = aten::conv2d": -1, |
| "prepacked::conv2d_clamp_prepack": -1, |
| "prepacked::conv2d_clamp_run": 1, |
| } |
| TestXNNPACKConv1dTransformPass.validate_transform_conv1d_to_conv2d( |
| Net(), |
| pattern_count_transformed_map, |
| pattern_count_optimized_map, |
| data_shape, |
| ) |
| |
| |
| if __name__ == "__main__": |
| run_tests() |