Fix max_pool2d decomposition for empty list and integer limits (#129106) Pull Request resolved: https://github.com/pytorch/pytorch/pull/129106 Approved by: https://github.com/peterbell10, https://github.com/lezcano, https://github.com/malfet ghstack dependencies: #129096, #129097
diff --git a/test/inductor/test_torchinductor_opinfo.py b/test/inductor/test_torchinductor_opinfo.py index 4a42c52..93dd2a0 100644 --- a/test/inductor/test_torchinductor_opinfo.py +++ b/test/inductor/test_torchinductor_opinfo.py
@@ -439,7 +439,6 @@ "nn.functional.gaussian_nll_loss": {f16}, "nn.functional.grid_sample": {f32, f64}, "nn.functional.interpolate.area": {f16}, - "nn.functional.max_pool2d": {f16, f32, f64, i32, i64}, "nn.functional.nll_loss": {f16, f32, f64}, "normal": {f16, f32, f64}, "put": {f16, f32, f64},
diff --git a/test/nn/test_pooling.py b/test/nn/test_pooling.py index eae5370..2175067 100644 --- a/test/nn/test_pooling.py +++ b/test/nn/test_pooling.py
@@ -1120,6 +1120,34 @@ helper(10, 512, 31, 31, 3, stride=2) helper(1, 129, 8, 8, 3, stride=2) + @onlyCPU + @dtypes(torch.int32, torch.int64) + def test_max_pool2d_corner_cases(self, device, dtype): + def check(x, args, expected, memory_format): + model = torch.nn.MaxPool2d(*args) + if isinstance(x, list): + x = torch.tensor(x, device=device, dtype=dtype).to( + memory_format=memory_format + ) + expected = torch.tensor(expected, device=device, dtype=dtype).to( + memory_format=memory_format + ) + self.assertEqual(model(x), expected) + + # Pooling args: (kernel_size, stride, padding, dilation, return_indices, ceil_mode) + check( + [[[[-1, -2], [-3, -4]]]], + (2, 2, 1, 2, False, True), + [[[[-4, -4], [-4, -4]]]], + torch.contiguous_format, + ) + check( + [[[[-1, -2], [-3, -4]]]], + (2, 2, 1, 2, False, True), + [[[[-4, -4], [-4, -4]]]], + torch.channels_last, + ) + @onlyNativeDeviceTypes @dtypes(torch.half, torch.bfloat16, torch.float, torch.double) @dtypesIfCUDA(torch.half, torch.float, torch.double)
diff --git a/test/test_mps.py b/test/test_mps.py index 77f3198..54b47d1 100644 --- a/test/test_mps.py +++ b/test/test_mps.py
@@ -731,9 +731,6 @@ 'nn.functional.interpolatearea': None, 'nn.functional.interpolatebicubic': None, 'nn.functional.interpolatetrilinear': None, - # TODO: max_pool2d for integral types fails the numerical test - 'nn.functional.max_pool2d': (integral_types() if product_version < 14.0 else - [torch.int64, torch.int32, torch.int16, torch.int8]), 'nn.functional.max_unpool1dgrad': None, 'nn.functional.max_unpool2dgrad': None, 'nn.functional.max_unpool3dgrad': None, @@ -911,6 +908,7 @@ # Error in TestConsistencyCPU.test_output_match_isin_cpu fails for integers, # not reproducible in later OS. Added assert to op if used in < 14.0 'isin': [torch.int64, torch.int32, torch.int16, torch.uint8, torch.int8], + 'nn.functional.max_pool2d': [torch.uint8], }) UNDEFINED_XFAILLIST = {