[MPS] Fix base shape size for view ops in case of multiple slices (#85934)
Fixes https://github.com/pytorch/pytorch/issues/84364, https://github.com/pytorch/pytorch/issues/85592
Fixes bug for view ops where the base shape would be incorectly determined.
E.g for the following tensor `torch.tensor([0.5, 0.5], device="mps")[1][None]`, we could consider the base shape of the parent tensor as 1, while the actual base shape is 2.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/85934
Approved by: https://github.com/kulinseth
diff --git a/test/test_mps.py b/test/test_mps.py
index c251e92..fa5f663 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -1373,6 +1373,18 @@
for op in ["<=", "<", ">=", ">", "==", "!="]:
helper(op)
+ def test_slice_of_slice(self):
+ x = torch.tensor([0.5, 0.5], device="cpu")
+ x_mps = torch.tensor([0.5, 0.5], device="mps")
+
+ tensor = x[1][None]
+ tensor_mps = x_mps[1][None]
+
+ res = tensor.ne(0)
+ res_mps = tensor_mps.ne(0)
+
+ self.assertEqual(res, res_mps)
+
def test_index_storage_offset(self):
# https://github.com/pytorch/pytorch/issues/78107