MPS: Add adaptive max pool2d op (#78410)
Adaptive max pool 2d forward and backward with test
Pull Request resolved: https://github.com/pytorch/pytorch/pull/78410
Approved by: https://github.com/albanD
diff --git a/aten/src/ATen/native/mps/operations/AdaptiveAveragePooling.mm b/aten/src/ATen/native/mps/operations/AdaptiveAveragePooling.mm
deleted file mode 100644
index c828183..0000000
--- a/aten/src/ATen/native/mps/operations/AdaptiveAveragePooling.mm
+++ /dev/null
@@ -1,154 +0,0 @@
-// Copyright © 2022 Apple Inc.
-
-#include <ATen/ATen.h>
-#include <ATen/Tensor.h>
-#include <ATen/Utils.h>
-#include <ATen/TensorUtils.h>
-#include <ATen/mps/MPSStream.h>
-#include <ATen/native/mps/OperationUtils.h>
-#include <ATen/native/Pool.h>
-#include <torch/library.h>
-
-namespace at {
-namespace native {
-
-
-void set_kernel_params
- (int64_t isizeH, int64_t isizeW,
- int64_t osizeH, int64_t osizeW,
- int64_t &strideH, int64_t &strideW,
- int64_t &kernel_sizeH, int64_t &kernel_sizeW) {
-
- strideH = (int64_t) (isizeH / osizeH);
- strideW = (int64_t) (isizeW / osizeW);
-
- kernel_sizeH = isizeH - (osizeH-1) * strideH;
- kernel_sizeW = isizeW - (osizeW-1) * strideW;
-}
-
-Tensor& adaptive_avg_pool2d_out_mps
- (const Tensor& input,
- IntArrayRef output_size,
- Tensor& output) {
-
- for (int64_t i = 1; i < input.ndimension(); i++) {
- TORCH_CHECK(input.size(i) > 0,
- "adaptive_avg_pool2d(): Expected input to have non-zero size for non-batch dimensions, "
- "but input has sizes ", input.sizes(), " with dimension ", i, " being "
- "empty");
- }
-
- int64_t isizeH = input.size(-2);
- int64_t isizeW = input.size(-1);
-
- int64_t osizeH = output_size[0];
- int64_t osizeW = output_size[1];
-
- if(input.suggest_memory_format() == at::MemoryFormat::ChannelsLast)
- TORCH_CHECK(input.ndimension() == 4,
- "adaptive_avg_pool2d(): Expected 4D tensor, but got ",
- input.sizes())
-
- switch (input.suggest_memory_format()) {
- case at::MemoryFormat::Contiguous:
- case at::MemoryFormat::ChannelsLast:
- break;
- default:
- TORCH_CHECK(
- false,
- "Unsupported memory format. Supports only ChannelsLast, Contiguous")
- }
-
- int64_t strideH;
- int64_t strideW;
- int64_t kernel_sizeH;
- int64_t kernel_sizeW;
-
- set_kernel_params(isizeH, isizeW,
- osizeH, osizeW,
- strideH, strideW,
- kernel_sizeH, kernel_sizeW);
-
- output = at::avg_pool2d(input,
- IntArrayRef({kernel_sizeH, kernel_sizeW}),
- IntArrayRef({strideH, strideW}),
- IntArrayRef({0, 0}),
- false,
- true,
- c10::nullopt);
- return output;
-}
-
-Tensor adaptive_avg_pool2d_mps
- (at::Tensor const& input,
- IntArrayRef output_size) {
-
- IntArrayRef output_shape;
-
- auto osizeH = output_size[0];
- auto osizeW = output_size[1];
-
- std::vector<long long> out_dims = {};
-
- if(input.ndimension() == 4) {
- auto sizeB = input.size(0);
- auto sizeD = input.size(1);
-
- out_dims.push_back(sizeB);
- out_dims.push_back(sizeD);
- out_dims.push_back(osizeH);
- out_dims.push_back(osizeW);
- output_shape = IntArrayRef(out_dims);
- }
- else {
- auto sizeD = input.size(0);
- out_dims.push_back(sizeD);
- out_dims.push_back(osizeH);
- out_dims.push_back(osizeW);
- output_shape = IntArrayRef(out_dims);
- }
-
- const auto memory_format = input.suggest_memory_format();
- Tensor output = at::native::empty_mps(
- output_shape,
- input.scalar_type(),
- c10::nullopt,
- kMPS,
- c10::nullopt,
- memory_format);
- return adaptive_avg_pool2d_out_mps(input, output_size, output);
-
-}
-
-Tensor adaptive_avg_pool2d_backward_mps
- (const Tensor& gradOutput,
- const Tensor& input) {
-
- int64_t isizeH = input.size(-2);
- int64_t isizeW = input.size(-1);
- int64_t osizeH = gradOutput.size(-2);
- int64_t osizeW = gradOutput.size(-1);
-
- int64_t strideH, strideW, kernel_sizeH, kernel_sizeW;
-
- set_kernel_params(isizeH, isizeW,
- osizeH, osizeW,
- strideH, strideW,
- kernel_sizeH, kernel_sizeW);
- auto gradInput = at::zeros_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
- if (gradInput.numel() != 0)
- gradInput = at::avg_pool2d_backward(gradOutput,
- input,
- IntArrayRef({kernel_sizeH, kernel_sizeW}),
- IntArrayRef({strideH, strideW}),
- IntArrayRef({0, 0}),
- false,
- true,
- c10::nullopt);
-
- return gradInput;
-
-}
-
-}
-}
diff --git a/aten/src/ATen/native/mps/operations/AdaptivePooling.mm b/aten/src/ATen/native/mps/operations/AdaptivePooling.mm
new file mode 100644
index 0000000..1d58de2
--- /dev/null
+++ b/aten/src/ATen/native/mps/operations/AdaptivePooling.mm
@@ -0,0 +1,244 @@
+// Copyright © 2022 Apple Inc.
+
+#include <ATen/ATen.h>
+#include <ATen/Tensor.h>
+#include <ATen/Utils.h>
+#include <ATen/TensorUtils.h>
+#include <ATen/mps/MPSStream.h>
+#include <ATen/native/mps/OperationUtils.h>
+#include <ATen/native/Pool.h>
+#include <torch/library.h>
+
+namespace at {
+namespace native {
+
+
+void set_kernel_params
+ (int64_t isizeH, int64_t isizeW,
+ int64_t osizeH, int64_t osizeW,
+ int64_t &strideH, int64_t &strideW,
+ int64_t &kernel_sizeH, int64_t &kernel_sizeW) {
+
+ strideH = (int64_t) (isizeH / osizeH);
+ strideW = (int64_t) (isizeW / osizeW);
+
+ kernel_sizeH = isizeH - (osizeH-1) * strideH;
+ kernel_sizeW = isizeW - (osizeW-1) * strideW;
+}
+
+// Adaptive average pooling
+
+Tensor& adaptive_avg_pool2d_out_mps
+ (const Tensor& input,
+ IntArrayRef output_size,
+ Tensor& output) {
+
+ for (int64_t i = 1; i < input.ndimension(); i++) {
+ TORCH_CHECK(input.size(i) > 0,
+ "adaptive_avg_pool2d(): Expected input to have non-zero size for non-batch dimensions, "
+ "but input has sizes ", input.sizes(), " with dimension ", i, " being "
+ "empty");
+ }
+
+ int64_t isizeH = input.size(-2);
+ int64_t isizeW = input.size(-1);
+
+ int64_t osizeH = output_size[0];
+ int64_t osizeW = output_size[1];
+
+ if(input.suggest_memory_format() == at::MemoryFormat::ChannelsLast)
+ TORCH_CHECK(input.ndimension() == 4,
+ "adaptive_avg_pool2d(): Expected 4D tensor, but got ",
+ input.sizes())
+
+ switch (input.suggest_memory_format()) {
+ case at::MemoryFormat::Contiguous:
+ case at::MemoryFormat::ChannelsLast:
+ break;
+ default:
+ TORCH_CHECK(
+ false,
+ "Unsupported memory format. Supports only ChannelsLast, Contiguous")
+ }
+
+ int64_t strideH;
+ int64_t strideW;
+ int64_t kernel_sizeH;
+ int64_t kernel_sizeW;
+
+ set_kernel_params(isizeH, isizeW,
+ osizeH, osizeW,
+ strideH, strideW,
+ kernel_sizeH, kernel_sizeW);
+
+ output = at::avg_pool2d(input,
+ IntArrayRef({kernel_sizeH, kernel_sizeW}),
+ IntArrayRef({strideH, strideW}),
+ IntArrayRef({0, 0}),
+ false,
+ true,
+ c10::nullopt);
+ return output;
+}
+
+Tensor adaptive_avg_pool2d_mps
+ (at::Tensor const& input,
+ IntArrayRef output_size) {
+
+ IntArrayRef output_shape;
+
+ auto osizeH = output_size[0];
+ auto osizeW = output_size[1];
+
+ std::vector<long long> out_dims = {};
+
+ if(input.ndimension() == 4) {
+ auto sizeB = input.size(0);
+ auto sizeD = input.size(1);
+
+ out_dims.push_back(sizeB);
+ out_dims.push_back(sizeD);
+ out_dims.push_back(osizeH);
+ out_dims.push_back(osizeW);
+ output_shape = IntArrayRef(out_dims);
+ }
+ else {
+ auto sizeD = input.size(0);
+ out_dims.push_back(sizeD);
+ out_dims.push_back(osizeH);
+ out_dims.push_back(osizeW);
+ output_shape = IntArrayRef(out_dims);
+ }
+
+ const auto memory_format = input.suggest_memory_format();
+ Tensor output = at::native::empty_mps(
+ output_shape,
+ input.scalar_type(),
+ c10::nullopt,
+ kMPS,
+ c10::nullopt,
+ memory_format);
+ return adaptive_avg_pool2d_out_mps(input, output_size, output);
+
+}
+
+Tensor adaptive_avg_pool2d_backward_mps
+ (const Tensor& gradOutput,
+ const Tensor& input) {
+
+ int64_t isizeH = input.size(-2);
+ int64_t isizeW = input.size(-1);
+ int64_t osizeH = gradOutput.size(-2);
+ int64_t osizeW = gradOutput.size(-1);
+
+ int64_t strideH, strideW, kernel_sizeH, kernel_sizeW;
+
+ set_kernel_params(isizeH, isizeW,
+ osizeH, osizeW,
+ strideH, strideW,
+ kernel_sizeH, kernel_sizeW);
+ auto gradInput = at::zeros_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
+ if (gradInput.numel() != 0)
+ gradInput = at::avg_pool2d_backward(gradOutput,
+ input,
+ IntArrayRef({kernel_sizeH, kernel_sizeW}),
+ IntArrayRef({strideH, strideW}),
+ IntArrayRef({0, 0}),
+ false,
+ true,
+ c10::nullopt);
+
+ return gradInput;
+
+}
+
+// Adaptive max pooling
+
+TORCH_IMPL_FUNC(adaptive_max_pool2d_out_mps)
+ (const Tensor& input,
+ IntArrayRef output_size,
+ const Tensor& output,
+ const Tensor& indices) {
+
+ for (int64_t i = 1; i < input.ndimension(); i++) {
+ TORCH_CHECK(input.size(i) > 0,
+ "adaptive_max_pool2d(): Expected input to have non-zero size for non-batch dimensions, "
+ "but input has sizes ", input.sizes(), " with dimension ", i, " being "
+ "empty");
+ }
+
+ int64_t isizeH = input.size(-2);
+ int64_t isizeW = input.size(-1);
+
+ int64_t osizeH = output_size[0];
+ int64_t osizeW = output_size[1];
+
+ if(input.suggest_memory_format() == at::MemoryFormat::ChannelsLast)
+ TORCH_CHECK(input.ndimension() == 4,
+ "adaptive_avg_pool2d(): Expected 4D tensor, but got ",
+ input.sizes())
+
+ switch (input.suggest_memory_format()) {
+ case at::MemoryFormat::Contiguous:
+ case at::MemoryFormat::ChannelsLast:
+ break;
+ default:
+ TORCH_CHECK(
+ false,
+ "Unsupported memory format. Supports only ChannelsLast, Contiguous")
+ }
+
+ int64_t strideH;
+ int64_t strideW;
+ int64_t kernel_sizeH;
+ int64_t kernel_sizeW;
+
+ set_kernel_params(isizeH, isizeW,
+ osizeH, osizeW,
+ strideH, strideW,
+ kernel_sizeH, kernel_sizeW);
+
+ auto outputs = at::max_pool2d_with_indices(input,
+ IntArrayRef({kernel_sizeH, kernel_sizeW}),
+ IntArrayRef({strideH, strideW}),
+ IntArrayRef({0, 0}),
+ IntArrayRef({1, 1}),
+ false);
+
+ output.copy_(std::get<0>(outputs));
+ indices.copy_(std::get<1>(outputs));
+}
+
+TORCH_IMPL_FUNC(adaptive_max_pool2d_backward_out_mps)
+ (const Tensor& gradOutput,
+ const Tensor& input,
+ const Tensor& indices,
+ const Tensor& gradInput) {
+
+ int64_t isizeH = input.size(-2);
+ int64_t isizeW = input.size(-1);
+ int64_t osizeH = gradOutput.size(-2);
+ int64_t osizeW = gradOutput.size(-1);
+
+ int64_t strideH, strideW, kernel_sizeH, kernel_sizeW;
+
+ set_kernel_params(isizeH, isizeW,
+ osizeH, osizeW,
+ strideH, strideW,
+ kernel_sizeH, kernel_sizeW);
+
+ auto returnGradInput = at::max_pool2d_with_indices_backward(gradOutput,
+ input,
+ IntArrayRef({kernel_sizeH, kernel_sizeW}),
+ IntArrayRef({strideH, strideW}),
+ IntArrayRef({0, 0}),
+ IntArrayRef({1, 1}),
+ false,
+ indices);
+
+ gradInput.copy_(returnGradInput);
+
+}
+
+}
+}
diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml
index c81d6ba..b97c8c9 100644
--- a/aten/src/ATen/native/native_functions.yaml
+++ b/aten/src/ATen/native/native_functions.yaml
@@ -9791,6 +9791,7 @@
dispatch:
CPU: adaptive_max_pool2d_out_cpu
CUDA: adaptive_max_pool2d_out_cuda
+ MPS: adaptive_max_pool2d_out_mps
# Return: (Tensor output, Tensor indices)
- func: adaptive_max_pool2d(Tensor self, int[2] output_size) -> (Tensor, Tensor)
@@ -9803,6 +9804,7 @@
dispatch:
CPU: adaptive_max_pool2d_backward_out_cpu
CUDA: adaptive_max_pool2d_backward_out_cuda
+ MPS: adaptive_max_pool2d_backward_out_mps
- func: adaptive_max_pool2d_backward(Tensor grad_output, Tensor self, Tensor indices) -> Tensor
python_module: nn
diff --git a/test/test_mps.py b/test/test_mps.py
index 72b4be7..f9778b4 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -3089,6 +3089,50 @@
helper((2, 16, 16), (4, 4), False)
+ # Test max avg pool2d - when the input size is a multiple of output size
+ # Not testing for channels last right now
+ def test_adaptive_max_pool2d_simple(self):
+ def helper(input_shape, out_shape, return_indices, dtype, channels_last=False):
+ cpu_x = None
+ if(dtype in [torch.float16, torch.float32]):
+ cpu_x = torch.randn(input_shape, device='cpu', dtype=dtype, requires_grad=True)
+ else:
+ cpu_x = torch.randint(50, input_shape, device='cpu', dtype=dtype, requires_grad=True)
+ if(channels_last):
+ cpu_x = cpu_x.to(memory_format=torch.channels_last)
+ cpu_x.retain_grad()
+ x = cpu_x.detach().clone().to('mps').requires_grad_()
+
+ max_result, max_indices = None, None
+ max_result_cpu, max_indices_cpu = None, None
+
+ if(return_indices):
+ max_result, max_indices = torch.nn.AdaptiveMaxPool2d(out_shape, return_indices)(x)
+ max_result_cpu, max_indices_cpu = torch.nn.AdaptiveMaxPool2d(out_shape, return_indices)(cpu_x)
+ else:
+ max_result = torch.nn.AdaptiveMaxPool2d(out_shape, return_indices)(x)
+ max_result_cpu = torch.nn.AdaptiveMaxPool2d(out_shape, return_indices)(cpu_x)
+
+ cpu_grad = torch.randn(max_result_cpu.shape)
+ grad = cpu_grad.to('mps')
+
+ max_result.backward(gradient=grad)
+ max_result_cpu.backward(gradient=cpu_grad)
+
+ self.assertEqual(max_result, max_result_cpu)
+ if(return_indices):
+ self.assertEqual(max_indices, max_indices_cpu)
+ self.assertEqual(x.grad, cpu_x.grad)
+
+ for dtype in [torch.float32]:
+ for return_indices in [False, True]:
+ helper((2, 2, 4, 4), (2, 2), return_indices, dtype)
+ helper((2, 2, 9, 9), (3, 3), return_indices, dtype)
+ helper((2, 2, 9, 9), (9, 9), return_indices, dtype)
+ helper((2, 2, 16, 16), (2, 2), return_indices, dtype)
+ helper((2, 2, 16, 16), (2, 16), return_indices, dtype)
+ helper((2, 16, 16), (4, 4), return_indices, dtype)
+
def test_gelu_simple(self):
def helper(shape):
cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)