[MPS] Skip gather/blit calls in case of strided output (#94260)
Skip gather/blit calls in case of strided output - this prevents:
- allocating additional memory for the output
- additional transpose for both the input and output
Fixes:
```
x = torch.rand((256,10), device='mps')
x = x.permute(1,0)
x.exp()
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/94260
Approved by: https://github.com/razarmehr
diff --git a/test/test_mps.py b/test/test_mps.py
index 9ecaa30..3b3bf9e 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -251,6 +251,17 @@
input = torch.tensor([-0.1, 3.0, -0.9]).to('mps')
output = torch.exp(input).to('cpu')
+ def test_exp_strided_output(self):
+ x = torch.rand((256, 10), device='mps')
+ x_cpu = x.to("cpu")
+
+ x = x.permute(1, 0)
+ x_cpu = x_cpu.permute(1, 0)
+
+ res = x.exp()
+ res_cpu = x_cpu.exp()
+ self.assertEqual(res, res_cpu)
+
def _testLeakyRelu(self, np_features, negative_slope, device):
cpu_x = torch.from_numpy(np_features).requires_grad_()
mps_x = torch.from_numpy(np_features).to('mps').requires_grad_()