[MPS] Pixel shuffle unshuffle support (#99306)
Fixes #83196
Now, MPS implementation is blazingly fast.
Though, I have several questions on improving this PR:
1. I copied code from `test_nn.py`. Is there better way to test this?
2. I decided to use `usepixelshuffleorder:YES`. Am I right performance-wise? According to docs:
`usePixelShuffleOrder` can be
used to control how the data within spatial blocks is ordered in the
`depthAxis` dimension: with `usePixelShuffleOrder=YES` the values within the
spatial blocks are stored contiguosly within the `depthAxis` dimension whereas
otherwise they are stored interleaved with existing values in the `depthAxis` dimension.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99306
Approved by: https://github.com/kulinseth, https://github.com/malfet
diff --git a/test/test_mps.py b/test/test_mps.py
index 1159e8d..e9d7378 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -948,6 +948,134 @@
with self.assertRaisesRegex(RuntimeError, r"MPS driver API confirmed .+"):
+class TestPixelShuffle(TestCaseMPS):
+ def test_pixel_shuffle_unshuffle(self):
+ def _test_pixel_shuffle_unshuffle_helper(num_input_dims, valid_channels_dim=True,
+ upscale_factor=None, is_contiguous=True):
+ def generate_input():
+ # If valid_channels_dim=False, add 1 to make channels dim indivisible by upscale_factor ** 2.
+ channels = random.randint(1, 4) * upscale_factor ** 2 + (0 if valid_channels_dim else 1)
+ height = random.randint(5, 10)
+ width = random.randint(5, 10)
+ if num_input_dims == 1:
+ input = torch.rand(channels, requires_grad=True, device='mps')
+ assert is_contiguous
+ elif num_input_dims == 2:
+ input = torch.rand(width, height, requires_grad=True, device='mps').T
+ if is_contiguous:
+ input = input.contiguous()
+ else:
+ batch_sizes = [random.randint(1, 3) for _ in range(num_input_dims - 3)]
+ input = torch.rand(*batch_sizes, channels, width, height, requires_grad=True, device='mps')
+ input = input.transpose(-1, -2)
+ if is_contiguous:
+ input = input.contiguous()
+ if not is_contiguous and len(input.reshape(-1)) > 0:
+ assert not input.is_contiguous()
+ input = input.detach().clone()
+ input.requires_grad = True
+ return input
+ # Function to imperatively ensure pixels are shuffled to the correct locations.
+ # Used to validate the batch operations in pixel_shuffle.
+ def _verify_pixel_shuffle(input, output, upscale_factor):
+ for c in range(output.size(-3)):
+ for h in range(output.size(-2)):
+ for w in range(output.size(-1)):
+ height_idx = h // upscale_factor
+ weight_idx = w // upscale_factor
+ channel_idx = (upscale_factor * (h % upscale_factor)) + (w % upscale_factor) + \
+ (c * upscale_factor ** 2)
+ self.assertEqual(output[..., c, h, w], input[..., channel_idx, height_idx, weight_idx])
+ upscale_factor = random.randint(2, 5) if upscale_factor is None else upscale_factor
+ input = generate_input()
+ ps = nn.PixelShuffle(upscale_factor)
+ pus = nn.PixelUnshuffle(downscale_factor=upscale_factor)
+ if num_input_dims >= 3 and valid_channels_dim and upscale_factor > 0:
+ output = ps(input)
+ _verify_pixel_shuffle(input, output, upscale_factor)
+ output.backward(output.data)
+ self.assertEqual(input.data, input.grad.data)
+ # Ensure unshuffle properly inverts shuffle.
+ unshuffle_output = pus(output)
+ self.assertEqual(input, unshuffle_output)
+ else:
+ self.assertRaises(RuntimeError, lambda: ps(input))
+ def _test_pixel_unshuffle_error_case_helper(num_input_dims, valid_height_dim=True, valid_width_dim=True,
+ downscale_factor=None):
+ downscale_factor = random.randint(2, 5) if downscale_factor is None else downscale_factor
+ channels = random.randint(1, 4)
+ # If valid_height_dim=False, add 1 to make height dim indivisible by downscale_factor.
+ height = random.randint(3, 5) * abs(downscale_factor) + (0 if valid_height_dim else 1)
+ # If valid_width_dim=False, add 1 to make width dim indivisible by downscale_factor.
+ width = random.randint(3, 5) * abs(downscale_factor) + (0 if valid_width_dim else 1)
+ if num_input_dims == 1:
+ input = torch.rand(channels, requires_grad=True, device='mps')
+ elif num_input_dims == 2:
+ input = torch.rand(height, width, requires_grad=True, device='mps')
+ else:
+ batch_sizes = [random.randint(1, 3) for _ in range(num_input_dims - 3)]
+ input = torch.rand(*batch_sizes, channels, height, width, requires_grad=True, device='mps')
+ pus = nn.PixelUnshuffle(downscale_factor)
+ self.assertRaises(RuntimeError, lambda: pus(input))
+ def _test_pixel_shuffle_unshuffle_for_input_dims(num_input_dims):
+ # For 1D - 2D, this is an error case.
+ # For 3D - 5D, this is a success case for pixel_shuffle + pixel_unshuffle.
+ is_contiguous_check = [True, False] if num_input_dims > 1 else [True]
+ for is_contiguous in is_contiguous_check:
+ _test_pixel_shuffle_unshuffle_helper(
+ num_input_dims=num_input_dims, is_contiguous=is_contiguous
+ )
+ _test_pixel_shuffle_unshuffle_helper(
+ num_input_dims=num_input_dims, valid_channels_dim=False, is_contiguous=is_contiguous
+ )
+ _test_pixel_shuffle_unshuffle_helper(
+ num_input_dims=num_input_dims, upscale_factor=0, is_contiguous=is_contiguous
+ )
+ _test_pixel_shuffle_unshuffle_helper(
+ num_input_dims=num_input_dims, upscale_factor=-2, is_contiguous=is_contiguous
+ )
+ # Error cases for pixel_unshuffle.
+ _test_pixel_unshuffle_error_case_helper(num_input_dims=num_input_dims, valid_height_dim=False)
+ _test_pixel_unshuffle_error_case_helper(num_input_dims=num_input_dims, valid_width_dim=False)
+ _test_pixel_unshuffle_error_case_helper(num_input_dims=num_input_dims, downscale_factor=0)
+ _test_pixel_unshuffle_error_case_helper(num_input_dims=num_input_dims, downscale_factor=-2)
+ def test_pixel_shuffle_unshuffle_1D():
+ _test_pixel_shuffle_unshuffle_for_input_dims(num_input_dims=1)
+ def test_pixel_shuffle_unshuffle_2D():
+ _test_pixel_shuffle_unshuffle_for_input_dims(num_input_dims=2)
+ def test_pixel_shuffle_unshuffle_3D():
+ _test_pixel_shuffle_unshuffle_for_input_dims(num_input_dims=3)
+ def test_pixel_shuffle_unshuffle_4D():
+ _test_pixel_shuffle_unshuffle_for_input_dims(num_input_dims=4)
+ def test_pixel_shuffle_unshuffle_5D():
+ _test_pixel_shuffle_unshuffle_for_input_dims(num_input_dims=5)
+ test_pixel_shuffle_unshuffle_1D()
+ test_pixel_shuffle_unshuffle_2D()
+ test_pixel_shuffle_unshuffle_3D()
+ test_pixel_shuffle_unshuffle_4D()
+ test_pixel_shuffle_unshuffle_5D()
class MPSReluTest(TestCaseMPS):
def _npRelu(self, np_features):
return np.maximum(np_features, np.zeros(np_features.shape)).astype(np_features.dtype)