| # Owner(s): ["oncall: mobile"] |
| |
| import torch |
| from torch.nn import functional as F |
| |
| from torch.testing._internal.common_utils import TestCase, run_tests |
| from torch.testing import FileCheck |
| import io |
| |
| class TestMetalRewritePass(TestCase): |
| @staticmethod |
| def validate_transformed_module( |
| # To please flake |
| self, |
| pattern_count_map, |
| data_shape, |
| prepack_removal=False, |
| fuse_clamping_ops=False): |
| module_instance = self |
| scripted_model = torch.jit.script(module_instance) |
| scripted_model.eval() |
| input_data = torch.normal(1, 20, size=data_shape) |
| ref_result = scripted_model(input_data) |
| torch._C._jit_pass_metal_insert_prepacked_ops(scripted_model._c) |
| if fuse_clamping_ops or prepack_removal: |
| scripted_model._c = torch._C._freeze_module(scripted_model._c) |
| if fuse_clamping_ops: |
| torch._C._jit_pass_metal_fuse_clamp_w_prepacked_conv(scripted_model._c) |
| if prepack_removal: |
| torch._C._jit_pass_metal_fold_prepacking_ops(scripted_model._c) |
| |
| buffer = io.BytesIO() |
| torch.jit.save(scripted_model, buffer) |
| buffer.seek(0) |
| deserialized_scripted_model = torch.jit.load(buffer) |
| for pattern, v in pattern_count_map.items(): |
| if (v == 0): |
| FileCheck().check(pattern).run(deserialized_scripted_model.graph) |
| elif (v == -1): |
| FileCheck().check_not(pattern).run(deserialized_scripted_model.graph) |
| else: |
| FileCheck().check_count(pattern, v, exactly=True).run(deserialized_scripted_model.graph) |
| |
| def test_conv(self): |
| # Conv params |
| batch_size = 2 |
| input_channels_per_group = 6 |
| height = 16 |
| width = 16 |
| output_channels_per_group = 6 |
| groups = 4 |
| kernel_h = kernel_w = 3 |
| stride_h = stride_w = 1 |
| pad_h = pad_w = 1 |
| dilation = 1 |
| input_channels = input_channels_per_group * groups |
| output_channels = output_channels_per_group * groups |
| kernels = (kernel_h, kernel_w) |
| strides = (stride_h, stride_w) |
| paddings = (pad_h, pad_w) |
| dilations = (dilation, dilation) |
| conv_weight_shape = (output_channels, input_channels_per_group, kernel_h, kernel_w) |
| conv_bias_shape = (output_channels) |
| |
| class Conv2D(torch.nn.Module): |
| def __init__(self) -> None: |
| super().__init__() |
| self.weight = torch.nn.Parameter(torch.rand(conv_weight_shape), requires_grad=False) |
| self.bias = torch.nn.Parameter(torch.rand(conv_bias_shape), requires_grad=False) |
| self.strides = strides |
| self.paddings = paddings |
| self.dilations = dilations |
| self.groups = groups |
| |
| def forward(self, x): |
| return F.conv2d(x, self.weight, self.bias, |
| self.strides, self.paddings, self.dilations, self.groups) |
| |
| data_shape = (batch_size, input_channels, height, width) |
| pattern_count_map = {"Tensor = aten::conv2d": -1, |
| "metal_prepack::conv2d_prepack": 1, |
| "metal_prepack::conv2d_run": 1} |
| TestMetalRewritePass.validate_transformed_module(Conv2D(), pattern_count_map, data_shape) |
| |
| class Conv2DRelu(torch.nn.Module): |
| def __init__(self) -> None: |
| super().__init__() |
| self.weight = torch.nn.Parameter(torch.rand(conv_weight_shape), requires_grad=False) |
| self.bias = torch.nn.Parameter(torch.rand(conv_bias_shape), requires_grad=False) |
| self.strides = strides |
| self.paddings = paddings |
| self.dilations = dilations |
| self.groups = groups |
| |
| def forward(self, x): |
| o = F.conv2d(x, self.weight, self.bias, |
| self.strides, self.paddings, self.dilations, self.groups) |
| o = F.relu(o) |
| return o |
| |
| data_shape = (batch_size, input_channels, height, width) |
| pattern_count_map = {"Tensor = aten::conv2d": -1, |
| "metal_prepack::conv2d_prepack": 1, |
| "metal_prepack::conv2d_run": 1} |
| TestMetalRewritePass.validate_transformed_module( |
| Conv2DRelu(), pattern_count_map, data_shape) |
| |
| pattern_count_map["aten::relu"] = 1 |
| pattern_count_map["metal_prepack::conv2d_prepack"] = -1 |
| TestMetalRewritePass.validate_transformed_module( |
| Conv2DRelu(), |
| pattern_count_map, |
| data_shape, |
| prepack_removal=True) |
| pattern_count_map["aten::relu"] = -1 |
| TestMetalRewritePass.validate_transformed_module( |
| Conv2DRelu(), |
| pattern_count_map, |
| data_shape, |
| prepack_removal=True, |
| fuse_clamping_ops=True) |
| |
| |
| class Conv2DHardtanh(torch.nn.Module): |
| def __init__(self) -> None: |
| super().__init__() |
| self.weight = torch.nn.Parameter(torch.rand(conv_weight_shape), requires_grad=False) |
| self.bias = torch.nn.Parameter(torch.rand(conv_bias_shape), requires_grad=False) |
| self.strides = strides |
| self.paddings = paddings |
| self.dilations = dilations |
| self.groups = groups |
| |
| def forward(self, x): |
| o = F.conv2d(x, self.weight, self.bias, |
| self.strides, self.paddings, self.dilations, self.groups) |
| o = F.hardtanh(o) |
| return o |
| |
| data_shape = (batch_size, input_channels, height, width) |
| pattern_count_map = {"Tensor = aten::conv2d": -1, |
| "metal_prepack::conv2d_prepack": 1, |
| "metal_prepack::conv2d_run": 1} |
| TestMetalRewritePass.validate_transformed_module(Conv2DHardtanh(), pattern_count_map, data_shape) |
| pattern_count_map["aten::hardtanh"] = 1 |
| pattern_count_map["metal_prepack::conv2d_prepack"] = -1 |
| TestMetalRewritePass.validate_transformed_module( |
| Conv2DHardtanh(), |
| pattern_count_map, |
| data_shape, |
| prepack_removal=True) |
| pattern_count_map["aten::hardtanh"] = -1 |
| TestMetalRewritePass.validate_transformed_module( |
| Conv2DRelu(), |
| pattern_count_map, |
| data_shape, |
| prepack_removal=True, |
| fuse_clamping_ops=True) |
| |
| if __name__ == "__main__": |
| run_tests() |