[MPS] Fix gelu forward and backward ops (#94529)
Forward pass:
```
fix gelu_out_mps key
add calculation for gelu with tanh
remove gelu from blocklist
```
Backward pass:
```
fix gelu_backward_out_mps key
uniform format
add caculation for tanh approximate backward pass
unblock grad test from blocklist
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/94529
Approved by: https://github.com/razarmehr, https://github.com/kulinseth
diff --git a/test/test_mps.py b/test/test_mps.py
index e3329a4..cd40f44 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -5005,6 +5005,17 @@
finally:
torch.set_num_threads(num_threads)
+ def test_gelu_tanh(self):
+ def helper(shape):
+ cpu_x = torch.randn(shape, device='cpu', dtype=torch.float)
+ x = cpu_x.detach().clone().to('mps')
+
+ gelu_tanh_result = torch.nn.functional.gelu(x, approximate='tanh')
+ gelu_tanh_result_cpu = torch.nn.functional.gelu(cpu_x, approximate='tanh')
+ self.assertEqual(gelu_tanh_result, gelu_tanh_result_cpu)
+
+ helper((2, 8, 4, 5))
+
# Test hardtanh
def test_hardtanh(self):
def helper(shape, min_val, max_val, inplace=False):
@@ -9175,6 +9186,7 @@
'_native_batch_norm_legit': ['f32'],
'native_batch_norm': ['f32'],
'native_layer_norm': ['f32'],
+ 'nn.functional.gelu': ['f32'],
}
# These ops that are problematic. So never run them even when