[MPS] Fix index_select scalar input with multiple indices (#91064)
Support operations like this:
```
device="mps"
arr = torch.tensor(10, device=device)
indices = torch.tensor([0, 0], device=device) # multiple indices
torch.index_select(arr, 0, indices)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/91064
Approved by: https://github.com/kulinseth
diff --git a/test/test_mps.py b/test/test_mps.py
index f496702..d7e560e 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -5108,6 +5108,22 @@
helper((2, 8, 4, 5), 2, [3, 0, 1])
helper((2, 8, 4, 5), 3, [2, 3, 0])
helper((2, 3, 3), -1, [1, 2])
+ helper((), 0, [0])
+
+ def test_index_select_scalar(self):
+ def helper(value, dim, index, idx_dtype=torch.int32):
+ cpu_x = torch.tensor(value, device='cpu', dtype=torch.float, requires_grad=False)
+ x = cpu_x.detach().clone().to('mps')
+
+ cpu_idx = torch.tensor(index, device='cpu', dtype=idx_dtype)
+ idx = cpu_idx.detach().clone().to('mps')
+
+ idx_result = torch.index_select(x, dim=dim, index=idx)
+ idx_result_cpu = torch.index_select(cpu_x, dim=dim, index=cpu_idx)
+
+ self.assertEqual(idx_result, idx_result_cpu)
+
+ helper(0.5, 0, [0, 0])
def test_embedding_dense_backward(self):
def helper(n, d, m, idx):