Update descriptor fields to resolve fft precision issue (#125328)
Fixes #124096
Pull Request resolved: https://github.com/pytorch/pytorch/pull/125328
Approved by: https://github.com/kulinseth, https://github.com/malfet
diff --git a/test/test_mps.py b/test/test_mps.py
index cbf8874..ed001d1 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -387,6 +387,12 @@
'fft.ifft2',
'fft.ifftn',
'fft.ifftshift',
+ 'fft.irfftn',
+ 'fft.irfft2',
+ 'fft.irfft',
+ 'fft.hfftn',
+ 'fft.hfft2',
+ 'fft.hfft',
'flip',
'fliplr',
'flipud',
@@ -653,8 +659,6 @@
'log_sigmoid_forward': None,
'linalg.eig': None,
'linalg.eigvals': None,
- 'fft.hfft2': None,
- 'fft.hfftn': None,
'put': None,
'nn.functional.conv_transpose3d': None,
'rounddecimals_neg_3': None,
@@ -895,6 +899,8 @@
'fft.fft2': None,
'fft.fftn': None,
'fft.hfft': None,
+ 'fft.hfft2': None,
+ 'fft.hfftn': None,
'fft.ifft': None,
'fft.ifft2': None,
'fft.ifftn': None,
@@ -2638,6 +2644,19 @@
# Regression test for https://github.com/pytorch/pytorch/issues/96113
torch.nn.LayerNorm((16,), elementwise_affine=True).to("mps")(torch.randn(1, 2, 16).to("mps", dtype=torch.float16))
+ @xfailIf(product_version < 14.0)
+ def test_ifft(self):
+ # See: https://github.com/pytorch/pytorch/issues/124096
+ device = torch.device("mps")
+
+ N = 64
+ signal = torch.rand(N, device=device)
+ fft_result = torch.fft.rfft(signal)
+ ifft_result = torch.fft.irfft(fft_result, n=signal.shape[0])
+
+ # Expecting the inverted to yield the original signal
+ self.assertEqual(ifft_result, signal)
+
def test_instance_norm(self):
def helper(shape, eps=1, momentum=0.1, wts=False, channels_last=False, track_running_stats=True, test_module=False):