[MPS] Handle scalar input for scatter and gather (#85842)

Issue noticed in test consistency - "Indexing dim 0 is out of bounds of tensor"
Pull Request resolved: https://github.com/pytorch/pytorch/pull/85842
Approved by: https://github.com/kulinseth
diff --git a/test/test_mps.py b/test/test_mps.py
index 6c63602..c251e92 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -4250,6 +4250,28 @@
         helper((2, 8, 4, 5), 2, (1, 8, 10, 3))
         helper((2, 8, 4, 5), 3, (2, 5, 3, 12))
 
+    # Test pytorch gather
+    def test_gather_scalar(self):
+        idx_dtype = torch.int64
+        cpu_x = torch.tensor(3, device='cpu', dtype=torch.float, requires_grad=True)
+        x = cpu_x.detach().clone().to('mps').requires_grad_()
+
+        idx_np = [0]
+
+        cpu_idx = torch.tensor(idx_np, device='cpu', dtype=idx_dtype)
+        idx = cpu_idx.detach().clone().to('mps')
+
+        gather_result = torch.gather(x, dim=0, index=idx)
+        gather_result_cpu = torch.gather(cpu_x, dim=0, index=cpu_idx)
+
+        cpu_grad = torch.randn([1], device='cpu', dtype=torch.float)
+        grad = cpu_grad.to('mps')
+        gather_result.backward(gradient=grad)
+        gather_result_cpu.backward(gradient=cpu_grad)
+
+        self.assertEqual(gather_result, gather_result_cpu)
+        self.assertEqual(cpu_x.grad, x.grad)
+
     # Test pytorch scatter_add and scatter
     def test_scatter_add(self):
         def helper(shape, dim, idx_shape, src_shape, idx_dtype=torch.int64, do_add=True):
@@ -4317,6 +4339,46 @@
         helper((8, 3), 0, (5, 3), (5, 3), do_add=False)
         helper((10, 3), 0, (5, 3), (5, 8), do_add=False)
 
+    # Test pytorch scatter_add and scatter for scalar input
+    def test_scatter_add_scalar(self):
+        def helper(idx_dtype=torch.int64, do_add=True):
+            cpu_x = torch.tensor(2, device='cpu', dtype=torch.float, requires_grad=True)
+            x = cpu_x.detach().clone().to('mps').requires_grad_()
+
+            cpu_src = torch.tensor(3, device='cpu', dtype=torch.float, requires_grad=True)
+            src = cpu_src.detach().clone().to('mps').requires_grad_()
+
+            # Indices should be taken from range of axis along which gathering is done
+            idx_np = [0]
+
+            cpu_idx = torch.tensor(idx_np, device='cpu', dtype=idx_dtype)
+            idx = cpu_idx.detach().clone().to('mps')
+
+            scatter_result = None
+            scatter_result_cpu = None
+
+            if(do_add):
+                scatter_result = torch.scatter_add(x, dim=0, index=idx, src=src)
+                scatter_result_cpu = torch.scatter_add(cpu_x, dim=0, index=cpu_idx, src=cpu_src)
+            else:
+                scatter_result = torch.scatter(x, dim=0, index=idx, src=src)
+                scatter_result_cpu = torch.scatter(cpu_x, dim=0, index=cpu_idx, src=cpu_src)
+
+            cpu_grad = None
+            grad = None
+
+            cpu_grad = torch.tensor(1.2, device='cpu', dtype=torch.float)
+            grad = cpu_grad.to('mps')
+            scatter_result.backward(gradient=grad)
+            scatter_result_cpu.backward(gradient=cpu_grad)
+
+            self.assertEqual(scatter_result, scatter_result_cpu)
+            self.assertEqual(cpu_x.grad, x.grad)
+            self.assertEqual(cpu_src.grad, src.grad)
+
+        helper()
+        helper(do_add=False)
+
     # Test pytorch scatter_reduce
     def test_scatter_reduce(self):
         def helper(shape, dim, idx_shape, src_shape, idx_dtype=torch.int64, reduce_str="sum"):