[MPS] Pixel shuffle unshuffle support (#99306)

Fixes #83196

Now, MPS implementation is blazingly fast.

Though, I have several questions on improving this PR:

1. I copied code from `test_nn.py`. Is there better way to test this?
2. I decided to use `usepixelshuffleorder:YES`. Am I right performance-wise? According to docs:
```
`usePixelShuffleOrder` can be
used to control how the data within spatial blocks is ordered in the
`depthAxis` dimension: with `usePixelShuffleOrder=YES` the values within the
spatial blocks are stored contiguosly within the `depthAxis` dimension whereas
otherwise they are stored interleaved with existing values in the `depthAxis` dimension.
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/99306
Approved by: https://github.com/kulinseth, https://github.com/malfet
diff --git a/aten/src/ATen/native/PixelShuffle.cpp b/aten/src/ATen/native/PixelShuffle.cpp
index e535909..1d500b3 100644
--- a/aten/src/ATen/native/PixelShuffle.cpp
+++ b/aten/src/ATen/native/PixelShuffle.cpp
@@ -1,6 +1,7 @@
 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
 #include <ATen/native/TensorTransformations.h>
 #include <ATen/native/cpu/PixelShuffleKernel.h>
+#include <ATen/native/PixelShuffle.h>
 
 #include <c10/util/Exception.h>
 
@@ -20,37 +21,6 @@
 namespace at {
 namespace native {
 
-static inline void check_pixel_shuffle_shapes(const Tensor& self, int64_t upscale_factor) {
-  TORCH_CHECK(self.dim() >= 3,
-              "pixel_shuffle expects input to have at least 3 dimensions, but got input with ",
-              self.dim(), " dimension(s)");
-  TORCH_CHECK(upscale_factor > 0,
-              "pixel_shuffle expects a positive upscale_factor, but got ",
-              upscale_factor);
-  int64_t c = self.size(-3);
-  int64_t upscale_factor_squared = upscale_factor * upscale_factor;
-  TORCH_CHECK(c % upscale_factor_squared == 0,
-              "pixel_shuffle expects its input's 'channel' dimension to be divisible by the square of "
-              "upscale_factor, but input.size(-3)=", c, " is not divisible by ", upscale_factor_squared);
-}
-
-static inline void check_pixel_unshuffle_shapes(const Tensor& self, int64_t downscale_factor) {
-  TORCH_CHECK(self.dim() >= 3,
-              "pixel_unshuffle expects input to have at least 3 dimensions, but got input with ",
-              self.dim(), " dimension(s)");
-  TORCH_CHECK(downscale_factor > 0,
-              "pixel_unshuffle expects a positive downscale_factor, but got ",
-              downscale_factor);
-  int64_t h = self.size(-2);
-  int64_t w = self.size(-1);
-  TORCH_CHECK(h % downscale_factor == 0,
-              "pixel_unshuffle expects height to be divisible by downscale_factor, but input.size(-2)=", h,
-              " is not divisible by ", downscale_factor);
-  TORCH_CHECK(w % downscale_factor == 0,
-              "pixel_unshuffle expects width to be divisible by downscale_factor, but input.size(-1)=", w,
-              " is not divisible by ", downscale_factor);
-}
-
 Tensor pixel_shuffle_cpu(const Tensor& self, int64_t upscale_factor) {
   check_pixel_shuffle_shapes(self, upscale_factor);
 
diff --git a/aten/src/ATen/native/PixelShuffle.h b/aten/src/ATen/native/PixelShuffle.h
new file mode 100644
index 0000000..a9a66a3
--- /dev/null
+++ b/aten/src/ATen/native/PixelShuffle.h
@@ -0,0 +1,47 @@
+#include <ATen/core/Tensor.h>
+#include <c10/util/Exception.h>
+
+namespace at {
+namespace native {
+
+inline void check_pixel_shuffle_shapes(const Tensor& self, int64_t upscale_factor) {
+  TORCH_CHECK(self.dim() >= 3,
+              "pixel_shuffle expects input to have at least 3 dimensions, but got input with ",
+              self.dim(), " dimension(s)");
+  TORCH_CHECK(upscale_factor > 0,
+              "pixel_shuffle expects a positive upscale_factor, but got ",
+              upscale_factor);
+  int64_t c = self.size(-3);
+  int64_t upscale_factor_squared = upscale_factor * upscale_factor;
+  TORCH_CHECK(c % upscale_factor_squared == 0,
+              "pixel_shuffle expects its input's 'channel' dimension to be divisible by the square of "
+              "upscale_factor, but input.size(-3)=", c, " is not divisible by ", upscale_factor_squared);
+}
+
+inline void check_pixel_unshuffle_shapes(const Tensor& self, int64_t downscale_factor) {
+  TORCH_CHECK(
+      self.dim() >= 3,
+      "pixel_unshuffle expects input to have at least 3 dimensions, but got input with ",
+      self.dim(),
+      " dimension(s)");
+  TORCH_CHECK(
+      downscale_factor > 0,
+      "pixel_unshuffle expects a positive downscale_factor, but got ",
+      downscale_factor);
+  int64_t h = self.size(-2);
+  int64_t w = self.size(-1);
+  TORCH_CHECK(
+      h % downscale_factor == 0,
+      "pixel_unshuffle expects height to be divisible by downscale_factor, but input.size(-2)=",
+      h,
+      " is not divisible by ",
+      downscale_factor);
+  TORCH_CHECK(
+      w % downscale_factor == 0,
+      "pixel_unshuffle expects width to be divisible by downscale_factor, but input.size(-1)=",
+      w,
+      " is not divisible by ",
+      downscale_factor);
+}
+
+}} // namespace at::native
diff --git a/aten/src/ATen/native/mps/operations/PixelShuffle.mm b/aten/src/ATen/native/mps/operations/PixelShuffle.mm
new file mode 100644
index 0000000..30e85bf
--- /dev/null
+++ b/aten/src/ATen/native/mps/operations/PixelShuffle.mm
@@ -0,0 +1,114 @@
+#include <ATen/native/PixelShuffle.h>
+#include <ATen/native/mps/OperationUtils.h>
+#include <ATen/ops/pixel_shuffle_native.h>
+#include <ATen/ops/pixel_unshuffle_native.h>
+
+using namespace at::mps;
+
+namespace at::native {
+
+static Tensor pixel_shuffle_helper(const Tensor& self, int64_t factor, bool upscale) {
+  using namespace mps;
+  using CachedGraph = MPSUnaryCachedGraph;
+
+  if (factor == 1) {
+    return self.clone();
+  }
+
+  if (upscale) {
+    check_pixel_shuffle_shapes(self, factor);
+  } else {
+    check_pixel_unshuffle_shapes(self, factor);
+  }
+
+  MPSStream* stream = getCurrentMPSStream();
+
+  const int64_t c = self.size(-3);
+  const int64_t h = self.size(-2);
+  const int64_t w = self.size(-1);
+  constexpr auto NUM_NON_BATCH_DIMS = 3;
+  const auto self_sizes_batch_end = self.sizes().end() - NUM_NON_BATCH_DIMS;
+
+  const int64_t factor_squared = factor * factor;
+  const int64_t oc = upscale ? c / factor_squared : c * factor_squared;
+  const int64_t oh = upscale ? h * factor : h / factor;
+  const int64_t ow = upscale ? w * factor : w / factor;
+
+  std::vector<int64_t> out_shape(self.sizes().begin(), self_sizes_batch_end);
+  out_shape.insert(out_shape.end(), {oc, oh, ow});
+
+  Tensor output = at::empty(out_shape, self.options());
+
+  if (output.numel() == 0) {
+    return output;
+  }
+
+  @autoreleasepool {
+    string key = (upscale ? "pixel_shuffle_" : "pixel_unshuffle_") + getTensorsStringKey({self}) + "_factor_" +
+        std::to_string(factor);
+    CachedGraph* cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
+      const auto ndims = self.ndimension();
+      MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, self);
+      MPSGraphTensor* outputTensor = nullptr;
+      if (upscale) {
+        outputTensor = [mpsGraph depthToSpace2DTensor:inputTensor
+                                            widthAxis:ndims - 1
+                                           heightAxis:ndims - 2
+                                            depthAxis:ndims - 3
+                                            blockSize:factor
+                                 usePixelShuffleOrder:YES
+                                                 name:nil];
+      } else {
+        outputTensor = [mpsGraph spaceToDepth2DTensor:inputTensor
+                                            widthAxis:ndims - 1
+                                           heightAxis:ndims - 2
+                                            depthAxis:ndims - 3
+                                            blockSize:factor
+                                 usePixelShuffleOrder:YES
+                                                 name:nil];
+      }
+
+      newCachedGraph->inputTensor_ = inputTensor;
+      newCachedGraph->outputTensor_ = outputTensor;
+      return newCachedGraph;
+    });
+
+    Placeholder selfPlaceholder = Placeholder(cachedGraph->inputTensor_, self);
+    Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_, output);
+
+    // Create dictionary of inputs and outputs
+    NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds =
+        @{selfPlaceholder.getMPSGraphTensor() : selfPlaceholder.getMPSGraphTensorData()};
+
+    NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results =
+        @{outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()};
+
+    runMPSGraph(stream, cachedGraph->graph(), feeds, results);
+  }
+
+  return output;
+}
+
+Tensor pixel_shuffle_mps(const Tensor& self, int64_t upscale_factor) {
+  if (!is_macos_13_or_newer()) {
+    TORCH_WARN_ONCE("MPS: pixel_shuffle op is supported starting from macOS 13.0. ",
+                    "Falling back on CPU. This may have performance implications.");
+
+    return at::native::pixel_shuffle_cpu(self.to("cpu"), upscale_factor).to("mps");
+  }
+
+  return pixel_shuffle_helper(self, upscale_factor, /*upscale=*/true);
+}
+
+Tensor pixel_unshuffle_mps(const Tensor& self, int64_t downscale_factor) {
+  if (!is_macos_13_or_newer()) {
+    TORCH_WARN_ONCE("MPS: pixel_unshuffle op is supported starting from macOS 13.0. ",
+                    "Falling back on CPU. This may have performance implications.");
+
+    return at::native::pixel_unshuffle_cpu(self.to("cpu"), downscale_factor).to("mps");
+  }
+
+  return pixel_shuffle_helper(self, downscale_factor, /*upscale=*/false);
+}
+
+} // namespace at::native
diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml
index e87c9bc..26bbfd0 100644
--- a/aten/src/ATen/native/native_functions.yaml
+++ b/aten/src/ATen/native/native_functions.yaml
@@ -4318,6 +4318,7 @@
 - func: pixel_shuffle(Tensor self, int upscale_factor) -> Tensor
   dispatch:
     CPU: pixel_shuffle_cpu
+    MPS: pixel_shuffle_mps
     CompositeExplicitAutogradNonFunctional: math_pixel_shuffle
   autogen: pixel_shuffle.out
   tags: core
@@ -4325,6 +4326,7 @@
 - func: pixel_unshuffle(Tensor self, int downscale_factor) -> Tensor
   dispatch:
     CPU: pixel_unshuffle_cpu
+    MPS: pixel_unshuffle_mps
     CompositeExplicitAutogradNonFunctional: math_pixel_unshuffle
   autogen: pixel_unshuffle.out
 
diff --git a/test/test_mps.py b/test/test_mps.py
index 1159e8d..e9d7378 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -948,6 +948,134 @@
         with self.assertRaisesRegex(RuntimeError, r"MPS driver API confirmed .+"):
             leak_gpu0()
 
+
+class TestPixelShuffle(TestCaseMPS):
+    def test_pixel_shuffle_unshuffle(self):
+        def _test_pixel_shuffle_unshuffle_helper(num_input_dims, valid_channels_dim=True,
+                                                 upscale_factor=None, is_contiguous=True):
+
+            def generate_input():
+                # If valid_channels_dim=False, add 1 to make channels dim indivisible by upscale_factor ** 2.
+                channels = random.randint(1, 4) * upscale_factor ** 2 + (0 if valid_channels_dim else 1)
+                height = random.randint(5, 10)
+                width = random.randint(5, 10)
+
+                if num_input_dims == 1:
+                    input = torch.rand(channels, requires_grad=True, device='mps')
+                    assert is_contiguous
+                elif num_input_dims == 2:
+                    input = torch.rand(width, height, requires_grad=True, device='mps').T
+                    if is_contiguous:
+                        input = input.contiguous()
+                else:
+                    batch_sizes = [random.randint(1, 3) for _ in range(num_input_dims - 3)]
+                    input = torch.rand(*batch_sizes, channels, width, height, requires_grad=True, device='mps')
+                    input = input.transpose(-1, -2)
+                    if is_contiguous:
+                        input = input.contiguous()
+
+                if not is_contiguous and len(input.reshape(-1)) > 0:
+                    assert not input.is_contiguous()
+
+                input = input.detach().clone()
+                input.requires_grad = True
+                return input
+
+            # Function to imperatively ensure pixels are shuffled to the correct locations.
+            # Used to validate the batch operations in pixel_shuffle.
+            def _verify_pixel_shuffle(input, output, upscale_factor):
+                for c in range(output.size(-3)):
+                    for h in range(output.size(-2)):
+                        for w in range(output.size(-1)):
+                            height_idx = h // upscale_factor
+                            weight_idx = w // upscale_factor
+                            channel_idx = (upscale_factor * (h % upscale_factor)) + (w % upscale_factor) + \
+                                          (c * upscale_factor ** 2)
+                            self.assertEqual(output[..., c, h, w], input[..., channel_idx, height_idx, weight_idx])
+
+            upscale_factor = random.randint(2, 5) if upscale_factor is None else upscale_factor
+            input = generate_input()
+
+            ps = nn.PixelShuffle(upscale_factor)
+            pus = nn.PixelUnshuffle(downscale_factor=upscale_factor)
+
+            if num_input_dims >= 3 and valid_channels_dim and upscale_factor > 0:
+                output = ps(input)
+                _verify_pixel_shuffle(input, output, upscale_factor)
+                output.backward(output.data)
+                self.assertEqual(input.data, input.grad.data)
+
+                # Ensure unshuffle properly inverts shuffle.
+                unshuffle_output = pus(output)
+                self.assertEqual(input, unshuffle_output)
+            else:
+                self.assertRaises(RuntimeError, lambda: ps(input))
+
+        def _test_pixel_unshuffle_error_case_helper(num_input_dims, valid_height_dim=True, valid_width_dim=True,
+                                                    downscale_factor=None):
+            downscale_factor = random.randint(2, 5) if downscale_factor is None else downscale_factor
+            channels = random.randint(1, 4)
+            # If valid_height_dim=False, add 1 to make height dim indivisible by downscale_factor.
+            height = random.randint(3, 5) * abs(downscale_factor) + (0 if valid_height_dim else 1)
+            # If valid_width_dim=False, add 1 to make width dim indivisible by downscale_factor.
+            width = random.randint(3, 5) * abs(downscale_factor) + (0 if valid_width_dim else 1)
+
+            if num_input_dims == 1:
+                input = torch.rand(channels, requires_grad=True, device='mps')
+            elif num_input_dims == 2:
+                input = torch.rand(height, width, requires_grad=True, device='mps')
+            else:
+                batch_sizes = [random.randint(1, 3) for _ in range(num_input_dims - 3)]
+                input = torch.rand(*batch_sizes, channels, height, width, requires_grad=True, device='mps')
+
+            pus = nn.PixelUnshuffle(downscale_factor)
+            self.assertRaises(RuntimeError, lambda: pus(input))
+
+        def _test_pixel_shuffle_unshuffle_for_input_dims(num_input_dims):
+            # For 1D - 2D, this is an error case.
+            # For 3D - 5D, this is a success case for pixel_shuffle + pixel_unshuffle.
+            is_contiguous_check = [True, False] if num_input_dims > 1 else [True]
+            for is_contiguous in is_contiguous_check:
+                _test_pixel_shuffle_unshuffle_helper(
+                    num_input_dims=num_input_dims, is_contiguous=is_contiguous
+                )
+                _test_pixel_shuffle_unshuffle_helper(
+                    num_input_dims=num_input_dims, valid_channels_dim=False, is_contiguous=is_contiguous
+                )
+                _test_pixel_shuffle_unshuffle_helper(
+                    num_input_dims=num_input_dims, upscale_factor=0, is_contiguous=is_contiguous
+                )
+                _test_pixel_shuffle_unshuffle_helper(
+                    num_input_dims=num_input_dims, upscale_factor=-2, is_contiguous=is_contiguous
+                )
+
+                # Error cases for pixel_unshuffle.
+            _test_pixel_unshuffle_error_case_helper(num_input_dims=num_input_dims, valid_height_dim=False)
+            _test_pixel_unshuffle_error_case_helper(num_input_dims=num_input_dims, valid_width_dim=False)
+            _test_pixel_unshuffle_error_case_helper(num_input_dims=num_input_dims, downscale_factor=0)
+            _test_pixel_unshuffle_error_case_helper(num_input_dims=num_input_dims, downscale_factor=-2)
+
+        def test_pixel_shuffle_unshuffle_1D():
+            _test_pixel_shuffle_unshuffle_for_input_dims(num_input_dims=1)
+
+        def test_pixel_shuffle_unshuffle_2D():
+            _test_pixel_shuffle_unshuffle_for_input_dims(num_input_dims=2)
+
+        def test_pixel_shuffle_unshuffle_3D():
+            _test_pixel_shuffle_unshuffle_for_input_dims(num_input_dims=3)
+
+        def test_pixel_shuffle_unshuffle_4D():
+            _test_pixel_shuffle_unshuffle_for_input_dims(num_input_dims=4)
+
+        def test_pixel_shuffle_unshuffle_5D():
+            _test_pixel_shuffle_unshuffle_for_input_dims(num_input_dims=5)
+
+        test_pixel_shuffle_unshuffle_1D()
+        test_pixel_shuffle_unshuffle_2D()
+        test_pixel_shuffle_unshuffle_3D()
+        test_pixel_shuffle_unshuffle_4D()
+        test_pixel_shuffle_unshuffle_5D()
+
 class MPSReluTest(TestCaseMPS):
     def _npRelu(self, np_features):
         return np.maximum(np_features, np.zeros(np_features.shape)).astype(np_features.dtype)