[MPS] Pixel shuffle unshuffle support (#99306)

Fixes #83196

Now, MPS implementation is blazingly fast.

Though, I have several questions on improving this PR:

1. I copied code from `test_nn.py`. Is there better way to test this?
2. I decided to use `usepixelshuffleorder:YES`. Am I right performance-wise? According to docs:
```
`usePixelShuffleOrder` can be
used to control how the data within spatial blocks is ordered in the
`depthAxis` dimension: with `usePixelShuffleOrder=YES` the values within the
spatial blocks are stored contiguosly within the `depthAxis` dimension whereas
otherwise they are stored interleaved with existing values in the `depthAxis` dimension.
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/99306
Approved by: https://github.com/kulinseth, https://github.com/malfet
diff --git a/test/test_mps.py b/test/test_mps.py
index 1159e8d..e9d7378 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -948,6 +948,134 @@
         with self.assertRaisesRegex(RuntimeError, r"MPS driver API confirmed .+"):
             leak_gpu0()
 
+
+class TestPixelShuffle(TestCaseMPS):
+    def test_pixel_shuffle_unshuffle(self):
+        def _test_pixel_shuffle_unshuffle_helper(num_input_dims, valid_channels_dim=True,
+                                                 upscale_factor=None, is_contiguous=True):
+
+            def generate_input():
+                # If valid_channels_dim=False, add 1 to make channels dim indivisible by upscale_factor ** 2.
+                channels = random.randint(1, 4) * upscale_factor ** 2 + (0 if valid_channels_dim else 1)
+                height = random.randint(5, 10)
+                width = random.randint(5, 10)
+
+                if num_input_dims == 1:
+                    input = torch.rand(channels, requires_grad=True, device='mps')
+                    assert is_contiguous
+                elif num_input_dims == 2:
+                    input = torch.rand(width, height, requires_grad=True, device='mps').T
+                    if is_contiguous:
+                        input = input.contiguous()
+                else:
+                    batch_sizes = [random.randint(1, 3) for _ in range(num_input_dims - 3)]
+                    input = torch.rand(*batch_sizes, channels, width, height, requires_grad=True, device='mps')
+                    input = input.transpose(-1, -2)
+                    if is_contiguous:
+                        input = input.contiguous()
+
+                if not is_contiguous and len(input.reshape(-1)) > 0:
+                    assert not input.is_contiguous()
+
+                input = input.detach().clone()
+                input.requires_grad = True
+                return input
+
+            # Function to imperatively ensure pixels are shuffled to the correct locations.
+            # Used to validate the batch operations in pixel_shuffle.
+            def _verify_pixel_shuffle(input, output, upscale_factor):
+                for c in range(output.size(-3)):
+                    for h in range(output.size(-2)):
+                        for w in range(output.size(-1)):
+                            height_idx = h // upscale_factor
+                            weight_idx = w // upscale_factor
+                            channel_idx = (upscale_factor * (h % upscale_factor)) + (w % upscale_factor) + \
+                                          (c * upscale_factor ** 2)
+                            self.assertEqual(output[..., c, h, w], input[..., channel_idx, height_idx, weight_idx])
+
+            upscale_factor = random.randint(2, 5) if upscale_factor is None else upscale_factor
+            input = generate_input()
+
+            ps = nn.PixelShuffle(upscale_factor)
+            pus = nn.PixelUnshuffle(downscale_factor=upscale_factor)
+
+            if num_input_dims >= 3 and valid_channels_dim and upscale_factor > 0:
+                output = ps(input)
+                _verify_pixel_shuffle(input, output, upscale_factor)
+                output.backward(output.data)
+                self.assertEqual(input.data, input.grad.data)
+
+                # Ensure unshuffle properly inverts shuffle.
+                unshuffle_output = pus(output)
+                self.assertEqual(input, unshuffle_output)
+            else:
+                self.assertRaises(RuntimeError, lambda: ps(input))
+
+        def _test_pixel_unshuffle_error_case_helper(num_input_dims, valid_height_dim=True, valid_width_dim=True,
+                                                    downscale_factor=None):
+            downscale_factor = random.randint(2, 5) if downscale_factor is None else downscale_factor
+            channels = random.randint(1, 4)
+            # If valid_height_dim=False, add 1 to make height dim indivisible by downscale_factor.
+            height = random.randint(3, 5) * abs(downscale_factor) + (0 if valid_height_dim else 1)
+            # If valid_width_dim=False, add 1 to make width dim indivisible by downscale_factor.
+            width = random.randint(3, 5) * abs(downscale_factor) + (0 if valid_width_dim else 1)
+
+            if num_input_dims == 1:
+                input = torch.rand(channels, requires_grad=True, device='mps')
+            elif num_input_dims == 2:
+                input = torch.rand(height, width, requires_grad=True, device='mps')
+            else:
+                batch_sizes = [random.randint(1, 3) for _ in range(num_input_dims - 3)]
+                input = torch.rand(*batch_sizes, channels, height, width, requires_grad=True, device='mps')
+
+            pus = nn.PixelUnshuffle(downscale_factor)
+            self.assertRaises(RuntimeError, lambda: pus(input))
+
+        def _test_pixel_shuffle_unshuffle_for_input_dims(num_input_dims):
+            # For 1D - 2D, this is an error case.
+            # For 3D - 5D, this is a success case for pixel_shuffle + pixel_unshuffle.
+            is_contiguous_check = [True, False] if num_input_dims > 1 else [True]
+            for is_contiguous in is_contiguous_check:
+                _test_pixel_shuffle_unshuffle_helper(
+                    num_input_dims=num_input_dims, is_contiguous=is_contiguous
+                )
+                _test_pixel_shuffle_unshuffle_helper(
+                    num_input_dims=num_input_dims, valid_channels_dim=False, is_contiguous=is_contiguous
+                )
+                _test_pixel_shuffle_unshuffle_helper(
+                    num_input_dims=num_input_dims, upscale_factor=0, is_contiguous=is_contiguous
+                )
+                _test_pixel_shuffle_unshuffle_helper(
+                    num_input_dims=num_input_dims, upscale_factor=-2, is_contiguous=is_contiguous
+                )
+
+                # Error cases for pixel_unshuffle.
+            _test_pixel_unshuffle_error_case_helper(num_input_dims=num_input_dims, valid_height_dim=False)
+            _test_pixel_unshuffle_error_case_helper(num_input_dims=num_input_dims, valid_width_dim=False)
+            _test_pixel_unshuffle_error_case_helper(num_input_dims=num_input_dims, downscale_factor=0)
+            _test_pixel_unshuffle_error_case_helper(num_input_dims=num_input_dims, downscale_factor=-2)
+
+        def test_pixel_shuffle_unshuffle_1D():
+            _test_pixel_shuffle_unshuffle_for_input_dims(num_input_dims=1)
+
+        def test_pixel_shuffle_unshuffle_2D():
+            _test_pixel_shuffle_unshuffle_for_input_dims(num_input_dims=2)
+
+        def test_pixel_shuffle_unshuffle_3D():
+            _test_pixel_shuffle_unshuffle_for_input_dims(num_input_dims=3)
+
+        def test_pixel_shuffle_unshuffle_4D():
+            _test_pixel_shuffle_unshuffle_for_input_dims(num_input_dims=4)
+
+        def test_pixel_shuffle_unshuffle_5D():
+            _test_pixel_shuffle_unshuffle_for_input_dims(num_input_dims=5)
+
+        test_pixel_shuffle_unshuffle_1D()
+        test_pixel_shuffle_unshuffle_2D()
+        test_pixel_shuffle_unshuffle_3D()
+        test_pixel_shuffle_unshuffle_4D()
+        test_pixel_shuffle_unshuffle_5D()
+
 class MPSReluTest(TestCaseMPS):
     def _npRelu(self, np_features):
         return np.maximum(np_features, np.zeros(np_features.shape)).astype(np_features.dtype)