[MPS] Fix channels last copies in ELU,ReLU and Hardswish (#94664)
Fixes test_modules.py tests:
```
test_memory_format_nn_Hardswish_mps_float32
test_non_contiguous_tensors_nn_Hardswish_mps_float32
test_memory_format_nn_ReLU_mps_float32
```
Fixes elu when ran with `ChannelsLast` memory format.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/94664
Approved by: https://github.com/kulinseth
diff --git a/test/test_mps.py b/test/test_mps.py
index bd3f5c1..8b282a9 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -4739,10 +4739,11 @@
# Test selu, elu, celu
def test_elu(self):
- def helper(shape, alpha=1.0):
- cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)
- x = cpu_x.detach().clone().to('mps').requires_grad_()
+ def helper(shape, alpha=1.0, memory_format=torch.contiguous_format):
+ cpu_x = torch.randn(shape, device='cpu', dtype=torch.float)
+ cpu_x = cpu_x.to(memory_format=memory_format).requires_grad_()
+ x = cpu_x.detach().clone().to('mps').requires_grad_(True)
for activation_func in [torch.nn.ELU(alpha=alpha), torch.nn.CELU(alpha=alpha), torch.nn.SELU()]:
elu_result = activation_func(x)
elu_result_cpu = activation_func(cpu_x)
@@ -4757,9 +4758,10 @@
self.assertEqual(x.grad, cpu_x.grad)
# Test empty shape too
- for shape in [[], (2, 3), (2, 8, 4, 5)]:
- for alpha in [0.000001, 1.0, 2.3, 0.34, 23]:
- helper(shape, alpha)
+ for memory_fromat in [torch.channels_last, torch.contiguous_format]:
+ for shape in [(2, 8, 4, 5)]:
+ for alpha in [0.000001, 1.0, 2.3, 0.34, 23]:
+ helper(shape, alpha, memory_fromat)
# Test glu
def test_glu(self):