MPS: Add adaptive max pool2d op (#78410)

Adaptive max pool 2d forward and backward with test

Pull Request resolved: https://github.com/pytorch/pytorch/pull/78410
Approved by: https://github.com/albanD
diff --git a/aten/src/ATen/native/mps/operations/AdaptiveAveragePooling.mm b/aten/src/ATen/native/mps/operations/AdaptiveAveragePooling.mm
deleted file mode 100644
index c828183..0000000
--- a/aten/src/ATen/native/mps/operations/AdaptiveAveragePooling.mm
+++ /dev/null
@@ -1,154 +0,0 @@
-//  Copyright © 2022 Apple Inc.
-
-#include <ATen/ATen.h>
-#include <ATen/Tensor.h>
-#include <ATen/Utils.h>
-#include <ATen/TensorUtils.h>
-#include <ATen/mps/MPSStream.h>
-#include <ATen/native/mps/OperationUtils.h>
-#include <ATen/native/Pool.h>
-#include <torch/library.h>
-
-namespace at {
-namespace native {
-
-
-void set_kernel_params
-  (int64_t isizeH, int64_t isizeW,
-   int64_t osizeH, int64_t osizeW,
-   int64_t &strideH, int64_t &strideW,
-   int64_t &kernel_sizeH, int64_t &kernel_sizeW) {
-
-  strideH = (int64_t) (isizeH / osizeH);
-  strideW = (int64_t) (isizeW / osizeW);
-
-  kernel_sizeH = isizeH - (osizeH-1) * strideH;
-  kernel_sizeW = isizeW - (osizeW-1) * strideW;
-}
-
-Tensor& adaptive_avg_pool2d_out_mps
-  (const Tensor& input,
-   IntArrayRef output_size,
-   Tensor& output) {
-
-  for (int64_t i = 1; i < input.ndimension(); i++) {
-    TORCH_CHECK(input.size(i) > 0,
-      "adaptive_avg_pool2d(): Expected input to have non-zero size for non-batch dimensions, "
-      "but input has sizes ", input.sizes(), " with dimension ", i, " being "
-      "empty");
-  }
-
-  int64_t isizeH = input.size(-2);
-  int64_t isizeW = input.size(-1);
-
-  int64_t osizeH = output_size[0];
-  int64_t osizeW = output_size[1];
-
-  if(input.suggest_memory_format() == at::MemoryFormat::ChannelsLast)
-    TORCH_CHECK(input.ndimension() == 4,
-                    "adaptive_avg_pool2d(): Expected 4D tensor, but got ",
-                    input.sizes())
-
-  switch (input.suggest_memory_format()) {
-    case at::MemoryFormat::Contiguous:
-    case at::MemoryFormat::ChannelsLast:
-      break;
-    default:
-        TORCH_CHECK(
-          false,
-          "Unsupported memory format. Supports only ChannelsLast, Contiguous")
-  }
-
-  int64_t strideH;
-  int64_t strideW;
-  int64_t kernel_sizeH;
-  int64_t kernel_sizeW;
-
-  set_kernel_params(isizeH, isizeW,
-                    osizeH, osizeW,
-                    strideH, strideW,
-                    kernel_sizeH, kernel_sizeW);
-
-  output =  at::avg_pool2d(input,
-                           IntArrayRef({kernel_sizeH, kernel_sizeW}),
-                           IntArrayRef({strideH, strideW}),
-                           IntArrayRef({0, 0}),
-                           false,
-                           true,
-                           c10::nullopt);
-  return output;
-}
-
-Tensor adaptive_avg_pool2d_mps
-  (at::Tensor const& input,
-   IntArrayRef output_size) {
-
-  IntArrayRef output_shape;
-
-  auto osizeH = output_size[0];
-  auto osizeW = output_size[1];
-
-  std::vector<long long> out_dims = {};
-
-  if(input.ndimension() == 4) {
-    auto sizeB = input.size(0);
-    auto sizeD = input.size(1);
-
-    out_dims.push_back(sizeB);
-    out_dims.push_back(sizeD);
-    out_dims.push_back(osizeH);
-    out_dims.push_back(osizeW);
-    output_shape = IntArrayRef(out_dims);
-  }
-  else {
-    auto sizeD = input.size(0);
-    out_dims.push_back(sizeD);
-    out_dims.push_back(osizeH);
-    out_dims.push_back(osizeW);
-    output_shape = IntArrayRef(out_dims);
-  }
-
-  const auto memory_format = input.suggest_memory_format();
-  Tensor output = at::native::empty_mps(
-                      output_shape,
-                      input.scalar_type(),
-                      c10::nullopt,
-                      kMPS,
-                      c10::nullopt,
-                      memory_format);
-  return adaptive_avg_pool2d_out_mps(input, output_size, output);
-
-}
-
-Tensor adaptive_avg_pool2d_backward_mps
-  (const Tensor& gradOutput,
-   const Tensor& input) {
-
-    int64_t isizeH = input.size(-2);
-    int64_t isizeW = input.size(-1);
-    int64_t osizeH = gradOutput.size(-2);
-    int64_t osizeW = gradOutput.size(-1);
-
-    int64_t strideH, strideW, kernel_sizeH, kernel_sizeW;
-
-    set_kernel_params(isizeH, isizeW,
-                      osizeH, osizeW,
-                      strideH, strideW,
-                      kernel_sizeH, kernel_sizeW);
-    auto gradInput = at::zeros_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
-    if (gradInput.numel() != 0)
-      gradInput = at::avg_pool2d_backward(gradOutput,
-                                          input,
-                                          IntArrayRef({kernel_sizeH, kernel_sizeW}),
-                                          IntArrayRef({strideH, strideW}),
-                                          IntArrayRef({0, 0}),
-                                          false,
-                                          true,
-                                          c10::nullopt);
-
-    return gradInput;
-
-}
-
-}
-}
diff --git a/aten/src/ATen/native/mps/operations/AdaptivePooling.mm b/aten/src/ATen/native/mps/operations/AdaptivePooling.mm
new file mode 100644
index 0000000..1d58de2
--- /dev/null
+++ b/aten/src/ATen/native/mps/operations/AdaptivePooling.mm
@@ -0,0 +1,244 @@
+//  Copyright © 2022 Apple Inc.
+
+#include <ATen/ATen.h>
+#include <ATen/Tensor.h>
+#include <ATen/Utils.h>
+#include <ATen/TensorUtils.h>
+#include <ATen/mps/MPSStream.h>
+#include <ATen/native/mps/OperationUtils.h>
+#include <ATen/native/Pool.h>
+#include <torch/library.h>
+
+namespace at {
+namespace native {
+
+
+void set_kernel_params
+  (int64_t isizeH, int64_t isizeW,
+   int64_t osizeH, int64_t osizeW,
+   int64_t &strideH, int64_t &strideW,
+   int64_t &kernel_sizeH, int64_t &kernel_sizeW) {
+
+  strideH = (int64_t) (isizeH / osizeH);
+  strideW = (int64_t) (isizeW / osizeW);
+
+  kernel_sizeH = isizeH - (osizeH-1) * strideH;
+  kernel_sizeW = isizeW - (osizeW-1) * strideW;
+}
+
+// Adaptive average pooling
+
+Tensor& adaptive_avg_pool2d_out_mps
+  (const Tensor& input,
+   IntArrayRef output_size,
+   Tensor& output) {
+
+  for (int64_t i = 1; i < input.ndimension(); i++) {
+    TORCH_CHECK(input.size(i) > 0,
+      "adaptive_avg_pool2d(): Expected input to have non-zero size for non-batch dimensions, "
+      "but input has sizes ", input.sizes(), " with dimension ", i, " being "
+      "empty");
+  }
+
+  int64_t isizeH = input.size(-2);
+  int64_t isizeW = input.size(-1);
+
+  int64_t osizeH = output_size[0];
+  int64_t osizeW = output_size[1];
+
+  if(input.suggest_memory_format() == at::MemoryFormat::ChannelsLast)
+    TORCH_CHECK(input.ndimension() == 4,
+                    "adaptive_avg_pool2d(): Expected 4D tensor, but got ",
+                    input.sizes())
+
+  switch (input.suggest_memory_format()) {
+    case at::MemoryFormat::Contiguous:
+    case at::MemoryFormat::ChannelsLast:
+      break;
+    default:
+        TORCH_CHECK(
+          false,
+          "Unsupported memory format. Supports only ChannelsLast, Contiguous")
+  }
+
+  int64_t strideH;
+  int64_t strideW;
+  int64_t kernel_sizeH;
+  int64_t kernel_sizeW;
+
+  set_kernel_params(isizeH, isizeW,
+                    osizeH, osizeW,
+                    strideH, strideW,
+                    kernel_sizeH, kernel_sizeW);
+
+  output =  at::avg_pool2d(input,
+                           IntArrayRef({kernel_sizeH, kernel_sizeW}),
+                           IntArrayRef({strideH, strideW}),
+                           IntArrayRef({0, 0}),
+                           false,
+                           true,
+                           c10::nullopt);
+  return output;
+}
+
+Tensor adaptive_avg_pool2d_mps
+  (at::Tensor const& input,
+   IntArrayRef output_size) {
+
+  IntArrayRef output_shape;
+
+  auto osizeH = output_size[0];
+  auto osizeW = output_size[1];
+
+  std::vector<long long> out_dims = {};
+
+  if(input.ndimension() == 4) {
+    auto sizeB = input.size(0);
+    auto sizeD = input.size(1);
+
+    out_dims.push_back(sizeB);
+    out_dims.push_back(sizeD);
+    out_dims.push_back(osizeH);
+    out_dims.push_back(osizeW);
+    output_shape = IntArrayRef(out_dims);
+  }
+  else {
+    auto sizeD = input.size(0);
+    out_dims.push_back(sizeD);
+    out_dims.push_back(osizeH);
+    out_dims.push_back(osizeW);
+    output_shape = IntArrayRef(out_dims);
+  }
+
+  const auto memory_format = input.suggest_memory_format();
+  Tensor output = at::native::empty_mps(
+                      output_shape,
+                      input.scalar_type(),
+                      c10::nullopt,
+                      kMPS,
+                      c10::nullopt,
+                      memory_format);
+  return adaptive_avg_pool2d_out_mps(input, output_size, output);
+
+}
+
+Tensor adaptive_avg_pool2d_backward_mps
+  (const Tensor& gradOutput,
+   const Tensor& input) {
+
+    int64_t isizeH = input.size(-2);
+    int64_t isizeW = input.size(-1);
+    int64_t osizeH = gradOutput.size(-2);
+    int64_t osizeW = gradOutput.size(-1);
+
+    int64_t strideH, strideW, kernel_sizeH, kernel_sizeW;
+
+    set_kernel_params(isizeH, isizeW,
+                      osizeH, osizeW,
+                      strideH, strideW,
+                      kernel_sizeH, kernel_sizeW);
+    auto gradInput = at::zeros_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
+    if (gradInput.numel() != 0)
+      gradInput = at::avg_pool2d_backward(gradOutput,
+                                          input,
+                                          IntArrayRef({kernel_sizeH, kernel_sizeW}),
+                                          IntArrayRef({strideH, strideW}),
+                                          IntArrayRef({0, 0}),
+                                          false,
+                                          true,
+                                          c10::nullopt);
+
+    return gradInput;
+
+}
+
+// Adaptive max pooling
+
+TORCH_IMPL_FUNC(adaptive_max_pool2d_out_mps)
+  (const Tensor& input,
+   IntArrayRef output_size,
+   const Tensor& output,
+   const Tensor& indices) {
+
+  for (int64_t i = 1; i < input.ndimension(); i++) {
+    TORCH_CHECK(input.size(i) > 0,
+      "adaptive_max_pool2d(): Expected input to have non-zero size for non-batch dimensions, "
+      "but input has sizes ", input.sizes(), " with dimension ", i, " being "
+      "empty");
+  }
+
+  int64_t isizeH = input.size(-2);
+  int64_t isizeW = input.size(-1);
+
+  int64_t osizeH = output_size[0];
+  int64_t osizeW = output_size[1];
+
+  if(input.suggest_memory_format() == at::MemoryFormat::ChannelsLast)
+    TORCH_CHECK(input.ndimension() == 4,
+                    "adaptive_avg_pool2d(): Expected 4D tensor, but got ",
+                    input.sizes())
+
+  switch (input.suggest_memory_format()) {
+    case at::MemoryFormat::Contiguous:
+    case at::MemoryFormat::ChannelsLast:
+      break;
+    default:
+        TORCH_CHECK(
+          false,
+          "Unsupported memory format. Supports only ChannelsLast, Contiguous")
+  }
+
+  int64_t strideH;
+  int64_t strideW;
+  int64_t kernel_sizeH;
+  int64_t kernel_sizeW;
+
+  set_kernel_params(isizeH, isizeW,
+                    osizeH, osizeW,
+                    strideH, strideW,
+                    kernel_sizeH, kernel_sizeW);
+
+  auto outputs = at::max_pool2d_with_indices(input,
+                              IntArrayRef({kernel_sizeH, kernel_sizeW}),
+                              IntArrayRef({strideH, strideW}),
+                              IntArrayRef({0, 0}),
+                              IntArrayRef({1, 1}),
+                              false);
+
+  output.copy_(std::get<0>(outputs));
+  indices.copy_(std::get<1>(outputs));
+}
+
+TORCH_IMPL_FUNC(adaptive_max_pool2d_backward_out_mps)
+  (const Tensor& gradOutput,
+   const Tensor& input,
+   const Tensor& indices,
+   const Tensor& gradInput) {
+
+  int64_t isizeH = input.size(-2);
+  int64_t isizeW = input.size(-1);
+  int64_t osizeH = gradOutput.size(-2);
+  int64_t osizeW = gradOutput.size(-1);
+
+  int64_t strideH, strideW, kernel_sizeH, kernel_sizeW;
+
+  set_kernel_params(isizeH, isizeW,
+                    osizeH, osizeW,
+                    strideH, strideW,
+                    kernel_sizeH, kernel_sizeW);
+
+  auto returnGradInput = at::max_pool2d_with_indices_backward(gradOutput,
+                                                              input,
+                                                              IntArrayRef({kernel_sizeH, kernel_sizeW}),
+                                                              IntArrayRef({strideH, strideW}),
+                                                              IntArrayRef({0, 0}),
+                                                              IntArrayRef({1, 1}),
+                                                              false,
+                                                              indices);
+
+  gradInput.copy_(returnGradInput);
+
+}
+
+}
+}
diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml
index c81d6ba..b97c8c9 100644
--- a/aten/src/ATen/native/native_functions.yaml
+++ b/aten/src/ATen/native/native_functions.yaml
@@ -9791,6 +9791,7 @@
   dispatch:
     CPU: adaptive_max_pool2d_out_cpu
     CUDA: adaptive_max_pool2d_out_cuda
+    MPS: adaptive_max_pool2d_out_mps
 
 # Return: (Tensor output, Tensor indices)
 - func: adaptive_max_pool2d(Tensor self, int[2] output_size) -> (Tensor, Tensor)
@@ -9803,6 +9804,7 @@
   dispatch:
     CPU: adaptive_max_pool2d_backward_out_cpu
     CUDA: adaptive_max_pool2d_backward_out_cuda
+    MPS: adaptive_max_pool2d_backward_out_mps
 
 - func: adaptive_max_pool2d_backward(Tensor grad_output, Tensor self, Tensor indices) -> Tensor
   python_module: nn
diff --git a/test/test_mps.py b/test/test_mps.py
index 72b4be7..f9778b4 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -3089,6 +3089,50 @@
 
         helper((2, 16, 16), (4, 4), False)
 
+    # Test max avg pool2d - when the input size is a multiple of output size
+    # Not testing for channels last right now
+    def test_adaptive_max_pool2d_simple(self):
+        def helper(input_shape, out_shape, return_indices, dtype, channels_last=False):
+            cpu_x = None
+            if(dtype in [torch.float16, torch.float32]):
+                cpu_x = torch.randn(input_shape, device='cpu', dtype=dtype, requires_grad=True)
+            else:
+                cpu_x = torch.randint(50, input_shape, device='cpu', dtype=dtype, requires_grad=True)
+            if(channels_last):
+                cpu_x = cpu_x.to(memory_format=torch.channels_last)
+                cpu_x.retain_grad()
+            x = cpu_x.detach().clone().to('mps').requires_grad_()
+
+            max_result, max_indices = None, None
+            max_result_cpu, max_indices_cpu = None, None
+
+            if(return_indices):
+                max_result, max_indices = torch.nn.AdaptiveMaxPool2d(out_shape, return_indices)(x)
+                max_result_cpu, max_indices_cpu = torch.nn.AdaptiveMaxPool2d(out_shape, return_indices)(cpu_x)
+            else:
+                max_result = torch.nn.AdaptiveMaxPool2d(out_shape, return_indices)(x)
+                max_result_cpu = torch.nn.AdaptiveMaxPool2d(out_shape, return_indices)(cpu_x)
+
+            cpu_grad = torch.randn(max_result_cpu.shape)
+            grad = cpu_grad.to('mps')
+
+            max_result.backward(gradient=grad)
+            max_result_cpu.backward(gradient=cpu_grad)
+
+            self.assertEqual(max_result, max_result_cpu)
+            if(return_indices):
+                self.assertEqual(max_indices, max_indices_cpu)
+            self.assertEqual(x.grad, cpu_x.grad)
+
+        for dtype in [torch.float32]:
+            for return_indices in [False, True]:
+                helper((2, 2, 4, 4), (2, 2), return_indices, dtype)
+                helper((2, 2, 9, 9), (3, 3), return_indices, dtype)
+                helper((2, 2, 9, 9), (9, 9), return_indices, dtype)
+                helper((2, 2, 16, 16), (2, 2), return_indices, dtype)
+                helper((2, 2, 16, 16), (2, 16), return_indices, dtype)
+                helper((2, 16, 16), (4, 4), return_indices, dtype)
+
     def test_gelu_simple(self):
         def helper(shape):
             cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)