[MPS] Fix strided ELU op (#125692)
Fixes https://github.com/pytorch/pytorch/issues/124834
Summary of changes:
In case of non-contiguous input, the output would be non-contiguous too. At the moment it's not supported to save the result to a non-contiguous buffer, thus we need two steps, one to allocate a contiguous buffer and the second one to scatter the result back to the original ouput.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/125692
Approved by: https://github.com/kulinseth
diff --git a/test/test_mps.py b/test/test_mps.py
index 598fde0..d13eabf 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -6500,6 +6500,18 @@
for alpha in [0.000001, 1.0, 2.3, 0.34, 23]:
helper(shape, alpha, memory_fromat)
+ def test_elu_strided_output(self):
+ # https://github.com/pytorch/pytorch/issues/124834
+ elu_input = torch.randn(1, 1024, 500)
+ alpha = float(1)
+ inplace = False
+
+ elu_input_noncontiguous = elu_input.transpose(1, 2)
+ self.assertEqual(
+ F.elu(elu_input_noncontiguous.to('cpu'), alpha, inplace),
+ F.elu(elu_input_noncontiguous.to('mps'), alpha, inplace)
+ )
+
# Test glu
def test_glu(self):
def helper(shape, dim=0):