[MPS] Fix index_select scalar input with multiple indices (#91064)
Support operations like this:
```
device="mps"
arr = torch.tensor(10, device=device)
indices = torch.tensor([0, 0], device=device) # multiple indices
torch.index_select(arr, 0, indices)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/91064
Approved by: https://github.com/kulinseth
diff --git a/aten/src/ATen/native/mps/operations/Indexing.mm b/aten/src/ATen/native/mps/operations/Indexing.mm
index e93b333..bea0338 100644
--- a/aten/src/ATen/native/mps/operations/Indexing.mm
+++ b/aten/src/ATen/native/mps/operations/Indexing.mm
@@ -628,6 +628,12 @@
TORCH_CHECK(dim == 0 || dim < self.dim(),
"index_select(): Indexing dim ", dim, " is out of bounds of tensor");
+ // Scalar input
+ if (self.dim() == 0 && self.numel() == 1){
+ output.copy_(self);
+ return output;
+ }
+
// Derive from MPSCachedGraph
struct CachedGraph : public MPSCachedGraph
{
diff --git a/test/test_mps.py b/test/test_mps.py
index f496702..d7e560e 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -5108,6 +5108,22 @@
helper((2, 8, 4, 5), 2, [3, 0, 1])
helper((2, 8, 4, 5), 3, [2, 3, 0])
helper((2, 3, 3), -1, [1, 2])
+ helper((), 0, [0])
+
+ def test_index_select_scalar(self):
+ def helper(value, dim, index, idx_dtype=torch.int32):
+ cpu_x = torch.tensor(value, device='cpu', dtype=torch.float, requires_grad=False)
+ x = cpu_x.detach().clone().to('mps')
+
+ cpu_idx = torch.tensor(index, device='cpu', dtype=idx_dtype)
+ idx = cpu_idx.detach().clone().to('mps')
+
+ idx_result = torch.index_select(x, dim=dim, index=idx)
+ idx_result_cpu = torch.index_select(cpu_x, dim=dim, index=cpu_idx)
+
+ self.assertEqual(idx_result, idx_result_cpu)
+
+ helper(0.5, 0, [0, 0])
def test_embedding_dense_backward(self):
def helper(n, d, m, idx):