MPS: Add adaptive max pool2d op (#78410) Adaptive max pool 2d forward and backward with test Pull Request resolved: https://github.com/pytorch/pytorch/pull/78410 Approved by: https://github.com/albanD
diff --git a/test/test_mps.py b/test/test_mps.py index 72b4be7..f9778b4 100644 --- a/test/test_mps.py +++ b/test/test_mps.py
@@ -3089,6 +3089,50 @@ helper((2, 16, 16), (4, 4), False) + # Test max avg pool2d - when the input size is a multiple of output size + # Not testing for channels last right now + def test_adaptive_max_pool2d_simple(self): + def helper(input_shape, out_shape, return_indices, dtype, channels_last=False): + cpu_x = None + if(dtype in [torch.float16, torch.float32]): + cpu_x = torch.randn(input_shape, device='cpu', dtype=dtype, requires_grad=True) + else: + cpu_x = torch.randint(50, input_shape, device='cpu', dtype=dtype, requires_grad=True) + if(channels_last): + cpu_x = cpu_x.to(memory_format=torch.channels_last) + cpu_x.retain_grad() + x = cpu_x.detach().clone().to('mps').requires_grad_() + + max_result, max_indices = None, None + max_result_cpu, max_indices_cpu = None, None + + if(return_indices): + max_result, max_indices = torch.nn.AdaptiveMaxPool2d(out_shape, return_indices)(x) + max_result_cpu, max_indices_cpu = torch.nn.AdaptiveMaxPool2d(out_shape, return_indices)(cpu_x) + else: + max_result = torch.nn.AdaptiveMaxPool2d(out_shape, return_indices)(x) + max_result_cpu = torch.nn.AdaptiveMaxPool2d(out_shape, return_indices)(cpu_x) + + cpu_grad = torch.randn(max_result_cpu.shape) + grad = cpu_grad.to('mps') + + max_result.backward(gradient=grad) + max_result_cpu.backward(gradient=cpu_grad) + + self.assertEqual(max_result, max_result_cpu) + if(return_indices): + self.assertEqual(max_indices, max_indices_cpu) + self.assertEqual(x.grad, cpu_x.grad) + + for dtype in [torch.float32]: + for return_indices in [False, True]: + helper((2, 2, 4, 4), (2, 2), return_indices, dtype) + helper((2, 2, 9, 9), (3, 3), return_indices, dtype) + helper((2, 2, 9, 9), (9, 9), return_indices, dtype) + helper((2, 2, 16, 16), (2, 2), return_indices, dtype) + helper((2, 2, 16, 16), (2, 16), return_indices, dtype) + helper((2, 16, 16), (4, 4), return_indices, dtype) + def test_gelu_simple(self): def helper(shape): cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)