[MPS] Disallow reshape in slice (#95905)
Disallow reshapes for arrayViews.
Current code allows a base shape of `[2, 4, 256]` to be sliced into `[4, 1, 256]` (view's shape) - which is not possible. Slicing a smaller dimension into a bigger one will always error out.
Fixes https://github.com/pytorch/pytorch/issues/95883
Pull Request resolved: https://github.com/pytorch/pytorch/pull/95905
Approved by: https://github.com/razarmehr, https://github.com/kulinseth
diff --git a/test/test_mps.py b/test/test_mps.py
index 9877612..9949e9c 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -1954,6 +1954,78 @@
x_cpu = x_cpu + 2
self.assertEqual(x, x_cpu)
+ def test_reshape_storage_offset(self):
+ # https://github.com/pytorch/pytorch/issues/95883
+ B = 4
+ T = 1
+
+ lin_cpu = nn.Linear(10, 256)
+ lin_mps = nn.Linear(10, 256, device="mps")
+
+ # Use the same weights and bias as the ones from the cpu
+ lin_mps.weight.data = lin_cpu.weight.data.detach().clone().to("mps").requires_grad_()
+ lin_mps.bias.data = lin_cpu.bias.data.detach().clone().to("mps").requires_grad_()
+
+ x_mps = torch.rand([B, T, 10], device="mps", requires_grad=True)
+ x_cpu = x_mps.detach().clone().cpu().requires_grad_()
+ x_mps = lin_mps(x_mps)
+ x_cpu = lin_cpu(x_cpu)
+
+ self.assertEqual(x_mps.shape, (B, T, 256))
+ self.assertEqual(x_cpu.shape, (B, T, 256))
+
+ cls_token_mps = torch.rand([1, 256], device="mps", requires_grad=True).repeat(B, 1, 1)
+ cls_token_cpu = cls_token_mps.detach().clone().cpu()
+ x_mps = torch.cat([cls_token_mps, x_mps], dim=1)
+ x_cpu = torch.cat([cls_token_cpu, x_cpu], dim=1)
+
+ x_mps = x_mps.transpose(0, 1)
+ x_cpu = x_cpu.transpose(0, 1)
+
+ target_mps = torch.rand_like(x_mps)
+ target_cpu = target_mps.detach().clone().cpu()
+ loss_mps = F.mse_loss(x_mps, target_mps)
+ loss_cpu = F.mse_loss(x_cpu, target_cpu)
+ self.assertEqual(loss_mps, loss_cpu)
+
+ loss_mps.backward()
+ loss_cpu.backward()
+ self.assertEqual(x_mps.grad, x_cpu.grad)
+
+ def test_stack(self):
+ # https://github.com/pytorch/pytorch/issues/87856
+ x_cpu = torch.tensor([[1, 2]])
+ x_mps = x_cpu.detach().clone().to("mps")
+
+ y_cpu = torch.stack((x_cpu[:, :1], x_cpu[:, -1:]), dim=-1)
+ y_mps = torch.stack((x_mps[:, :1], x_mps[:, -1:]), dim=-1)
+
+ self.assertEqual(y_cpu, y_mps)
+
+ t_mps = torch.tensor([1, 2, 3, 4], device="mps")
+ t_cpu = t_mps.detach().cpu().detach()
+
+ x_mps = t_mps[2:]
+ y_mps = t_mps[:2]
+
+ x_cpu = t_cpu[2:]
+ y_cpu = t_cpu[:2]
+
+ res_mps = torch.stack((y_mps, x_mps), dim=-1)
+ res_cpu = torch.stack((y_cpu, x_cpu), dim=-1)
+
+ self.assertEqual(res_mps, res_cpu)
+
+ def test_unsafe_chunk(self):
+ # https://github.com/pytorch/pytorch/issues/91065
+ a = torch.rand(5, dtype=torch.float32, device="cpu")
+ ret = a.unsafe_chunk(4, 0)
+ y = ret[0] * ret[2]
+ a_mps = a.to("mps")
+ ret_mps = a_mps.unsafe_chunk(4, 0)
+ y_mps = ret_mps[0] * ret_mps[2]
+ self.assertEqual(y, y_mps)
+
def test_slice_casting(self):
# generate random binary numbers
cpu_in = torch.bernoulli(torch.empty(1, 1, 128, 128).uniform_(0, 1)).to(torch.uint8)