[MPS] Support stride of stride
Fixes https://github.com/pytorch/pytorch/issues/79181
Pull Request resolved: https://github.com/pytorch/pytorch/pull/79521
Approved by: https://github.com/kulinseth
diff --git a/test/test_mps.py b/test/test_mps.py
index e5dc637..4ce4737 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -1460,6 +1460,14 @@
self.assertEqual(t[2, 1], j)
self.assertEqual(t.sum(), 1 + i + j)
+ def test_stride_of_strides(self) -> None:
+ x = torch.rand(32, 1, device='mps')
+ y = x.as_strided(size=(32, 2), stride=(1, 0))
+ # Casting stride of strided tensor to CPU use to crash with "buffer is not large enough." assert
+ # See https://github.com/pytorch/pytorch/issues/79181#issuecomment-1154683435
+ z = y.as_strided(size=(32, 3), stride=(1, 0)).to("cpu")
+ self.assertEqual(x.to("cpu").as_strided(size=(32, 3), stride=(1, 0)), z)
+
class TestSmoothL1Loss(TestCase):