[MPS] Add Metal implementation of exp op (#128421)

To improve accuracy, use `precise::exp()` (and `precise::sin()`/`precise::cos()` for complex flavor)
Reuse `test_exp1` to check that accuracy of `exp` ops is sometimes closer to CPU

Fix bug in non-contiguous tensors handling

Fixes https://github.com/pytorch/pytorch/issues/84936
Pull Request resolved: https://github.com/pytorch/pytorch/pull/128421
Approved by: https://github.com/kulinseth
ghstack dependencies: #128373, #128375
diff --git a/aten/src/ATen/native/mps/OperationUtils.mm b/aten/src/ATen/native/mps/OperationUtils.mm
index 8170bd0..82d1fe9 100644
--- a/aten/src/ATen/native/mps/OperationUtils.mm
+++ b/aten/src/ATen/native/mps/OperationUtils.mm
@@ -659,6 +659,7 @@
   MTLCompileOptions* options = [[MTLCompileOptions new] autorelease];
   [options setLanguageVersion:is_macos_13_or_newer(MacOSVersion::MACOS_VER_14_0_PLUS) ? MTLLanguageVersion3_1
                                                                                       : MTLLanguageVersion2_3];
+  // [options setFastMathEnabled: NO];
   auto str = [NSString stringWithCString:src.c_str() encoding:NSASCIIStringEncoding];
   auto device = MPSDevice::getInstance()->device();
   library = [device newLibraryWithSource:str options:options error:&error];
diff --git a/aten/src/ATen/native/mps/UnaryConstants.h b/aten/src/ATen/native/mps/UnaryConstants.h
index b1a92f6..8a9a668 100644
--- a/aten/src/ATen/native/mps/UnaryConstants.h
+++ b/aten/src/ATen/native/mps/UnaryConstants.h
@@ -9,9 +9,9 @@
 constant float c[4] = {{-1.970840454, -1.624906493, 3.429567803, 1.641345311}};
 constant float d[2] = {{3.543889200, 1.637067800}};
 
-kernel void erfinv_mps_kernel( device {0} *output [[buffer(0)]],
-                            device {1} *input [[buffer(1)]],
-                            uint index [[thread_position_in_grid]]) {{
+kernel void erfinv_kernel( device {0} *output [[buffer(0)]],
+                           device {1} *input [[buffer(1)]],
+                           uint index [[thread_position_in_grid]]) {{
 
   float y = input[index];
   float x, z, num, dem; /*working variables */
@@ -35,4 +35,46 @@
   }}
 
   output[index] = {0}(x);
-}})METAL";
+}}
+
+kernel void exp_kernel( device {0} *output [[buffer(0)]],
+                        device {1} *input [[ buffer(1)]],
+                        uint index [[thread_position_in_grid]]) {{
+  output[index] = {0}(precise::exp(input[index]));
+}}
+
+kernel void exp_complex_kernel( device {0}2 *output [[buffer(0)]],
+                                device {0}2 *input [[ buffer(1)]],
+                                uint index [[thread_position_in_grid]]) {{
+  output[index].x = {0}(precise::exp(input[index].x)*precise::cos(input[index].y));
+  output[index].y = {0}(precise::exp(input[index].x)*precise::sin(input[index].y));
+}}
+
+kernel void tanh_kernel( device {0} *output [[buffer(0)]],
+                        device {1} *input [[ buffer(1)]],
+                        uint index [[thread_position_in_grid]]) {{
+  output[index] = {0}(precise::tanh(input[index]));
+}}
+
+
+#if __METAL_VERSION__ >= 310
+bfloat dot(bfloat2 a, bfloat2 b) {{
+  return a.x * b.x + a.y * b.y;
+}}
+#endif
+
+template<typename T>
+T complex_div(T a, T b) {{
+  auto denom = dot(b, b);
+  return T(dot(a, b), a.y * b.x - a.x * b.y)/denom;
+}}
+
+kernel void tanh_complex_kernel( device {0}2 *output [[buffer(0)]],
+                                 device {0}2 *input [[ buffer(1)]],
+                                 uint index [[thread_position_in_grid]]) {{
+  //tanh(x+iy)=(tanh(x)+itan(y))/(1+itahnh(x)*tan(y));
+  auto tanh_x = {0}(precise::tanh(input[index].x));
+  auto tan_y = {0}(precise::tan(input[index].y));
+  output[index] = complex_div({0}2(tanh_x, tan_y), {0}2({0}(1), tanh_x * tan_y));
+}}
+)METAL";
diff --git a/aten/src/ATen/native/mps/operations/UnaryKernel.mm b/aten/src/ATen/native/mps/operations/UnaryKernel.mm
index 5c894ef..8e0eda5 100644
--- a/aten/src/ATen/native/mps/operations/UnaryKernel.mm
+++ b/aten/src/ATen/native/mps/operations/UnaryKernel.mm
@@ -8,6 +8,8 @@
 #include <ATen/NativeFunctions.h>
 #else
 #include <ATen/ops/erfinv_native.h>
