MPS: Fix crashes in view tensors due to buffer size mismatch (#78496)
Fixes #78247, #77886
Pull Request resolved: https://github.com/pytorch/pytorch/pull/78496
Approved by: https://github.com/albanD, https://github.com/malfet
diff --git a/test/test_mps.py b/test/test_mps.py
index 1765b3b..f0b411d 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -3963,6 +3963,25 @@
input = torch.randn((1, 1, 1, 1), dtype=torch.float)
self.assertEqual(m(input).size(), (1, 1, 1, 1))
+ def test_conv_expand(self):
+ device = 'mps'
+ input_ = torch.rand(2, 3, 16, 16, device=device)
+ kernel = torch.rand(1, 1, 3, 11, device=device)
+ tmp_kernel = kernel.expand(-1, 3, -1, -1)
+ output = F.conv2d(input_, tmp_kernel, groups=1, padding=0, stride=1)
+
+ # The test should not crash
+ def test_permute(self):
+ X = torch.randn(5, 5).to('mps')
+ torch.log(X)
+ X = X.permute(1, 0)
+ torch.log(X)
+
+ # Printing of non_contiguous should not crash
+ def test_print_non_contiguous(self):
+ print(torch.ones(100, 100, device='mps').nonzero())
+ print(torch.ones(100, 100, device='mps').nonzero().contiguous())
+
def test_zero_grad(self):
i = torch.randn(2, 5, requires_grad=True)
module = nn.Linear(5, 5)