[MPS] Handle scalar input for scatter and gather (#85842)
Issue noticed in test consistency - "Indexing dim 0 is out of bounds of tensor"
Pull Request resolved: https://github.com/pytorch/pytorch/pull/85842
Approved by: https://github.com/kulinseth
diff --git a/test/test_mps.py b/test/test_mps.py
index 6c63602..c251e92 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -4250,6 +4250,28 @@
helper((2, 8, 4, 5), 2, (1, 8, 10, 3))
helper((2, 8, 4, 5), 3, (2, 5, 3, 12))
+ # Test pytorch gather
+ def test_gather_scalar(self):
+ idx_dtype = torch.int64
+ cpu_x = torch.tensor(3, device='cpu', dtype=torch.float, requires_grad=True)
+ x = cpu_x.detach().clone().to('mps').requires_grad_()
+
+ idx_np = [0]
+
+ cpu_idx = torch.tensor(idx_np, device='cpu', dtype=idx_dtype)
+ idx = cpu_idx.detach().clone().to('mps')
+
+ gather_result = torch.gather(x, dim=0, index=idx)
+ gather_result_cpu = torch.gather(cpu_x, dim=0, index=cpu_idx)
+
+ cpu_grad = torch.randn([1], device='cpu', dtype=torch.float)
+ grad = cpu_grad.to('mps')
+ gather_result.backward(gradient=grad)
+ gather_result_cpu.backward(gradient=cpu_grad)
+
+ self.assertEqual(gather_result, gather_result_cpu)
+ self.assertEqual(cpu_x.grad, x.grad)
+
# Test pytorch scatter_add and scatter
def test_scatter_add(self):
def helper(shape, dim, idx_shape, src_shape, idx_dtype=torch.int64, do_add=True):
@@ -4317,6 +4339,46 @@
helper((8, 3), 0, (5, 3), (5, 3), do_add=False)
helper((10, 3), 0, (5, 3), (5, 8), do_add=False)
+ # Test pytorch scatter_add and scatter for scalar input
+ def test_scatter_add_scalar(self):
+ def helper(idx_dtype=torch.int64, do_add=True):
+ cpu_x = torch.tensor(2, device='cpu', dtype=torch.float, requires_grad=True)
+ x = cpu_x.detach().clone().to('mps').requires_grad_()
+
+ cpu_src = torch.tensor(3, device='cpu', dtype=torch.float, requires_grad=True)
+ src = cpu_src.detach().clone().to('mps').requires_grad_()
+
+ # Indices should be taken from range of axis along which gathering is done
+ idx_np = [0]
+
+ cpu_idx = torch.tensor(idx_np, device='cpu', dtype=idx_dtype)
+ idx = cpu_idx.detach().clone().to('mps')
+
+ scatter_result = None
+ scatter_result_cpu = None
+
+ if(do_add):
+ scatter_result = torch.scatter_add(x, dim=0, index=idx, src=src)
+ scatter_result_cpu = torch.scatter_add(cpu_x, dim=0, index=cpu_idx, src=cpu_src)
+ else:
+ scatter_result = torch.scatter(x, dim=0, index=idx, src=src)
+ scatter_result_cpu = torch.scatter(cpu_x, dim=0, index=cpu_idx, src=cpu_src)
+
+ cpu_grad = None
+ grad = None
+
+ cpu_grad = torch.tensor(1.2, device='cpu', dtype=torch.float)
+ grad = cpu_grad.to('mps')
+ scatter_result.backward(gradient=grad)
+ scatter_result_cpu.backward(gradient=cpu_grad)
+
+ self.assertEqual(scatter_result, scatter_result_cpu)
+ self.assertEqual(cpu_x.grad, x.grad)
+ self.assertEqual(cpu_src.grad, src.grad)
+
+ helper()
+ helper(do_add=False)
+
# Test pytorch scatter_reduce
def test_scatter_reduce(self):
def helper(shape, dim, idx_shape, src_shape, idx_dtype=torch.int64, reduce_str="sum"):