[MPS] Add Metal implementation of exp op (#128421)
To improve accuracy, use `precise::exp()` (and `precise::sin()`/`precise::cos()` for complex flavor)
Reuse `test_exp1` to check that accuracy of `exp` ops is sometimes closer to CPU
Fix bug in non-contiguous tensors handling
Fixes https://github.com/pytorch/pytorch/issues/84936
Pull Request resolved: https://github.com/pytorch/pytorch/pull/128421
Approved by: https://github.com/kulinseth
ghstack dependencies: #128373, #128375
diff --git a/test/test_mps.py b/test/test_mps.py
index c59a598..d141e1a 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -270,6 +270,7 @@
'empty_permuted',
'empty_strided',
'eye',
+ 'exp',
'expand',
'expand_as',
'flatten',
@@ -306,6 +307,7 @@
'nn.functional.conv_transpose2d',
'nn.functional.feature_alpha_dropoutwithout_train',
'nn.functional.padcircular',
+ 'nn.functional.tanhshrink',
'nn.functional.unfold',
'nonzero',
'ones',
@@ -333,6 +335,7 @@
'sub',
'svd',
't',
+ 'tanh',
'tensor_split',
'transpose',
'T',
@@ -389,7 +392,6 @@
'eq',
'equal',
'exp2',
- 'exp',
'expm1',
'fft.fft',
'fft.fft2',
@@ -447,7 +449,6 @@
'nn.functional.pixel_unshuffle',
'nn.functional.rms_norm',
'nn.functional.softsign',
- 'nn.functional.tanhshrink',
'pinverse',
'prod',
'reciprocal',
@@ -465,7 +466,6 @@
'sum',
'sum_to_size',
'tan',
- 'tanh',
'tensordot',
'trace',
'trapz',
@@ -1612,14 +1612,19 @@
class TestMPS(TestCaseMPS):
def test_exp(self, device="mps", dtype=torch.float):
for v in (2, -2) + ((1j, 1 + 1j) if dtype.is_complex else ()):
- b = torch.arange(18, device="cpu") / 3 * math.pi
- a = torch.tensor(v, dtype=dtype, device="cpu") * b
- a = a.to(dtype).to("mps")
+ b = torch.arange(18, dtype=dtype, device=device) / 3 * math.pi
+ a = torch.tensor(v, dtype=dtype, device="mps") * b
self.compare_with_numpy(torch.exp, np.exp, a)
def test_exp1(self, device="mps", dtype=torch.float):
- input = torch.tensor([-0.1, 3.0, -0.9]).to('mps')
- output = torch.exp(input).to('cpu')
+ input = torch.tensor([-0.1, 1.0, -0.9, 0.1], device=device, dtype=dtype)
+ output = torch.exp(input)
+ output_cpu = torch.exp(input.cpu())
+ # If exponentWithTensor: MPS call is used on M1 running 14.5 test will fail with
+ # Mismatched elements: 3 / 4 (75.0%)
+ # Greatest absolute difference: 1.1920928955078125e-07 at index (3,) (up to 1e-08 allowed)
+ # Greatest relative difference: 1.0786502002702036e-07 at index (3,) (up to 1e-08 allowed)
+ self.assertEqual(output, output_cpu, atol=1e-8, rtol=1e-8)
def test_exp_strided_output(self):
x = torch.rand((256, 10), device='mps')