[MPS] Fix gelu forward and backward ops (#94529)

Forward pass:
```
fix gelu_out_mps key
add calculation for gelu with tanh
remove gelu from blocklist
```
Backward pass:
```
fix gelu_backward_out_mps key
uniform format
add caculation for tanh approximate backward pass
unblock grad test from blocklist
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/94529
Approved by: https://github.com/razarmehr, https://github.com/kulinseth
diff --git a/aten/src/ATen/native/mps/operations/Activation.mm b/aten/src/ATen/native/mps/operations/Activation.mm
index ee1c3ee..a5dae09 100644
--- a/aten/src/ATen/native/mps/operations/Activation.mm
+++ b/aten/src/ATen/native/mps/operations/Activation.mm
@@ -753,6 +753,50 @@
     return  erfTensor;
 }
 
+MPSGraphTensor* tanh (MPSGraph* mpsGraph, MPSGraphTensor *inputTensor) {
+    // 0.5 * x * (1 + text{Tanh}(sqrt(2 / pi) * (x + 0.044715 * x^3)))
+    auto dataType = [inputTensor dataType];
+    constexpr float kBeta =  M_SQRT2 * M_2_SQRTPI * 0.5;
+    constexpr float kKappa = 0.044715f;
+    MPSGraphTensor *betaf = [mpsGraph constantWithScalar: kBeta
+                                                   shape: @[@1]
+                                                dataType: dataType];
+    MPSGraphTensor *kappaf = [mpsGraph constantWithScalar: kKappa
+                                                    shape: @[@1]
+                                                 dataType: dataType];
+    MPSGraphTensor *onef = [mpsGraph constantWithScalar: 1.0f
+                                                  shape: @[@1]
+                                              dataType: dataType];
+    MPSGraphTensor *halff = [mpsGraph constantWithScalar: 0.5f
+                                                    shape: @[@1]
+                                                dataType: dataType];
+    MPSGraphTensor *erfTensor = [mpsGraph multiplicationWithPrimaryTensor: inputTensor
+                                                          secondaryTensor: inputTensor
+                                                                    name : nil];
+    erfTensor = [mpsGraph multiplicationWithPrimaryTensor: erfTensor
+                                          secondaryTensor: inputTensor
+                                                    name : nil];
+    erfTensor = [mpsGraph multiplicationWithPrimaryTensor: erfTensor
+                                          secondaryTensor: kappaf
+                                                    name : nil];
+    erfTensor = [mpsGraph additionWithPrimaryTensor: erfTensor
+                                    secondaryTensor: inputTensor
+                                              name : nil];
+    erfTensor = [mpsGraph multiplicationWithPrimaryTensor: erfTensor
+                                          secondaryTensor: betaf
+                                                    name : nil];
+    erfTensor = [mpsGraph tanhWithTensor: erfTensor
+                                   name : nil];
+    erfTensor = [mpsGraph additionWithPrimaryTensor: erfTensor
+                                    secondaryTensor: onef
+                                              name : nil];
+    erfTensor = [mpsGraph multiplicationWithPrimaryTensor: erfTensor
+                                          secondaryTensor: halff
+                                                    name : nil];
+
+    return  erfTensor;
+}
+
 TORCH_IMPL_FUNC(gelu_out_mps) (
     const Tensor& self, c10::string_view approximate, const Tensor& output
   ) {
@@ -776,7 +820,7 @@
   MPSStream* stream = getCurrentMPSStream();
 
   @autoreleasepool {
-    string key = "gelu_out_mps" + getTensorsStringKey({self});
+    string key = "gelu_out_mps" + getTensorsStringKey({self}) + ":" + c10::str(approximate);
     CachedGraph* cachedGraph = static_cast<CachedGraph *>(cache_->LookUp(key));
     if(!cachedGraph) {
       MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ MPSCachedGraph * () {
@@ -791,7 +835,12 @@
                                                                   getMPSDataType(self.scalar_type()),
                                                                   getMPSShape(self));
 
-          MPSGraphTensor* outputTensor = normcdf(mpsGraph, inputTensor);
+          MPSGraphTensor* outputTensor = nil;
+          if(approximate == "tanh") {
+            outputTensor = tanh(mpsGraph, inputTensor);
+          } else {
+            outputTensor = normcdf(mpsGraph, inputTensor);
+          }
           outputTensor = [mpsGraph multiplicationWithPrimaryTensor:outputTensor
                                                    secondaryTensor:inputTensor
                                                               name:nil];
@@ -824,7 +873,6 @@
     const Tensor& grad, const Tensor& self, c10::string_view approximate, const Tensor& grad_input
   ) {
   using namespace mps;
-  constexpr float kBeta = M_2_SQRTPI * M_SQRT1_2 * (0.5);
 
   // Empty output
   if(grad_input.numel() == 0)
@@ -843,7 +891,7 @@
   MPSStream* stream = getCurrentMPSStream();
 
   @autoreleasepool {
-    string key = "gelu_backward_out_mps" + getTensorsStringKey({self, grad});
+    string key = "gelu_backward_out_mps" + getTensorsStringKey({self, grad}) + ":" + c10::str(approximate);
     CachedGraph* cachedGraph = static_cast<CachedGraph *>(cache_->LookUp(key));
     if(!cachedGraph) {
       MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ MPSCachedGraph * () {
@@ -861,32 +909,110 @@
           MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph,
                                                                   dataType,
                                                                   getMPSShape(self));
-          MPSGraphTensor* cdf = normcdf(mpsGraph, inputTensor);
-          MPSGraphTensor *halff = [mpsGraph constantWithScalar: -0.5f
-                                                    shape: @[@1]
-                                                dataType: dataType];
-          MPSGraphTensor *betaf = [mpsGraph constantWithScalar :kBeta
-                                                    shape :@[@1]
-                                                dataType:dataType];
-          MPSGraphTensor *pdfMul = [mpsGraph squareWithTensor : inputTensor
-                                                    name : nil];
-          pdfMul = [mpsGraph multiplicationWithPrimaryTensor : pdfMul
-                                          secondaryTensor : halff
-                                                    name : nil];
-          pdfMul = [mpsGraph exponentWithTensor : pdfMul
-                                        name  : nil];
-          MPSGraphTensor* pdf = [mpsGraph multiplicationWithPrimaryTensor : pdfMul
-                                                        secondaryTensor  : betaf
-                                                                  name : nil];
-          pdf = [mpsGraph multiplicationWithPrimaryTensor : inputTensor
-                                          secondaryTensor : pdf
-                                            name : nil];
-          pdf = [mpsGraph additionWithPrimaryTensor : pdf
-                                  secondaryTensor : cdf
-                                      name : nil];
-          MPSGraphTensor* outputTensor = [mpsGraph multiplicationWithPrimaryTensor : gradTensor
-                                                                   secondaryTensor : pdf
-                                                                              name : nil];
+          MPSGraphTensor* outputTensor = nil;
+          if(approximate == "tanh") {
+            constexpr float kBeta = M_SQRT2 * M_2_SQRTPI * (0.5f);
+            constexpr float kKappa = 0.044715f;
+            MPSGraphTensor *betaf = [mpsGraph constantWithScalar: kBeta
+                                                           shape: @[@1]
+                                                        dataType: dataType];
+            MPSGraphTensor *kappaf = [mpsGraph constantWithScalar: kKappa
+                                                            shape: @[@1]
+                                                         dataType: dataType];
+            MPSGraphTensor *halff = [mpsGraph constantWithScalar: 0.5f
+                                                           shape: @[@1]
+                                                        dataType: dataType];
+            MPSGraphTensor *onef = [mpsGraph constantWithScalar: 1.0f
+                                                          shape: @[@1]
+                                                       dataType: dataType];
+            MPSGraphTensor *threef = [mpsGraph constantWithScalar: 3.0f
+                                                            shape: @[@1]
+                                                         dataType: dataType];
+            MPSGraphTensor* x_sq = [mpsGraph multiplicationWithPrimaryTensor: inputTensor
+                                                             secondaryTensor: inputTensor
+                                                                        name: nil];
+            MPSGraphTensor *x_cube =  [mpsGraph multiplicationWithPrimaryTensor: x_sq
+                                                                secondaryTensor: inputTensor
+                                                                           name: nil];
+            MPSGraphTensor *inner = [mpsGraph multiplicationWithPrimaryTensor: kappaf
+                                                              secondaryTensor: x_cube
+                                                                         name: nil];
+            inner = [mpsGraph additionWithPrimaryTensor: inner
+                                        secondaryTensor: inputTensor
+                                                   name: nil];
+            inner = [mpsGraph multiplicationWithPrimaryTensor: betaf
+                                              secondaryTensor: inner
+                                                         name: nil];
+            MPSGraphTensor *tanhInner = [mpsGraph tanhWithTensor: inner
+                                                            name: nil];
+            MPSGraphTensor *left = [mpsGraph multiplicationWithPrimaryTensor: halff
+                                                             secondaryTensor: inputTensor
+                                                                        name: nil];
+            MPSGraphTensor *right = [mpsGraph additionWithPrimaryTensor: onef
+                                                        secondaryTensor: tanhInner
+                                                                   name: nil];
+            MPSGraphTensor *left_derivative = [mpsGraph multiplicationWithPrimaryTensor: halff
+                                                                        secondaryTensor: right
+                                                                                   name: nil];
+            MPSGraphTensor *tanh_derivative = [mpsGraph multiplicationWithPrimaryTensor: tanhInner
+                                                                        secondaryTensor: tanhInner
+                                                                                   name: nil];
+            tanh_derivative = [mpsGraph subtractionWithPrimaryTensor: onef
+                                                     secondaryTensor: tanh_derivative
+                                                                name: nil];
+            MPSGraphTensor *inner_derivative = [mpsGraph multiplicationWithPrimaryTensor: threef
+                                                                         secondaryTensor: kappaf
+                                                                                    name: nil];
+            inner_derivative = [mpsGraph multiplicationWithPrimaryTensor: inner_derivative
+                                                         secondaryTensor: x_sq
+                                                                    name: nil];
+            inner_derivative = [mpsGraph additionWithPrimaryTensor: inner_derivative
+                                                   secondaryTensor: onef
+                                                              name: nil];
+            inner_derivative = [mpsGraph multiplicationWithPrimaryTensor: betaf
+                                                         secondaryTensor: inner_derivative
+                                                                    name: nil];
+            MPSGraphTensor *right_derivative = [mpsGraph multiplicationWithPrimaryTensor: left
+                                                                         secondaryTensor: tanh_derivative
+                                                                                    name: nil];
+            right_derivative = [mpsGraph multiplicationWithPrimaryTensor: right_derivative
+                                                         secondaryTensor: inner_derivative
+                                                                    name: nil];
+            outputTensor = [mpsGraph additionWithPrimaryTensor: left_derivative
+                                               secondaryTensor: right_derivative
+                                                          name: nil];
+            outputTensor = [mpsGraph multiplicationWithPrimaryTensor: gradTensor
+                                                     secondaryTensor: outputTensor
+                                                                name: nil];
+          } else {
+            constexpr float kBeta = M_2_SQRTPI * M_SQRT1_2 * (0.5);
+            MPSGraphTensor *halff = [mpsGraph constantWithScalar: -0.5f
+                                                           shape: @[@1]
+                                                        dataType: dataType];
+            MPSGraphTensor *betaf = [mpsGraph constantWithScalar: kBeta
+                                                           shape: @[@1]
+                                                        dataType: dataType];
+            MPSGraphTensor* cdf = normcdf(mpsGraph, inputTensor);
+            MPSGraphTensor *pdfMul = [mpsGraph squareWithTensor: inputTensor
+                                                           name: nil];
+            pdfMul = [mpsGraph multiplicationWithPrimaryTensor: pdfMul
+                                               secondaryTensor: halff
+                                                          name: nil];
+            pdfMul = [mpsGraph exponentWithTensor: pdfMul
+                                             name: nil];
+            MPSGraphTensor* pdf = [mpsGraph multiplicationWithPrimaryTensor: pdfMul
+                                                            secondaryTensor: betaf
+                                                                       name: nil];
+            pdf = [mpsGraph multiplicationWithPrimaryTensor: inputTensor
+                                            secondaryTensor: pdf
+                                                       name: nil];
+            pdf = [mpsGraph additionWithPrimaryTensor: pdf
+                                      secondaryTensor: cdf
+                                                 name: nil];
+            outputTensor = [mpsGraph multiplicationWithPrimaryTensor: gradTensor
+                                                     secondaryTensor: pdf
+                                                                name: nil];
+          }
 
           newCachedGraph->gradTensor_ = gradTensor;
           newCachedGraph->inputTensor_ = inputTensor;
diff --git a/test/test_mps.py b/test/test_mps.py
index e3329a4..cd40f44 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -5005,6 +5005,17 @@
         finally:
             torch.set_num_threads(num_threads)
 
+    def test_gelu_tanh(self):
+        def helper(shape):
+            cpu_x = torch.randn(shape, device='cpu', dtype=torch.float)
+            x = cpu_x.detach().clone().to('mps')
+
+            gelu_tanh_result = torch.nn.functional.gelu(x, approximate='tanh')
+            gelu_tanh_result_cpu = torch.nn.functional.gelu(cpu_x, approximate='tanh')
+            self.assertEqual(gelu_tanh_result, gelu_tanh_result_cpu)
+
+        helper((2, 8, 4, 5))
+
     # Test hardtanh
     def test_hardtanh(self):
         def helper(shape, min_val, max_val, inplace=False):
@@ -9175,6 +9186,7 @@
         '_native_batch_norm_legit': ['f32'],
         'native_batch_norm': ['f32'],
         'native_layer_norm': ['f32'],
+        'nn.functional.gelu': ['f32'],
     }
 
     # These ops that are problematic. So never run them even when