[MPS] Fix `index_select` for empty input (#94117)
Also add test for this case to `test_index_select`
Fixes https://github.com/pytorch/pytorch/issues/93877
Pull Request resolved: https://github.com/pytorch/pytorch/pull/94117
Approved by: https://github.com/orionr
diff --git a/test/test_mps.py b/test/test_mps.py
index 423f3ba..d3fcc88 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -5109,6 +5109,7 @@
helper((2, 8, 4, 5), 3, [2, 3, 0])
helper((2, 3, 3), -1, [1, 2])
helper((), 0, [0])
+ helper((5), 0, [])
def test_index_select_scalar(self):
def helper(value, dim, index, idx_dtype=torch.int32):
@@ -5124,6 +5125,7 @@
self.assertEqual(idx_result, idx_result_cpu)
helper(0.5, 0, [0, 0])
+ helper(22, 0, [])
def test_embedding_dense_backward(self):
def helper(n, d, m, idx):