[MPS] Add sort and argSort Op. (#94697)
Fixes #ISSUE_NUMBER
Pull Request resolved: https://github.com/pytorch/pytorch/pull/94697
Approved by: https://github.com/DenisVieriu97
diff --git a/test/test_mps.py b/test/test_mps.py
index a8d17ba..314ad5c 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -4375,6 +4375,26 @@
helper((5, 9, 7, 4))
helper((50, 20, 7, 4))
+ def test_sort(self):
+ for SIZE in (4, 2049):
+ device = 'mps'
+ x = torch.rand(4, SIZE, device=device)
+ res1val, res1ind = torch.sort(x)
+
+ res2val = torch.tensor((), device=device)
+ res2ind = torch.tensor((), device=device, dtype=torch.long)
+ torch.sort(x, out=(res2val, res2ind))
+ self.assertEqual(res1val, res2val, atol=0, rtol=0)
+ self.assertEqual(res1ind, res2ind, atol=0, rtol=0)
+ self.assertEqual(torch.argsort(x), res1ind)
+ self.assertEqual(x.argsort(), res1ind)
+
+ self.assertEqual(
+ torch.sort(torch.tensor((50, 40, 30, 20, 10), device=device))[0],
+ torch.tensor((10, 20, 30, 40, 50), device=device),
+ atol=0, rtol=0
+ )
+
def test_upsample_nearest2d(self):
def helper(N, C, H, W):
inputCPU = torch.arange(N * C * H * W, device='cpu', dtype=torch.float,
@@ -9076,6 +9096,8 @@
'tile': ['f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
'topk': ['f32', 'f16'],
'trapz': ['f16', 'f32', 'i16', 'i32', 'i64'],
+ 'sort': ['f32', 'i16', 'i32', 'i64'],
+ 'argsort': ['f32', 'i16', 'i32', 'i64'],
'tril': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
'tril_indices': ['i32', 'i64'],
'triu': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],