+#include <ATen/ops/exp_native.h>
+#include <ATen/ops/tanh_native.h>
 #endif
 
 #include <fmt/format.h>
@@ -15,14 +17,8 @@
 namespace at::native {
 static mps::MetalShaderLibrary lib(UNARY_KERNEL_TEMPLATE, 2);
 
-TORCH_IMPL_FUNC(erfinv_out_mps)(const Tensor& self, const Tensor& output_) {
-  // handle erfinv ops using metal kernel
-  // erfinv algorithm ported from aten/src/ATen/native/Math.h
-  // https://github.com/pytorch/pytorch/blob/4154c8ea159fdaecc71ee9af820ac956193c875b/aten/src/ATen/native/Math.h#L152
-
-  TORCH_CHECK(self.scalar_type() != ScalarType::Double, "MPS does not support erfinv op with scalar type: Double");
-
-  Tensor inputTensor = self;
+static void exec_unary_kernel(const Tensor& self, const Tensor& output_, const std::string& name) {
+  Tensor inputTensor = self.contiguous();
   Tensor outputTensor = output_;
   bool needs_output_copy = false;
   uint32_t length = output_.numel();
@@ -31,11 +27,16 @@
   }
   using namespace mps;
   @autoreleasepool {
-    auto cplState = lib.getPipelineStateForFunc("erfinv_mps_kernel",
-                                                {scalarToMetalTypeString(outputTensor), scalarToMetalTypeString(self)});
+    id<MTLComputePipelineState> cplState = nil;
+    if (c10::isComplexType(self.scalar_type())) {
+      auto scalarStr = self.scalar_type() == kComplexFloat ? "float" : "half";
+      cplState = lib.getPipelineStateForFunc(name + "_complex_kernel", {scalarStr, scalarStr});
+    } else {
+      cplState = lib.getPipelineStateForFunc(name + "_kernel",
+                                             {scalarToMetalTypeString(outputTensor), scalarToMetalTypeString(self)});
+    }
 
-    if (!self.is_contiguous()) {
-      inputTensor = inputTensor.contiguous();
+    if (!outputTensor.is_contiguous()) {
       outputTensor = outputTensor.contiguous();
       needs_output_copy = true;
     }
@@ -44,7 +45,7 @@
     dispatch_sync(mpsStream->queue(), ^() {
       id<MTLComputeCommandEncoder> computeEncoder = mpsStream->commandEncoder();
 
-      getMPSProfiler().beginProfileKernel(cplState, "erf_inv", {inputTensor});
+      getMPSProfiler().beginProfileKernel(cplState, name, {self});
 
       [computeEncoder setComputePipelineState:cplState];
       mtl_setBuffer(computeEncoder, outputTensor, 0);
@@ -58,4 +59,19 @@
     output_.copy_(outputTensor);
   }
 }
+TORCH_IMPL_FUNC(erfinv_out_mps)(const Tensor& self, const Tensor& output_) {
+  // handle erfinv ops using metal kernel
+  // erfinv algorithm ported from aten/src/ATen/native/Math.h
+  // https://github.com/pytorch/pytorch/blob/4154c8ea159fdaecc71ee9af820ac956193c875b/aten/src/ATen/native/Math.h#L152
+
+  TORCH_CHECK(self.scalar_type() != ScalarType::Double, "MPS does not support erfinv op with scalar type: Double");
+  exec_unary_kernel(self, output_, "erfinv");
+}
+
+TORCH_IMPL_FUNC(exp_out_mps)(const Tensor& self, const Tensor& output_) {
+  exec_unary_kernel(self, output_, "exp");
+}
+TORCH_IMPL_FUNC(tanh_out_mps)(const Tensor& self, const Tensor& output_) {
+  exec_unary_kernel(self, output_, "tanh");
+}
 } // namespace at::native
diff --git a/aten/src/ATen/native/mps/operations/UnaryOps.mm b/aten/src/ATen/native/mps/operations/UnaryOps.mm
index 46709f2..ec83c60 100644
--- a/aten/src/ATen/native/mps/operations/UnaryOps.mm
+++ b/aten/src/ATen/native/mps/operations/UnaryOps.mm
@@ -26,7 +26,6 @@
 #include <ATen/ops/cumsum_native.h>
 #include <ATen/ops/erf_native.h>
 #include <ATen/ops/exp2_native.h>
-#include <ATen/ops/exp_native.h>
 #include <ATen/ops/expm1_native.h>
 #include <ATen/ops/floor_native.h>
 #include <ATen/ops/frac_native.h>
@@ -54,7 +53,6 @@
 #include <ATen/ops/sinh_native.h>
 #include <ATen/ops/sqrt_native.h>
 #include <ATen/ops/tan_native.h>
-#include <ATen/ops/tanh_native.h>
 #include <ATen/ops/trunc_native.h>
 #include <ATen/ops/view_as_real.h>
 #endif
@@ -236,7 +234,6 @@
     });                                                                                                          \
   }
 
