[MPS] Enable adaptive avg pool 2d with larger output size (#85726)

* Handle adpative pool 2d forward and backward when ouptut size is larger than input size
* Disallow larger output size if not a multiple of input size
Fixes: https://github.com/pytorch/pytorch/issues/80732
Pull Request resolved: https://github.com/pytorch/pytorch/pull/85726
Approved by: https://github.com/malfet
diff --git a/aten/src/ATen/native/mps/operations/AdaptivePooling.mm b/aten/src/ATen/native/mps/operations/AdaptivePooling.mm
index 1d58de2..c4184ee 100644
--- a/aten/src/ATen/native/mps/operations/AdaptivePooling.mm
+++ b/aten/src/ATen/native/mps/operations/AdaptivePooling.mm
@@ -19,11 +19,27 @@
    int64_t &strideH, int64_t &strideW,
    int64_t &kernel_sizeH, int64_t &kernel_sizeW) {
 
-  strideH = (int64_t) (isizeH / osizeH);
-  strideW = (int64_t) (isizeW / osizeW);
+  TORCH_CHECK((isizeH >= osizeH && isizeW >= osizeW) || (isizeH <= osizeH && isizeW <= osizeW),
+              "Adaptive pool MPS: Input height and width must both be greather than or equal to, or lesser than, output height and width")
 
-  kernel_sizeH = isizeH - (osizeH-1) * strideH;
-  kernel_sizeW = isizeW - (osizeW-1) * strideW;
+  TORCH_CHECK((!(isizeH <= osizeH && isizeW <= osizeW) || (osizeH % isizeH == 0 && osizeW % isizeW == 0)),
+              "Adaptive pool MPS: If output is larger than input, output sizes must be multiples of input sizes")
+
+  if(isizeH >= osizeH) {
+    strideH = (int64_t) (isizeH / osizeH);
+    strideW = (int64_t) (isizeW / osizeW);
+
+    kernel_sizeH = isizeH - (osizeH-1) * strideH;
+    kernel_sizeW = isizeW - (osizeW-1) * strideW;
+  }
+  else {
+    strideH = (int64_t) (osizeH / isizeH);
+    strideW = (int64_t) (osizeW / isizeW);
+
+    kernel_sizeH = osizeH - (isizeH-1) * strideH;
+    kernel_sizeW = osizeW - (isizeW-1) * strideW;
+  }
+
 }
 
 // Adaptive average pooling
@@ -71,13 +87,35 @@
                     strideH, strideW,
                     kernel_sizeH, kernel_sizeW);
 
-  output =  at::avg_pool2d(input,
-                           IntArrayRef({kernel_sizeH, kernel_sizeW}),
-                           IntArrayRef({strideH, strideW}),
-                           IntArrayRef({0, 0}),
-                           false,
-                           true,
-                           c10::nullopt);
+  if(isizeH >= osizeH) {
+    output =  at::avg_pool2d(input,
+                             IntArrayRef({kernel_sizeH, kernel_sizeW}),
+                             IntArrayRef({strideH, strideW}),
+                             IntArrayRef({0, 0}),
+                             false,
+                             true,
+                             c10::nullopt);
+  }  else {
+    Tensor phony_grad = at::ones_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
+    auto num_input_dims = input.sizes().size();
+    int64_t phony_shape[num_input_dims];
+    for(int i = 0; i < num_input_dims - 2; i++)
+      phony_shape[i] = input.size(i);
+    phony_shape[num_input_dims-2] = output_size[0];
+    phony_shape[num_input_dims-1] = output_size[1];
+    phony_grad.resize_(IntArrayRef(phony_shape, num_input_dims));
+    output =  at::avg_pool2d_backward(input,
+                                      phony_grad,
+                                      IntArrayRef({kernel_sizeH, kernel_sizeW}),
+                                      IntArrayRef({strideH, strideW}),
+                                      IntArrayRef({0, 0}),
+                                      false,
+                                      true,
+                                      c10::nullopt);
+    // Multiply output by kernel size
+    output = at::mul(output, kernel_sizeH*kernel_sizeW);
+  }
+
   return output;
 }
 
@@ -138,15 +176,27 @@
                       strideH, strideW,
                       kernel_sizeH, kernel_sizeW);
     auto gradInput = at::zeros_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
-    if (gradInput.numel() != 0)
-      gradInput = at::avg_pool2d_backward(gradOutput,
-                                          input,
-                                          IntArrayRef({kernel_sizeH, kernel_sizeW}),
-                                          IntArrayRef({strideH, strideW}),
-                                          IntArrayRef({0, 0}),
-                                          false,
-                                          true,
-                                          c10::nullopt);
+    if (gradInput.numel() != 0) {
+      if(isizeH >= osizeH) {
+        gradInput = at::avg_pool2d_backward(gradOutput,
+                                            input,
+                                            IntArrayRef({kernel_sizeH, kernel_sizeW}),
+                                            IntArrayRef({strideH, strideW}),
+                                            IntArrayRef({0, 0}),
+                                            false,
+                                            true,
+                                            c10::nullopt);
+      } else {
+        gradInput = at::avg_pool2d(gradOutput,
+                                   IntArrayRef({kernel_sizeH, kernel_sizeW}),
+                                   IntArrayRef({strideH, strideW}),
+                                   IntArrayRef({0, 0}),
+                                   false,
+                                   true,
+                                   c10::nullopt);
+        gradInput = at::mul(gradInput, kernel_sizeH*kernel_sizeW);
+      }
+    }
 
     return gradInput;
 
diff --git a/test/test_mps.py b/test/test_mps.py
index 19de49f..635acd5 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -3861,6 +3861,21 @@
 
         helper((2, 16, 16), (4, 4), False)
 
+        # Output shape larger than input shape
+
+        helper((2, 2, 4, 4), (8, 8), False)
+        helper((2, 2, 2, 2), (4, 4), False)
+        helper((2, 2, 3, 3), (9, 9), False)
+        helper((2, 2, 2, 2), (16, 16), False)
+        helper((2, 2, 2, 16), (16, 16), False)
+
+        helper((2, 4, 4), (16, 16), False)
+
+        try:
+            helper((2, 2, 3, 3), (7, 7), False)
+        except Exception as e:
+            pass
+
     # Test max avg pool2d - when the input size is a multiple of output size
     # Not testing for channels last right now
     def test_adaptive_max_pool2d_simple(self):