[MPS] Add support for randperm (#91708)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/91708
Approved by: https://github.com/kulinseth
diff --git a/test/test_mps.py b/test/test_mps.py
index 0525a0c..5b5d70b 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -684,6 +684,40 @@
low.grad.zero_()
high.grad.zero_()
+ def test_randperm(self, device="mps"):
+ rng_device = None
+ for n in (5, 100, 50000, 100000):
+ for dtype in (torch.long, torch.half, torch.float):
+ if n > 2049 and dtype == torch.half: # Large n for torch.half will raise an exception, do not test here.
+ continue
+ if n > 256 and dtype == torch.bfloat16:
+ continue
+ with torch.random.fork_rng(devices=rng_device):
+ res1 = torch.randperm(n, dtype=dtype, device=device)
+ res2 = torch.empty(0, dtype=dtype, device=device)
+ torch.randperm(n, out=res2, dtype=dtype, device=device)
+ self.assertEqual(res1.cpu().sort().values.long(), torch.arange(n, device=device))
+
+ # Default type is long
+ for n in (100, 10000):
+ self.assertEqual(torch.randperm(n, device=device).dtype, torch.long)
+
+ # randperm of 0 elements is an empty tensor
+ res1 = torch.randperm(0)
+ res2 = torch.tensor(5, dtype=dtype, device=device)
+ torch.randperm(0, out=res2)
+ self.assertEqual(res1.numel(), 0)
+ self.assertEqual(res2.numel(), 0)
+
+ # Test non-contiguous tensors
+ for n in (4, 5, 6, 10, 20):
+ non_contiguous_tensor = torch.zeros((2, 3), dtype=torch.long, device=device).t()
+ self.assertFalse(non_contiguous_tensor.is_contiguous())
+ with torch.random.fork_rng(devices=rng_device):
+ res = torch.randperm(n, dtype=torch.long, device=device)
+ torch.randperm(n, out=non_contiguous_tensor)
+ self.assertEqual(res.cpu().sort().values.long(), torch.arange(n, device=device))
+
# Test forward maxpool2d
def test_max_pool2d(self):
def helper(shape, ks, padding=0, dilation=1, ceil_mode=False, return_indices=False, test_ties=False):