-CREATE_MPS_STRUCTURED_UNARY_TORCH_IMPL_FUNC(exp_out_mps, exponent)
 CREATE_MPS_STRUCTURED_UNARY_TORCH_IMPL_FUNC(exp2_out_mps, exponentBase2)
 CREATE_MPS_STRUCTURED_UNARY_TORCH_IMPL_FUNC(reciprocal_out_mps, reciprocal)
 CREATE_MPS_STRUCTURED_UNARY_TORCH_IMPL_FUNC(sqrt_out_mps, squareRoot)
@@ -254,7 +251,6 @@
 CREATE_MPS_STRUCTURED_UNARY_TORCH_IMPL_FUNC(atan_out_mps, atan)
 CREATE_MPS_STRUCTURED_UNARY_TORCH_IMPL_FUNC(sinh_out_mps, sinh)
 CREATE_MPS_STRUCTURED_UNARY_TORCH_IMPL_FUNC(cosh_out_mps, cosh)
-CREATE_MPS_STRUCTURED_UNARY_TORCH_IMPL_FUNC(tanh_out_mps, tanh)
 CREATE_MPS_STRUCTURED_UNARY_TORCH_IMPL_FUNC(asinh_out_mps, asinh)
 CREATE_MPS_STRUCTURED_UNARY_TORCH_IMPL_FUNC(acosh_out_mps, acosh)
 CREATE_MPS_STRUCTURED_UNARY_TORCH_IMPL_FUNC(atanh_out_mps, atanh)
diff --git a/test/test_mps.py b/test/test_mps.py
index c59a598..d141e1a 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -270,6 +270,7 @@
         'empty_permuted',
         'empty_strided',
         'eye',
+        'exp',
         'expand',
         'expand_as',
         'flatten',
@@ -306,6 +307,7 @@
         'nn.functional.conv_transpose2d',
         'nn.functional.feature_alpha_dropoutwithout_train',
         'nn.functional.padcircular',
+        'nn.functional.tanhshrink',
         'nn.functional.unfold',
         'nonzero',
         'ones',
@@ -333,6 +335,7 @@
         'sub',
         'svd',
         't',
+        'tanh',
         'tensor_split',
         'transpose',
         'T',
@@ -389,7 +392,6 @@
         'eq',
         'equal',
         'exp2',
-        'exp',
         'expm1',
         'fft.fft',
         'fft.fft2',
@@ -447,7 +449,6 @@
         'nn.functional.pixel_unshuffle',
         'nn.functional.rms_norm',
         'nn.functional.softsign',
-        'nn.functional.tanhshrink',
         'pinverse',
         'prod',
         'reciprocal',
@@ -465,7 +466,6 @@
         'sum',
         'sum_to_size',
         'tan',
-        'tanh',
         'tensordot',
         'trace',
         'trapz',
@@ -1612,14 +1612,19 @@
 class TestMPS(TestCaseMPS):
     def test_exp(self, device="mps", dtype=torch.float):
         for v in (2, -2) + ((1j, 1 + 1j) if dtype.is_complex else ()):
-            b = torch.arange(18, device="cpu") / 3 * math.pi
-            a = torch.tensor(v, dtype=dtype, device="cpu") * b
-            a = a.to(dtype).to("mps")
+            b = torch.arange(18, dtype=dtype, device=device) / 3 * math.pi
+            a = torch.tensor(v, dtype=dtype, device="mps") * b
             self.compare_with_numpy(torch.exp, np.exp, a)
 
     def test_exp1(self, device="mps", dtype=torch.float):
-        input = torch.tensor([-0.1, 3.0, -0.9]).to('mps')
-        output = torch.exp(input).to('cpu')
+        input = torch.tensor([-0.1, 1.0, -0.9, 0.1], device=device, dtype=dtype)
+        output = torch.exp(input)
+        output_cpu = torch.exp(input.cpu())
+        # If exponentWithTensor: MPS call is used on M1 running 14.5 test will fail with
+        # Mismatched elements: 3 / 4 (75.0%)
+        # Greatest absolute difference: 1.1920928955078125e-07 at index (3,) (up to 1e-08 allowed)
+        # Greatest relative difference: 1.0786502002702036e-07 at index (3,) (up to 1e-08 allowed)
+        self.assertEqual(output, output_cpu, atol=1e-8, rtol=1e-8)
 
     def test_exp_strided_output(self):
         x = torch.rand((256, 10), device='mps')