[MPS] Pixel shuffle unshuffle support (#99306)
Fixes #83196
Now, MPS implementation is blazingly fast.
Though, I have several questions on improving this PR:
1. I copied code from `test_nn.py`. Is there better way to test this?
2. I decided to use `usepixelshuffleorder:YES`. Am I right performance-wise? According to docs:
```
`usePixelShuffleOrder` can be
used to control how the data within spatial blocks is ordered in the
`depthAxis` dimension: with `usePixelShuffleOrder=YES` the values within the
spatial blocks are stored contiguosly within the `depthAxis` dimension whereas
otherwise they are stored interleaved with existing values in the `depthAxis` dimension.
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99306
Approved by: https://github.com/kulinseth, https://github.com/malfet
diff --git a/aten/src/ATen/native/PixelShuffle.cpp b/aten/src/ATen/native/PixelShuffle.cpp
index e535909..1d500b3 100644
--- a/aten/src/ATen/native/PixelShuffle.cpp
+++ b/aten/src/ATen/native/PixelShuffle.cpp
@@ -1,6 +1,7 @@
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
#include <ATen/native/TensorTransformations.h>
#include <ATen/native/cpu/PixelShuffleKernel.h>
+#include <ATen/native/PixelShuffle.h>
#include <c10/util/Exception.h>
@@ -20,37 +21,6 @@
namespace at {
namespace native {
-static inline void check_pixel_shuffle_shapes(const Tensor& self, int64_t upscale_factor) {
- TORCH_CHECK(self.dim() >= 3,
- "pixel_shuffle expects input to have at least 3 dimensions, but got input with ",
- self.dim(), " dimension(s)");
- TORCH_CHECK(upscale_factor > 0,
- "pixel_shuffle expects a positive upscale_factor, but got ",
- upscale_factor);
- int64_t c = self.size(-3);
- int64_t upscale_factor_squared = upscale_factor * upscale_factor;
- TORCH_CHECK(c % upscale_factor_squared == 0,
- "pixel_shuffle expects its input's 'channel' dimension to be divisible by the square of "
- "upscale_factor, but input.size(-3)=", c, " is not divisible by ", upscale_factor_squared);
-}
-
-static inline void check_pixel_unshuffle_shapes(const Tensor& self, int64_t downscale_factor) {
- TORCH_CHECK(self.dim() >= 3,
- "pixel_unshuffle expects input to have at least 3 dimensions, but got input with ",
- self.dim(), " dimension(s)");
- TORCH_CHECK(downscale_factor > 0,
- "pixel_unshuffle expects a positive downscale_factor, but got ",
- downscale_factor);
- int64_t h = self.size(-2);
- int64_t w = self.size(-1);
- TORCH_CHECK(h % downscale_factor == 0,
- "pixel_unshuffle expects height to be divisible by downscale_factor, but input.size(-2)=", h,
- " is not divisible by ", downscale_factor);
- TORCH_CHECK(w % downscale_factor == 0,
- "pixel_unshuffle expects width to be divisible by downscale_factor, but input.size(-1)=", w,
- " is not divisible by ", downscale_factor);
-}
-
Tensor pixel_shuffle_cpu(const Tensor& self, int64_t upscale_factor) {
check_pixel_shuffle_shapes(self, upscale_factor);
diff --git a/aten/src/ATen/native/PixelShuffle.h b/aten/src/ATen/native/PixelShuffle.h
new file mode 100644
index 0000000..a9a66a3
--- /dev/null
+++ b/aten/src/ATen/native/PixelShuffle.h
@@ -0,0 +1,47 @@
+#include <ATen/core/Tensor.h>
+#include <c10/util/Exception.h>
+
+namespace at {
+namespace native {
+
+inline void check_pixel_shuffle_shapes(const Tensor& self, int64_t upscale_factor) {
+ TORCH_CHECK(self.dim() >= 3,
+ "pixel_shuffle expects input to have at least 3 dimensions, but got input with ",
+ self.dim(), " dimension(s)");
+ TORCH_CHECK(upscale_factor > 0,
+ "pixel_shuffle expects a positive upscale_factor, but got ",
+ upscale_factor);
+ int64_t c = self.size(-3);
+ int64_t upscale_factor_squared = upscale_factor * upscale_factor;
+ TORCH_CHECK(c % upscale_factor_squared == 0,
+ "pixel_shuffle expects its input's 'channel' dimension to be divisible by the square of "
+ "upscale_factor, but input.size(-3)=", c, " is not divisible by ", upscale_factor_squared);
+}
+
+inline void check_pixel_unshuffle_shapes(const Tensor& self, int64_t downscale_factor) {
+ TORCH_CHECK(
+ self.dim() >= 3,
+ "pixel_unshuffle expects input to have at least 3 dimensions, but got input with ",
+ self.dim(),
+ " dimension(s)");
+ TORCH_CHECK(
+ downscale_factor > 0,
+ "pixel_unshuffle expects a positive downscale_factor, but got ",
+ downscale_factor);
+ int64_t h = self.size(-2);
+ int64_t w = self.size(-1);
+ TORCH_CHECK(
+ h % downscale_factor == 0,
+ "pixel_unshuffle expects height to be divisible by downscale_factor, but input.size(-2)=",
+ h,
+ " is not divisible by ",
+ downscale_factor);
+ TORCH_CHECK(
+ w % downscale_factor == 0,
+ "pixel_unshuffle expects width to be divisible by downscale_factor, but input.size(-1)=",
+ w,
+ " is not divisible by ",
+ downscale_factor);
+}
+
+}} // namespace at::native
diff --git a/aten/src/ATen/native/mps/operations/PixelShuffle.mm b/aten/src/ATen/native/mps/operations/PixelShuffle.mm
new file mode 100644
index 0000000..30e85bf
--- /dev/null
+++ b/aten/src/ATen/native/mps/operations/PixelShuffle.mm
@@ -0,0 +1,114 @@
+#include <ATen/native/PixelShuffle.h>
+#include <ATen/native/mps/OperationUtils.h>
+#include <ATen/ops/pixel_shuffle_native.h>
+#include <ATen/ops/pixel_unshuffle_native.h>
+
+using namespace at::mps;
+
+namespace at::native {
+
+static Tensor pixel_shuffle_helper(const Tensor& self, int64_t factor, bool upscale) {
+ using namespace mps;
+ using CachedGraph = MPSUnaryCachedGraph;
+
+ if (factor == 1) {
+ return self.clone();
+ }
+
+ if (upscale) {
+ check_pixel_shuffle_shapes(self, factor);
+ } else {
+ check_pixel_unshuffle_shapes(self, factor);
+ }
+
+ MPSStream* stream = getCurrentMPSStream();
+
+ const int64_t c = self.size(-3);
+ const int64_t h = self.size(-2);
+ const int64_t w = self.size(-1);
+ constexpr auto NUM_NON_BATCH_DIMS = 3;
+ const auto self_sizes_batch_end = self.sizes().end() - NUM_NON_BATCH_DIMS;
+
+ const int64_t factor_squared = factor * factor;
+ const int64_t oc = upscale ? c / factor_squared : c * factor_squared;
+ const int64_t oh = upscale ? h * factor : h / factor;
+ const int64_t ow = upscale ? w * factor : w / factor;
+
+ std::vector<int64_t> out_shape(self.sizes().begin(), self_sizes_batch_end);
+ out_shape.insert(out_shape.end(), {oc, oh, ow});
+
+ Tensor output = at::empty(out_shape, self.options());
+
+ if (output.numel() == 0) {
+ return output;
+ }
+
+ @autoreleasepool {
+ string key = (upscale ? "pixel_shuffle_" : "pixel_unshuffle_") + getTensorsStringKey({self}) + "_factor_" +
+ std::to_string(factor);
+ CachedGraph* cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
+ const auto ndims = self.ndimension();
+ MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, self);
+ MPSGraphTensor* outputTensor = nullptr;
+ if (upscale) {
+ outputTensor = [mpsGraph depthToSpace2DTensor:inputTensor
+ widthAxis:ndims - 1
+ heightAxis:ndims - 2
+ depthAxis:ndims - 3
+ blockSize:factor
+ usePixelShuffleOrder:YES
+ name:nil];
+ } else {
+ outputTensor = [mpsGraph spaceToDepth2DTensor:inputTensor
+ widthAxis:ndims - 1
+ heightAxis:ndims - 2
+ depthAxis:ndims - 3
+ blockSize:factor
+ usePixelShuffleOrder:YES
+ name:nil];
+ }
+
+ newCachedGraph->inputTensor_ = inputTensor;
+ newCachedGraph->outputTensor_ = outputTensor;
+ return newCachedGraph;
+ });
+
+ Placeholder selfPlaceholder = Placeholder(cachedGraph->inputTensor_, self);
+ Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_, output);
+
+ // Create dictionary of inputs and outputs
+ NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds =
+ @{selfPlaceholder.getMPSGraphTensor() : selfPlaceholder.getMPSGraphTensorData()};
+
+ NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results =
+ @{outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()};
+
+ runMPSGraph(stream, cachedGraph->graph(), feeds, results);
+ }
+
+ return output;
+}
+
+Tensor pixel_shuffle_mps(const Tensor& self, int64_t upscale_factor) {
+ if (!is_macos_13_or_newer()) {
+ TORCH_WARN_ONCE("MPS: pixel_shuffle op is supported starting from macOS 13.0. ",
+ "Falling back on CPU. This may have performance implications.");
+
+ return at::native::pixel_shuffle_cpu(self.to("cpu"), upscale_factor).to("mps");
+ }
+
+ return pixel_shuffle_helper(self, upscale_factor, /*upscale=*/true);
+}
+
+Tensor pixel_unshuffle_mps(const Tensor& self, int64_t downscale_factor) {
+ if (!is_macos_13_or_newer()) {
+ TORCH_WARN_ONCE("MPS: pixel_unshuffle op is supported starting from macOS 13.0. ",
+ "Falling back on CPU. This may have performance implications.");
+
+ return at::native::pixel_unshuffle_cpu(self.to("cpu"), downscale_factor).to("mps");
+ }
+
+ return pixel_shuffle_helper(self, downscale_factor, /*upscale=*/false);
+}
+
+} // namespace at::native
diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml
index e87c9bc..26bbfd0 100644
--- a/aten/src/ATen/native/native_functions.yaml
+++ b/aten/src/ATen/native/native_functions.yaml
@@ -4318,6 +4318,7 @@
- func: pixel_shuffle(Tensor self, int upscale_factor) -> Tensor
dispatch:
CPU: pixel_shuffle_cpu
+ MPS: pixel_shuffle_mps
CompositeExplicitAutogradNonFunctional: math_pixel_shuffle
autogen: pixel_shuffle.out
tags: core
@@ -4325,6 +4326,7 @@
- func: pixel_unshuffle(Tensor self, int downscale_factor) -> Tensor
dispatch:
CPU: pixel_unshuffle_cpu
+ MPS: pixel_unshuffle_mps
CompositeExplicitAutogradNonFunctional: math_pixel_unshuffle
autogen: pixel_unshuffle.out
diff --git a/test/test_mps.py b/test/test_mps.py
index 1159e8d..e9d7378 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -948,6 +948,134 @@
with self.assertRaisesRegex(RuntimeError, r"MPS driver API confirmed .+"):
leak_gpu0()
+
+class TestPixelShuffle(TestCaseMPS):
+ def test_pixel_shuffle_unshuffle(self):
+ def _test_pixel_shuffle_unshuffle_helper(num_input_dims, valid_channels_dim=True,
+ upscale_factor=None, is_contiguous=True):
+
+ def generate_input():
+ # If valid_channels_dim=False, add 1 to make channels dim indivisible by upscale_factor ** 2.
+ channels = random.randint(1, 4) * upscale_factor ** 2 + (0 if valid_channels_dim else 1)
+ height = random.randint(5, 10)
+ width = random.randint(5, 10)
+
+ if num_input_dims == 1:
+ input = torch.rand(channels, requires_grad=True, device='mps')
+ assert is_contiguous
+ elif num_input_dims == 2:
+ input = torch.rand(width, height, requires_grad=True, device='mps').T
+ if is_contiguous:
+ input = input.contiguous()
+ else:
+ batch_sizes = [random.randint(1, 3) for _ in range(num_input_dims - 3)]
+ input = torch.rand(*batch_sizes, channels, width, height, requires_grad=True, device='mps')
+ input = input.transpose(-1, -2)
+ if is_contiguous:
+ input = input.contiguous()
+
+ if not is_contiguous and len(input.reshape(-1)) > 0:
+ assert not input.is_contiguous()
+
+ input = input.detach().clone()
+ input.requires_grad = True
+ return input
+
+ # Function to imperatively ensure pixels are shuffled to the correct locations.
+ # Used to validate the batch operations in pixel_shuffle.
+ def _verify_pixel_shuffle(input, output, upscale_factor):
+ for c in range(output.size(-3)):
+ for h in range(output.size(-2)):
+ for w in range(output.size(-1)):
+ height_idx = h // upscale_factor
+ weight_idx = w // upscale_factor
+ channel_idx = (upscale_factor * (h % upscale_factor)) + (w % upscale_factor) + \
+ (c * upscale_factor ** 2)
+ self.assertEqual(output[..., c, h, w], input[..., channel_idx, height_idx, weight_idx])
+
+ upscale_factor = random.randint(2, 5) if upscale_factor is None else upscale_factor
+ input = generate_input()
+
+ ps = nn.PixelShuffle(upscale_factor)
+ pus = nn.PixelUnshuffle(downscale_factor=upscale_factor)
+
+ if num_input_dims >= 3 and valid_channels_dim and upscale_factor > 0:
+ output = ps(input)
+ _verify_pixel_shuffle(input, output, upscale_factor)
+ output.backward(output.data)
+ self.assertEqual(input.data, input.grad.data)
+
+ # Ensure unshuffle properly inverts shuffle.
+ unshuffle_output = pus(output)
+ self.assertEqual(input, unshuffle_output)
+ else:
+ self.assertRaises(RuntimeError, lambda: ps(input))
+
+ def _test_pixel_unshuffle_error_case_helper(num_input_dims, valid_height_dim=True, valid_width_dim=True,
+ downscale_factor=None):
+ downscale_factor = random.randint(2, 5) if downscale_factor is None else downscale_factor
+ channels = random.randint(1, 4)
+ # If valid_height_dim=False, add 1 to make height dim indivisible by downscale_factor.
+ height = random.randint(3, 5) * abs(downscale_factor) + (0 if valid_height_dim else 1)
+ # If valid_width_dim=False, add 1 to make width dim indivisible by downscale_factor.
+ width = random.randint(3, 5) * abs(downscale_factor) + (0 if valid_width_dim else 1)
+
+ if num_input_dims == 1:
+ input = torch.rand(channels, requires_grad=True, device='mps')
+ elif num_input_dims == 2:
+ input = torch.rand(height, width, requires_grad=True, device='mps')
+ else:
+ batch_sizes = [random.randint(1, 3) for _ in range(num_input_dims - 3)]
+ input = torch.rand(*batch_sizes, channels, height, width, requires_grad=True, device='mps')
+
+ pus = nn.PixelUnshuffle(downscale_factor)
+ self.assertRaises(RuntimeError, lambda: pus(input))
+
+ def _test_pixel_shuffle_unshuffle_for_input_dims(num_input_dims):
+ # For 1D - 2D, this is an error case.
+ # For 3D - 5D, this is a success case for pixel_shuffle + pixel_unshuffle.
+ is_contiguous_check = [True, False] if num_input_dims > 1 else [True]
+ for is_contiguous in is_contiguous_check:
+ _test_pixel_shuffle_unshuffle_helper(
+ num_input_dims=num_input_dims, is_contiguous=is_contiguous
+ )
+ _test_pixel_shuffle_unshuffle_helper(
+ num_input_dims=num_input_dims, valid_channels_dim=False, is_contiguous=is_contiguous
+ )
+ _test_pixel_shuffle_unshuffle_helper(
+ num_input_dims=num_input_dims, upscale_factor=0, is_contiguous=is_contiguous
+ )
+ _test_pixel_shuffle_unshuffle_helper(
+ num_input_dims=num_input_dims, upscale_factor=-2, is_contiguous=is_contiguous
+ )
+
+ # Error cases for pixel_unshuffle.
+ _test_pixel_unshuffle_error_case_helper(num_input_dims=num_input_dims, valid_height_dim=False)
+ _test_pixel_unshuffle_error_case_helper(num_input_dims=num_input_dims, valid_width_dim=False)
+ _test_pixel_unshuffle_error_case_helper(num_input_dims=num_input_dims, downscale_factor=0)
+ _test_pixel_unshuffle_error_case_helper(num_input_dims=num_input_dims, downscale_factor=-2)
+
+ def test_pixel_shuffle_unshuffle_1D():
+ _test_pixel_shuffle_unshuffle_for_input_dims(num_input_dims=1)
+
+ def test_pixel_shuffle_unshuffle_2D():
+ _test_pixel_shuffle_unshuffle_for_input_dims(num_input_dims=2)
+
+ def test_pixel_shuffle_unshuffle_3D():
+ _test_pixel_shuffle_unshuffle_for_input_dims(num_input_dims=3)
+
+ def test_pixel_shuffle_unshuffle_4D():
+ _test_pixel_shuffle_unshuffle_for_input_dims(num_input_dims=4)
+
+ def test_pixel_shuffle_unshuffle_5D():
+ _test_pixel_shuffle_unshuffle_for_input_dims(num_input_dims=5)
+
+ test_pixel_shuffle_unshuffle_1D()
+ test_pixel_shuffle_unshuffle_2D()
+ test_pixel_shuffle_unshuffle_3D()
+ test_pixel_shuffle_unshuffle_4D()
+ test_pixel_shuffle_unshuffle_5D()
+
class MPSReluTest(TestCaseMPS):
def _npRelu(self, np_features):
return np.maximum(np_features, np.zeros(np_features.shape)).astype(np_features.dtype)