[MPS] Add Metal implementation of exp op (#128421)
To improve accuracy, use `precise::exp()` (and `precise::sin()`/`precise::cos()` for complex flavor)
Reuse `test_exp1` to check that accuracy of `exp` ops is sometimes closer to CPU
Fix bug in non-contiguous tensors handling
Fixes https://github.com/pytorch/pytorch/issues/84936
Pull Request resolved: https://github.com/pytorch/pytorch/pull/128421
Approved by: https://github.com/kulinseth
ghstack dependencies: #128373, #128375
diff --git a/aten/src/ATen/native/mps/OperationUtils.mm b/aten/src/ATen/native/mps/OperationUtils.mm
index 8170bd0..82d1fe9 100644
--- a/aten/src/ATen/native/mps/OperationUtils.mm
+++ b/aten/src/ATen/native/mps/OperationUtils.mm
@@ -659,6 +659,7 @@
MTLCompileOptions* options = [[MTLCompileOptions new] autorelease];
[options setLanguageVersion:is_macos_13_or_newer(MacOSVersion::MACOS_VER_14_0_PLUS) ? MTLLanguageVersion3_1
: MTLLanguageVersion2_3];
+ // [options setFastMathEnabled: NO];
auto str = [NSString stringWithCString:src.c_str() encoding:NSASCIIStringEncoding];
auto device = MPSDevice::getInstance()->device();
library = [device newLibraryWithSource:str options:options error:&error];
diff --git a/aten/src/ATen/native/mps/UnaryConstants.h b/aten/src/ATen/native/mps/UnaryConstants.h
index b1a92f6..8a9a668 100644
--- a/aten/src/ATen/native/mps/UnaryConstants.h
+++ b/aten/src/ATen/native/mps/UnaryConstants.h
@@ -9,9 +9,9 @@
constant float c[4] = {{-1.970840454, -1.624906493, 3.429567803, 1.641345311}};
constant float d[2] = {{3.543889200, 1.637067800}};
-kernel void erfinv_mps_kernel( device {0} *output [[buffer(0)]],
- device {1} *input [[buffer(1)]],
- uint index [[thread_position_in_grid]]) {{
+kernel void erfinv_kernel( device {0} *output [[buffer(0)]],
+ device {1} *input [[buffer(1)]],
+ uint index [[thread_position_in_grid]]) {{
float y = input[index];
float x, z, num, dem; /*working variables */
@@ -35,4 +35,46 @@
}}
output[index] = {0}(x);
-}})METAL";
+}}
+
+kernel void exp_kernel( device {0} *output [[buffer(0)]],
+ device {1} *input [[ buffer(1)]],
+ uint index [[thread_position_in_grid]]) {{
+ output[index] = {0}(precise::exp(input[index]));
+}}
+
+kernel void exp_complex_kernel( device {0}2 *output [[buffer(0)]],
+ device {0}2 *input [[ buffer(1)]],
+ uint index [[thread_position_in_grid]]) {{
+ output[index].x = {0}(precise::exp(input[index].x)*precise::cos(input[index].y));
+ output[index].y = {0}(precise::exp(input[index].x)*precise::sin(input[index].y));
+}}
+
+kernel void tanh_kernel( device {0} *output [[buffer(0)]],
+ device {1} *input [[ buffer(1)]],
+ uint index [[thread_position_in_grid]]) {{
+ output[index] = {0}(precise::tanh(input[index]));
+}}
+
+
+#if __METAL_VERSION__ >= 310
+bfloat dot(bfloat2 a, bfloat2 b) {{
+ return a.x * b.x + a.y * b.y;
+}}
+#endif
+
+template<typename T>
+T complex_div(T a, T b) {{
+ auto denom = dot(b, b);
+ return T(dot(a, b), a.y * b.x - a.x * b.y)/denom;
+}}
+
+kernel void tanh_complex_kernel( device {0}2 *output [[buffer(0)]],
+ device {0}2 *input [[ buffer(1)]],
+ uint index [[thread_position_in_grid]]) {{
+ //tanh(x+iy)=(tanh(x)+itan(y))/(1+itahnh(x)*tan(y));
+ auto tanh_x = {0}(precise::tanh(input[index].x));
+ auto tan_y = {0}(precise::tan(input[index].y));
+ output[index] = complex_div({0}2(tanh_x, tan_y), {0}2({0}(1), tanh_x * tan_y));
+}}
+)METAL";
diff --git a/aten/src/ATen/native/mps/operations/UnaryKernel.mm b/aten/src/ATen/native/mps/operations/UnaryKernel.mm
index 5c894ef..8e0eda5 100644
--- a/aten/src/ATen/native/mps/operations/UnaryKernel.mm
+++ b/aten/src/ATen/native/mps/operations/UnaryKernel.mm
@@ -8,6 +8,8 @@
#include <ATen/NativeFunctions.h>
#else
#include <ATen/ops/erfinv_native.h>
+#include <ATen/ops/exp_native.h>
+#include <ATen/ops/tanh_native.h>
#endif
#include <fmt/format.h>
@@ -15,14 +17,8 @@
namespace at::native {
static mps::MetalShaderLibrary lib(UNARY_KERNEL_TEMPLATE, 2);
-TORCH_IMPL_FUNC(erfinv_out_mps)(const Tensor& self, const Tensor& output_) {
- // handle erfinv ops using metal kernel
- // erfinv algorithm ported from aten/src/ATen/native/Math.h
- // https://github.com/pytorch/pytorch/blob/4154c8ea159fdaecc71ee9af820ac956193c875b/aten/src/ATen/native/Math.h#L152
-
- TORCH_CHECK(self.scalar_type() != ScalarType::Double, "MPS does not support erfinv op with scalar type: Double");
-
- Tensor inputTensor = self;
+static void exec_unary_kernel(const Tensor& self, const Tensor& output_, const std::string& name) {
+ Tensor inputTensor = self.contiguous();
Tensor outputTensor = output_;
bool needs_output_copy = false;
uint32_t length = output_.numel();
@@ -31,11 +27,16 @@
}
using namespace mps;
@autoreleasepool {
- auto cplState = lib.getPipelineStateForFunc("erfinv_mps_kernel",
- {scalarToMetalTypeString(outputTensor), scalarToMetalTypeString(self)});
+ id<MTLComputePipelineState> cplState = nil;
+ if (c10::isComplexType(self.scalar_type())) {
+ auto scalarStr = self.scalar_type() == kComplexFloat ? "float" : "half";
+ cplState = lib.getPipelineStateForFunc(name + "_complex_kernel", {scalarStr, scalarStr});
+ } else {
+ cplState = lib.getPipelineStateForFunc(name + "_kernel",
+ {scalarToMetalTypeString(outputTensor), scalarToMetalTypeString(self)});
+ }
- if (!self.is_contiguous()) {
- inputTensor = inputTensor.contiguous();
+ if (!outputTensor.is_contiguous()) {
outputTensor = outputTensor.contiguous();
needs_output_copy = true;
}
@@ -44,7 +45,7 @@
dispatch_sync(mpsStream->queue(), ^() {
id<MTLComputeCommandEncoder> computeEncoder = mpsStream->commandEncoder();
- getMPSProfiler().beginProfileKernel(cplState, "erf_inv", {inputTensor});
+ getMPSProfiler().beginProfileKernel(cplState, name, {self});
[computeEncoder setComputePipelineState:cplState];
mtl_setBuffer(computeEncoder, outputTensor, 0);
@@ -58,4 +59,19 @@
output_.copy_(outputTensor);
}
}
+TORCH_IMPL_FUNC(erfinv_out_mps)(const Tensor& self, const Tensor& output_) {
+ // handle erfinv ops using metal kernel
+ // erfinv algorithm ported from aten/src/ATen/native/Math.h
+ // https://github.com/pytorch/pytorch/blob/4154c8ea159fdaecc71ee9af820ac956193c875b/aten/src/ATen/native/Math.h#L152
+
+ TORCH_CHECK(self.scalar_type() != ScalarType::Double, "MPS does not support erfinv op with scalar type: Double");
+ exec_unary_kernel(self, output_, "erfinv");
+}
+
+TORCH_IMPL_FUNC(exp_out_mps)(const Tensor& self, const Tensor& output_) {
+ exec_unary_kernel(self, output_, "exp");
+}
+TORCH_IMPL_FUNC(tanh_out_mps)(const Tensor& self, const Tensor& output_) {
+ exec_unary_kernel(self, output_, "tanh");
+}
} // namespace at::native
diff --git a/aten/src/ATen/native/mps/operations/UnaryOps.mm b/aten/src/ATen/native/mps/operations/UnaryOps.mm
index 46709f2..ec83c60 100644
--- a/aten/src/ATen/native/mps/operations/UnaryOps.mm
+++ b/aten/src/ATen/native/mps/operations/UnaryOps.mm
@@ -26,7 +26,6 @@
#include <ATen/ops/cumsum_native.h>
#include <ATen/ops/erf_native.h>
#include <ATen/ops/exp2_native.h>
-#include <ATen/ops/exp_native.h>
#include <ATen/ops/expm1_native.h>
#include <ATen/ops/floor_native.h>
#include <ATen/ops/frac_native.h>
@@ -54,7 +53,6 @@
#include <ATen/ops/sinh_native.h>
#include <ATen/ops/sqrt_native.h>
#include <ATen/ops/tan_native.h>
-#include <ATen/ops/tanh_native.h>
#include <ATen/ops/trunc_native.h>
#include <ATen/ops/view_as_real.h>
#endif
@@ -236,7 +234,6 @@
}); \
}
-CREATE_MPS_STRUCTURED_UNARY_TORCH_IMPL_FUNC(exp_out_mps, exponent)
CREATE_MPS_STRUCTURED_UNARY_TORCH_IMPL_FUNC(exp2_out_mps, exponentBase2)
CREATE_MPS_STRUCTURED_UNARY_TORCH_IMPL_FUNC(reciprocal_out_mps, reciprocal)
CREATE_MPS_STRUCTURED_UNARY_TORCH_IMPL_FUNC(sqrt_out_mps, squareRoot)
@@ -254,7 +251,6 @@
CREATE_MPS_STRUCTURED_UNARY_TORCH_IMPL_FUNC(atan_out_mps, atan)
CREATE_MPS_STRUCTURED_UNARY_TORCH_IMPL_FUNC(sinh_out_mps, sinh)
CREATE_MPS_STRUCTURED_UNARY_TORCH_IMPL_FUNC(cosh_out_mps, cosh)
-CREATE_MPS_STRUCTURED_UNARY_TORCH_IMPL_FUNC(tanh_out_mps, tanh)
CREATE_MPS_STRUCTURED_UNARY_TORCH_IMPL_FUNC(asinh_out_mps, asinh)
CREATE_MPS_STRUCTURED_UNARY_TORCH_IMPL_FUNC(acosh_out_mps, acosh)
CREATE_MPS_STRUCTURED_UNARY_TORCH_IMPL_FUNC(atanh_out_mps, atanh)
diff --git a/test/test_mps.py b/test/test_mps.py
index c59a598..d141e1a 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -270,6 +270,7 @@
'empty_permuted',
'empty_strided',
'eye',
+ 'exp',
'expand',
'expand_as',
'flatten',
@@ -306,6 +307,7 @@
'nn.functional.conv_transpose2d',
'nn.functional.feature_alpha_dropoutwithout_train',
'nn.functional.padcircular',
+ 'nn.functional.tanhshrink',
'nn.functional.unfold',
'nonzero',
'ones',
@@ -333,6 +335,7 @@
'sub',
'svd',
't',
+ 'tanh',
'tensor_split',
'transpose',
'T',
@@ -389,7 +392,6 @@
'eq',
'equal',
'exp2',
- 'exp',
'expm1',
'fft.fft',
'fft.fft2',
@@ -447,7 +449,6 @@
'nn.functional.pixel_unshuffle',
'nn.functional.rms_norm',
'nn.functional.softsign',
- 'nn.functional.tanhshrink',
'pinverse',
'prod',
'reciprocal',
@@ -465,7 +466,6 @@
'sum',
'sum_to_size',
'tan',
- 'tanh',
'tensordot',
'trace',
'trapz',
@@ -1612,14 +1612,19 @@
class TestMPS(TestCaseMPS):
def test_exp(self, device="mps", dtype=torch.float):
for v in (2, -2) + ((1j, 1 + 1j) if dtype.is_complex else ()):
- b = torch.arange(18, device="cpu") / 3 * math.pi
- a = torch.tensor(v, dtype=dtype, device="cpu") * b
- a = a.to(dtype).to("mps")
+ b = torch.arange(18, dtype=dtype, device=device) / 3 * math.pi
+ a = torch.tensor(v, dtype=dtype, device="mps") * b
self.compare_with_numpy(torch.exp, np.exp, a)
def test_exp1(self, device="mps", dtype=torch.float):
- input = torch.tensor([-0.1, 3.0, -0.9]).to('mps')
- output = torch.exp(input).to('cpu')
+ input = torch.tensor([-0.1, 1.0, -0.9, 0.1], device=device, dtype=dtype)
+ output = torch.exp(input)
+ output_cpu = torch.exp(input.cpu())
+ # If exponentWithTensor: MPS call is used on M1 running 14.5 test will fail with
+ # Mismatched elements: 3 / 4 (75.0%)
+ # Greatest absolute difference: 1.1920928955078125e-07 at index (3,) (up to 1e-08 allowed)
+ # Greatest relative difference: 1.0786502002702036e-07 at index (3,) (up to 1e-08 allowed)
+ self.assertEqual(output, output_cpu, atol=1e-8, rtol=1e-8)
def test_exp_strided_output(self):
x = torch.rand((256, 10), device='mps')