[MPS] Enable adaptive avg pool 2d with larger output size (#85726)
* Handle adpative pool 2d forward and backward when ouptut size is larger than input size
* Disallow larger output size if not a multiple of input size
Fixes: https://github.com/pytorch/pytorch/issues/80732
Pull Request resolved: https://github.com/pytorch/pytorch/pull/85726
Approved by: https://github.com/malfet
diff --git a/aten/src/ATen/native/mps/operations/AdaptivePooling.mm b/aten/src/ATen/native/mps/operations/AdaptivePooling.mm
index 1d58de2..c4184ee 100644
--- a/aten/src/ATen/native/mps/operations/AdaptivePooling.mm
+++ b/aten/src/ATen/native/mps/operations/AdaptivePooling.mm
@@ -19,11 +19,27 @@
int64_t &strideH, int64_t &strideW,
int64_t &kernel_sizeH, int64_t &kernel_sizeW) {
- strideH = (int64_t) (isizeH / osizeH);
- strideW = (int64_t) (isizeW / osizeW);
+ TORCH_CHECK((isizeH >= osizeH && isizeW >= osizeW) || (isizeH <= osizeH && isizeW <= osizeW),
+ "Adaptive pool MPS: Input height and width must both be greather than or equal to, or lesser than, output height and width")
- kernel_sizeH = isizeH - (osizeH-1) * strideH;
- kernel_sizeW = isizeW - (osizeW-1) * strideW;
+ TORCH_CHECK((!(isizeH <= osizeH && isizeW <= osizeW) || (osizeH % isizeH == 0 && osizeW % isizeW == 0)),
+ "Adaptive pool MPS: If output is larger than input, output sizes must be multiples of input sizes")
+
+ if(isizeH >= osizeH) {
+ strideH = (int64_t) (isizeH / osizeH);
+ strideW = (int64_t) (isizeW / osizeW);
+
+ kernel_sizeH = isizeH - (osizeH-1) * strideH;
+ kernel_sizeW = isizeW - (osizeW-1) * strideW;
+ }
+ else {
+ strideH = (int64_t) (osizeH / isizeH);
+ strideW = (int64_t) (osizeW / isizeW);
+
+ kernel_sizeH = osizeH - (isizeH-1) * strideH;
+ kernel_sizeW = osizeW - (isizeW-1) * strideW;
+ }
+
}
// Adaptive average pooling
@@ -71,13 +87,35 @@
strideH, strideW,
kernel_sizeH, kernel_sizeW);
- output = at::avg_pool2d(input,
- IntArrayRef({kernel_sizeH, kernel_sizeW}),
- IntArrayRef({strideH, strideW}),
- IntArrayRef({0, 0}),
- false,
- true,
- c10::nullopt);
+ if(isizeH >= osizeH) {
+ output = at::avg_pool2d(input,
+ IntArrayRef({kernel_sizeH, kernel_sizeW}),
+ IntArrayRef({strideH, strideW}),
+ IntArrayRef({0, 0}),
+ false,
+ true,
+ c10::nullopt);
+ } else {
+ Tensor phony_grad = at::ones_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
+ auto num_input_dims = input.sizes().size();
+ int64_t phony_shape[num_input_dims];
+ for(int i = 0; i < num_input_dims - 2; i++)
+ phony_shape[i] = input.size(i);
+ phony_shape[num_input_dims-2] = output_size[0];
+ phony_shape[num_input_dims-1] = output_size[1];
+ phony_grad.resize_(IntArrayRef(phony_shape, num_input_dims));
+ output = at::avg_pool2d_backward(input,
+ phony_grad,
+ IntArrayRef({kernel_sizeH, kernel_sizeW}),
+ IntArrayRef({strideH, strideW}),
+ IntArrayRef({0, 0}),
+ false,
+ true,
+ c10::nullopt);
+ // Multiply output by kernel size
+ output = at::mul(output, kernel_sizeH*kernel_sizeW);
+ }
+
return output;
}
@@ -138,15 +176,27 @@
strideH, strideW,
kernel_sizeH, kernel_sizeW);
auto gradInput = at::zeros_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
- if (gradInput.numel() != 0)
- gradInput = at::avg_pool2d_backward(gradOutput,
- input,
- IntArrayRef({kernel_sizeH, kernel_sizeW}),
- IntArrayRef({strideH, strideW}),
- IntArrayRef({0, 0}),
- false,
- true,
- c10::nullopt);
+ if (gradInput.numel() != 0) {
+ if(isizeH >= osizeH) {
+ gradInput = at::avg_pool2d_backward(gradOutput,
+ input,
+ IntArrayRef({kernel_sizeH, kernel_sizeW}),
+ IntArrayRef({strideH, strideW}),
+ IntArrayRef({0, 0}),
+ false,
+ true,
+ c10::nullopt);
+ } else {
+ gradInput = at::avg_pool2d(gradOutput,
+ IntArrayRef({kernel_sizeH, kernel_sizeW}),
+ IntArrayRef({strideH, strideW}),
+ IntArrayRef({0, 0}),
+ false,
+ true,
+ c10::nullopt);
+ gradInput = at::mul(gradInput, kernel_sizeH*kernel_sizeW);
+ }
+ }
return gradInput;
diff --git a/test/test_mps.py b/test/test_mps.py
index 19de49f..635acd5 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -3861,6 +3861,21 @@
helper((2, 16, 16), (4, 4), False)
+ # Output shape larger than input shape
+
+ helper((2, 2, 4, 4), (8, 8), False)
+ helper((2, 2, 2, 2), (4, 4), False)
+ helper((2, 2, 3, 3), (9, 9), False)
+ helper((2, 2, 2, 2), (16, 16), False)
+ helper((2, 2, 2, 16), (16, 16), False)
+
+ helper((2, 4, 4), (16, 16), False)
+
+ try:
+ helper((2, 2, 3, 3), (7, 7), False)
+ except Exception as e:
+ pass
+
# Test max avg pool2d - when the input size is a multiple of output size
# Not testing for channels last right now
def test_adaptive_max_pool2d_simple(self):