| # Owner(s): ["module: fft"] |
| |
| import torch |
| import unittest |
| import math |
| from contextlib import contextmanager |
| from itertools import product |
| import itertools |
| import doctest |
| import inspect |
| |
| from torch.testing._internal.common_utils import \ |
| (TestCase, run_tests, TEST_NUMPY, TEST_LIBROSA, TEST_MKL, first_sample, TEST_WITH_ROCM, |
| make_tensor, skipIfTorchDynamo) |
| from torch.testing._internal.common_device_type import \ |
| (instantiate_device_type_tests, ops, dtypes, onlyNativeDeviceTypes, |
| skipCPUIfNoFFT, deviceCountAtLeast, onlyCUDA, OpDTypes, skipIf, toleranceOverride, tol) |
| from torch.testing._internal.common_methods_invocations import ( |
| spectral_funcs, SpectralFuncType) |
| from torch.testing._internal.common_cuda import SM53OrLater |
| from torch._prims_common import corresponding_complex_dtype |
| |
| from typing import Optional, List |
| from packaging import version |
| |
| |
| if TEST_NUMPY: |
| import numpy as np |
| |
| |
| if TEST_LIBROSA: |
| import librosa |
| |
| has_scipy_fft = False |
| try: |
| import scipy.fft |
| has_scipy_fft = True |
| except ModuleNotFoundError: |
| pass |
| |
| REFERENCE_NORM_MODES = ( |
| (None, "forward", "backward", "ortho") |
| if version.parse(np.__version__) >= version.parse('1.20.0') and ( |
| not has_scipy_fft or version.parse(scipy.__version__) >= version.parse('1.6.0')) |
| else (None, "ortho")) |
| |
| |
| def _complex_stft(x, *args, **kwargs): |
| # Transform real and imaginary components separably |
| stft_real = torch.stft(x.real, *args, **kwargs, return_complex=True, onesided=False) |
| stft_imag = torch.stft(x.imag, *args, **kwargs, return_complex=True, onesided=False) |
| return stft_real + 1j * stft_imag |
| |
| |
| def _hermitian_conj(x, dim): |
| """Returns the hermitian conjugate along a single dimension |
| |
| H(x)[i] = conj(x[-i]) |
| """ |
| out = torch.empty_like(x) |
| mid = (x.size(dim) - 1) // 2 |
| idx = [slice(None)] * out.dim() |
| idx_center = list(idx) |
| idx_center[dim] = 0 |
| out[idx] = x[idx] |
| |
| idx_neg = list(idx) |
| idx_neg[dim] = slice(-mid, None) |
| idx_pos = idx |
| idx_pos[dim] = slice(1, mid + 1) |
| |
| out[idx_pos] = x[idx_neg].flip(dim) |
| out[idx_neg] = x[idx_pos].flip(dim) |
| if (2 * mid + 1 < x.size(dim)): |
| idx[dim] = mid + 1 |
| out[idx] = x[idx] |
| return out.conj() |
| |
| |
| def _complex_istft(x, *args, **kwargs): |
| # Decompose into Hermitian (FFT of real) and anti-Hermitian (FFT of imaginary) |
| n_fft = x.size(-2) |
| slc = (Ellipsis, slice(None, n_fft // 2 + 1), slice(None)) |
| |
| hconj = _hermitian_conj(x, dim=-2) |
| x_hermitian = (x + hconj) / 2 |
| x_antihermitian = (x - hconj) / 2 |
| istft_real = torch.istft(x_hermitian[slc], *args, **kwargs, onesided=True) |
| istft_imag = torch.istft(-1j * x_antihermitian[slc], *args, **kwargs, onesided=True) |
| return torch.complex(istft_real, istft_imag) |
| |
| |
| def _stft_reference(x, hop_length, window): |
| r"""Reference stft implementation |
| |
| This doesn't implement all of torch.stft, only the STFT definition: |
| |
| .. math:: X(m, \omega) = \sum_n x[n]w[n - m] e^{-jn\omega} |
| |
| """ |
| n_fft = window.numel() |
| X = torch.empty((n_fft, (x.numel() - n_fft + hop_length) // hop_length), |
| device=x.device, dtype=torch.cdouble) |
| for m in range(X.size(1)): |
| start = m * hop_length |
| if start + n_fft > x.numel(): |
| slc = torch.empty(n_fft, device=x.device, dtype=x.dtype) |
| tmp = x[start:] |
| slc[:tmp.numel()] = tmp |
| else: |
| slc = x[start: start + n_fft] |
| X[:, m] = torch.fft.fft(slc * window) |
| return X |
| |
| |
| def skip_helper_for_fft(device, dtype): |
| device_type = torch.device(device).type |
| if dtype not in (torch.half, torch.complex32): |
| return |
| |
| if device_type == 'cpu': |
| raise unittest.SkipTest("half and complex32 are not supported on CPU") |
| if not SM53OrLater: |
| raise unittest.SkipTest("half and complex32 are only supported on CUDA device with SM>53") |
| |
| |
| # Tests of functions related to Fourier analysis in the torch.fft namespace |
| class TestFFT(TestCase): |
| exact_dtype = True |
| |
| @onlyNativeDeviceTypes |
| @ops([op for op in spectral_funcs if op.ndimensional == SpectralFuncType.OneD], |
| allowed_dtypes=(torch.float, torch.cfloat)) |
| def test_reference_1d(self, device, dtype, op): |
| if op.ref is None: |
| raise unittest.SkipTest("No reference implementation") |
| |
| norm_modes = REFERENCE_NORM_MODES |
| test_args = [ |
| *product( |
| # input |
| (torch.randn(67, device=device, dtype=dtype), |
| torch.randn(80, device=device, dtype=dtype), |
| torch.randn(12, 14, device=device, dtype=dtype), |
| torch.randn(9, 6, 3, device=device, dtype=dtype)), |
| # n |
| (None, 50, 6), |
| # dim |
| (-1, 0), |
| # norm |
| norm_modes |
| ), |
| # Test transforming middle dimensions of multi-dim tensor |
| *product( |
| (torch.randn(4, 5, 6, 7, device=device, dtype=dtype),), |
| (None,), |
| (1, 2, -2,), |
| norm_modes |
| ) |
| ] |
| |
| for iargs in test_args: |
| args = list(iargs) |
| input = args[0] |
| args = args[1:] |
| |
| expected = op.ref(input.cpu().numpy(), *args) |
| exact_dtype = dtype in (torch.double, torch.complex128) |
| actual = op(input, *args) |
| self.assertEqual(actual, expected, exact_dtype=exact_dtype) |
| |
| @skipCPUIfNoFFT |
| @onlyNativeDeviceTypes |
| @toleranceOverride({ |
| torch.half : tol(1e-2, 1e-2), |
| torch.chalf : tol(1e-2, 1e-2), |
| }) |
| @dtypes(torch.half, torch.float, torch.double, torch.complex32, torch.complex64, torch.complex128) |
| def test_fft_round_trip(self, device, dtype): |
| skip_helper_for_fft(device, dtype) |
| # Test that round trip through ifft(fft(x)) is the identity |
| if dtype not in (torch.half, torch.complex32): |
| test_args = list(product( |
| # input |
| (torch.randn(67, device=device, dtype=dtype), |
| torch.randn(80, device=device, dtype=dtype), |
| torch.randn(12, 14, device=device, dtype=dtype), |
| torch.randn(9, 6, 3, device=device, dtype=dtype)), |
| # dim |
| (-1, 0), |
| # norm |
| (None, "forward", "backward", "ortho") |
| )) |
| else: |
| # cuFFT supports powers of 2 for half and complex half precision |
| test_args = list(product( |
| # input |
| (torch.randn(64, device=device, dtype=dtype), |
| torch.randn(128, device=device, dtype=dtype), |
| torch.randn(4, 16, device=device, dtype=dtype), |
| torch.randn(8, 6, 2, device=device, dtype=dtype)), |
| # dim |
| (-1, 0), |
| # norm |
| (None, "forward", "backward", "ortho") |
| )) |
| |
| fft_functions = [(torch.fft.fft, torch.fft.ifft)] |
| # Real-only functions |
| if not dtype.is_complex: |
| # NOTE: Using ihfft as "forward" transform to avoid needing to |
| # generate true half-complex input |
| fft_functions += [(torch.fft.rfft, torch.fft.irfft), |
| (torch.fft.ihfft, torch.fft.hfft)] |
| |
| for forward, backward in fft_functions: |
| for x, dim, norm in test_args: |
| kwargs = { |
| 'n': x.size(dim), |
| 'dim': dim, |
| 'norm': norm, |
| } |
| |
| y = backward(forward(x, **kwargs), **kwargs) |
| if x.dtype is torch.half and y.dtype is torch.complex32: |
| # Since type promotion currently doesn't work with complex32 |
| # manually promote `x` to complex32 |
| x = x.to(torch.complex32) |
| # For real input, ifft(fft(x)) will convert to complex |
| self.assertEqual(x, y, exact_dtype=( |
| forward != torch.fft.fft or x.is_complex())) |
| |
| # Note: NumPy will throw a ValueError for an empty input |
| @onlyNativeDeviceTypes |
| @ops(spectral_funcs, allowed_dtypes=(torch.half, torch.float, torch.complex32, torch.cfloat)) |
| def test_empty_fft(self, device, dtype, op): |
| t = torch.empty(1, 0, device=device, dtype=dtype) |
| match = r"Invalid number of data points \([-\d]*\) specified" |
| |
| with self.assertRaisesRegex(RuntimeError, match): |
| op(t) |
| |
| @onlyNativeDeviceTypes |
| def test_empty_ifft(self, device): |
| t = torch.empty(2, 1, device=device, dtype=torch.complex64) |
| match = r"Invalid number of data points \([-\d]*\) specified" |
| |
| for f in [torch.fft.irfft, torch.fft.irfft2, torch.fft.irfftn, |
| torch.fft.hfft, torch.fft.hfft2, torch.fft.hfftn]: |
| with self.assertRaisesRegex(RuntimeError, match): |
| f(t) |
| |
| @onlyNativeDeviceTypes |
| def test_fft_invalid_dtypes(self, device): |
| t = torch.randn(64, device=device, dtype=torch.complex128) |
| |
| with self.assertRaisesRegex(RuntimeError, "rfft expects a real input tensor"): |
| torch.fft.rfft(t) |
| |
| with self.assertRaisesRegex(RuntimeError, "rfftn expects a real-valued input tensor"): |
| torch.fft.rfftn(t) |
| |
| with self.assertRaisesRegex(RuntimeError, "ihfft expects a real input tensor"): |
| torch.fft.ihfft(t) |
| |
| @skipCPUIfNoFFT |
| @onlyNativeDeviceTypes |
| @dtypes(torch.int8, torch.half, torch.float, torch.double, |
| torch.complex32, torch.complex64, torch.complex128) |
| def test_fft_type_promotion(self, device, dtype): |
| skip_helper_for_fft(device, dtype) |
| |
| if dtype.is_complex or dtype.is_floating_point: |
| t = torch.randn(64, device=device, dtype=dtype) |
| else: |
| t = torch.randint(-2, 2, (64,), device=device, dtype=dtype) |
| |
| PROMOTION_MAP = { |
| torch.int8: torch.complex64, |
| torch.half: torch.complex32, |
| torch.float: torch.complex64, |
| torch.double: torch.complex128, |
| torch.complex32: torch.complex32, |
| torch.complex64: torch.complex64, |
| torch.complex128: torch.complex128, |
| } |
| T = torch.fft.fft(t) |
| self.assertEqual(T.dtype, PROMOTION_MAP[dtype]) |
| |
| PROMOTION_MAP_C2R = { |
| torch.int8: torch.float, |
| torch.half: torch.half, |
| torch.float: torch.float, |
| torch.double: torch.double, |
| torch.complex32: torch.half, |
| torch.complex64: torch.float, |
| torch.complex128: torch.double, |
| } |
| if dtype in (torch.half, torch.complex32): |
| # cuFFT supports powers of 2 for half and complex half precision |
| # NOTE: With hfft and default args where output_size n=2*(input_size - 1), |
| # we make sure that logical fft size is a power of two. |
| x = torch.randn(65, device=device, dtype=dtype) |
| R = torch.fft.hfft(x) |
| else: |
| R = torch.fft.hfft(t) |
| self.assertEqual(R.dtype, PROMOTION_MAP_C2R[dtype]) |
| |
| if not dtype.is_complex: |
| PROMOTION_MAP_R2C = { |
| torch.int8: torch.complex64, |
| torch.half: torch.complex32, |
| torch.float: torch.complex64, |
| torch.double: torch.complex128, |
| } |
| C = torch.fft.rfft(t) |
| self.assertEqual(C.dtype, PROMOTION_MAP_R2C[dtype]) |
| |
| @onlyNativeDeviceTypes |
| @ops(spectral_funcs, dtypes=OpDTypes.unsupported, |
| allowed_dtypes=[torch.half, torch.bfloat16]) |
| def test_fft_half_and_bfloat16_errors(self, device, dtype, op): |
| # TODO: Remove torch.half error when complex32 is fully implemented |
| sample = first_sample(self, op.sample_inputs(device, dtype)) |
| device_type = torch.device(device).type |
| default_msg = "Unsupported dtype" |
| if dtype is torch.half and device_type == 'cuda' and TEST_WITH_ROCM: |
| err_msg = default_msg |
| elif dtype is torch.half and device_type == 'cuda' and not SM53OrLater: |
| err_msg = "cuFFT doesn't support signals of half type with compute capability less than SM_53" |
| else: |
| err_msg = default_msg |
| with self.assertRaisesRegex(RuntimeError, err_msg): |
| op(sample.input, *sample.args, **sample.kwargs) |
| |
| @onlyNativeDeviceTypes |
| @ops(spectral_funcs, allowed_dtypes=(torch.half, torch.chalf)) |
| def test_fft_half_and_chalf_not_power_of_two_error(self, device, dtype, op): |
| t = make_tensor(13, 13, device=device, dtype=dtype) |
| err_msg = "cuFFT only supports dimensions whose sizes are powers of two" |
| with self.assertRaisesRegex(RuntimeError, err_msg): |
| op(t) |
| |
| if op.ndimensional in (SpectralFuncType.ND, SpectralFuncType.TwoD): |
| kwargs = {'s': (12, 12)} |
| else: |
| kwargs = {'n': 12} |
| |
| with self.assertRaisesRegex(RuntimeError, err_msg): |
| op(t, **kwargs) |
| |
| # nd-fft tests |
| @onlyNativeDeviceTypes |
| @unittest.skipIf(not TEST_NUMPY, 'NumPy not found') |
| @ops([op for op in spectral_funcs if op.ndimensional == SpectralFuncType.ND], |
| allowed_dtypes=(torch.cfloat, torch.cdouble)) |
| def test_reference_nd(self, device, dtype, op): |
| if op.ref is None: |
| raise unittest.SkipTest("No reference implementation") |
| |
| norm_modes = REFERENCE_NORM_MODES |
| |
| # input_ndim, s, dim |
| transform_desc = [ |
| *product(range(2, 5), (None,), (None, (0,), (0, -1))), |
| *product(range(2, 5), (None, (4, 10)), (None,)), |
| (6, None, None), |
| (5, None, (1, 3, 4)), |
| (3, None, (1,)), |
| (1, None, (0,)), |
| (4, (10, 10), None), |
| (4, (10, 10), (0, 1)) |
| ] |
| |
| for input_ndim, s, dim in transform_desc: |
| shape = itertools.islice(itertools.cycle(range(4, 9)), input_ndim) |
| input = torch.randn(*shape, device=device, dtype=dtype) |
| |
| for norm in norm_modes: |
| expected = op.ref(input.cpu().numpy(), s, dim, norm) |
| exact_dtype = dtype in (torch.double, torch.complex128) |
| actual = op(input, s, dim, norm) |
| self.assertEqual(actual, expected, exact_dtype=exact_dtype) |
| |
| @skipCPUIfNoFFT |
| @onlyNativeDeviceTypes |
| @toleranceOverride({ |
| torch.half : tol(1e-2, 1e-2), |
| torch.chalf : tol(1e-2, 1e-2), |
| }) |
| @dtypes(torch.half, torch.float, torch.double, |
| torch.complex32, torch.complex64, torch.complex128) |
| def test_fftn_round_trip(self, device, dtype): |
| skip_helper_for_fft(device, dtype) |
| |
| norm_modes = (None, "forward", "backward", "ortho") |
| |
| # input_ndim, dim |
| transform_desc = [ |
| *product(range(2, 5), (None, (0,), (0, -1))), |
| (7, None), |
| (5, (1, 3, 4)), |
| (3, (1,)), |
| (1, 0), |
| ] |
| |
| fft_functions = [(torch.fft.fftn, torch.fft.ifftn)] |
| |
| # Real-only functions |
| if not dtype.is_complex: |
| # NOTE: Using ihfftn as "forward" transform to avoid needing to |
| # generate true half-complex input |
| fft_functions += [(torch.fft.rfftn, torch.fft.irfftn), |
| (torch.fft.ihfftn, torch.fft.hfftn)] |
| |
| for input_ndim, dim in transform_desc: |
| if dtype in (torch.half, torch.complex32): |
| # cuFFT supports powers of 2 for half and complex half precision |
| shape = itertools.islice(itertools.cycle((2, 4, 8)), input_ndim) |
| else: |
| shape = itertools.islice(itertools.cycle(range(4, 9)), input_ndim) |
| x = torch.randn(*shape, device=device, dtype=dtype) |
| |
| for (forward, backward), norm in product(fft_functions, norm_modes): |
| if isinstance(dim, tuple): |
| s = [x.size(d) for d in dim] |
| else: |
| s = x.size() if dim is None else x.size(dim) |
| |
| kwargs = {'s': s, 'dim': dim, 'norm': norm} |
| y = backward(forward(x, **kwargs), **kwargs) |
| # For real input, ifftn(fftn(x)) will convert to complex |
| if x.dtype is torch.half and y.dtype is torch.chalf: |
| # Since type promotion currently doesn't work with complex32 |
| # manually promote `x` to complex32 |
| self.assertEqual(x.to(torch.chalf), y) |
| else: |
| self.assertEqual(x, y, exact_dtype=( |
| forward != torch.fft.fftn or x.is_complex())) |
| |
| @onlyNativeDeviceTypes |
| @ops([op for op in spectral_funcs if op.ndimensional == SpectralFuncType.ND], |
| allowed_dtypes=[torch.float, torch.cfloat]) |
| def test_fftn_invalid(self, device, dtype, op): |
| a = torch.rand(10, 10, 10, device=device, dtype=dtype) |
| # FIXME: https://github.com/pytorch/pytorch/issues/108205 |
| errMsg = "dims must be unique" |
| with self.assertRaisesRegex(RuntimeError, errMsg): |
| op(a, dim=(0, 1, 0)) |
| |
| with self.assertRaisesRegex(RuntimeError, errMsg): |
| op(a, dim=(2, -1)) |
| |
| with self.assertRaisesRegex(RuntimeError, "dim and shape .* same length"): |
| op(a, s=(1,), dim=(0, 1)) |
| |
| with self.assertRaisesRegex(IndexError, "Dimension out of range"): |
| op(a, dim=(3,)) |
| |
| with self.assertRaisesRegex(RuntimeError, "tensor only has 3 dimensions"): |
| op(a, s=(10, 10, 10, 10)) |
| |
| @skipCPUIfNoFFT |
| @onlyNativeDeviceTypes |
| @dtypes(torch.half, torch.float, torch.double, torch.cfloat, torch.cdouble) |
| def test_fftn_noop_transform(self, device, dtype): |
| skip_helper_for_fft(device, dtype) |
| RESULT_TYPE = { |
| torch.half: torch.chalf, |
| torch.float: torch.cfloat, |
| torch.double: torch.cdouble, |
| } |
| |
| for op in [ |
| torch.fft.fftn, |
| torch.fft.ifftn, |
| torch.fft.fft2, |
| torch.fft.ifft2, |
| ]: |
| inp = make_tensor((10, 10), device=device, dtype=dtype) |
| out = torch.fft.fftn(inp, dim=[]) |
| |
| expect_dtype = RESULT_TYPE.get(inp.dtype, inp.dtype) |
| expect = inp.to(expect_dtype) |
| self.assertEqual(expect, out) |
| |
| |
| @skipCPUIfNoFFT |
| @onlyNativeDeviceTypes |
| @toleranceOverride({ |
| torch.half : tol(1e-2, 1e-2), |
| }) |
| @dtypes(torch.half, torch.float, torch.double) |
| def test_hfftn(self, device, dtype): |
| skip_helper_for_fft(device, dtype) |
| |
| # input_ndim, dim |
| transform_desc = [ |
| *product(range(2, 5), (None, (0,), (0, -1))), |
| (6, None), |
| (5, (1, 3, 4)), |
| (3, (1,)), |
| (1, (0,)), |
| (4, (0, 1)) |
| ] |
| |
| for input_ndim, dim in transform_desc: |
| actual_dims = list(range(input_ndim)) if dim is None else dim |
| if dtype is torch.half: |
| shape = tuple(itertools.islice(itertools.cycle((2, 4, 8)), input_ndim)) |
| else: |
| shape = tuple(itertools.islice(itertools.cycle(range(4, 9)), input_ndim)) |
| expect = torch.randn(*shape, device=device, dtype=dtype) |
| input = torch.fft.ifftn(expect, dim=dim, norm="ortho") |
| |
| lastdim = actual_dims[-1] |
| lastdim_size = input.size(lastdim) // 2 + 1 |
| idx = [slice(None)] * input_ndim |
| idx[lastdim] = slice(0, lastdim_size) |
| input = input[idx] |
| |
| s = [shape[dim] for dim in actual_dims] |
| actual = torch.fft.hfftn(input, s=s, dim=dim, norm="ortho") |
| |
| self.assertEqual(expect, actual) |
| |
| @skipCPUIfNoFFT |
| @onlyNativeDeviceTypes |
| @toleranceOverride({ |
| torch.half : tol(1e-2, 1e-2), |
| }) |
| @dtypes(torch.half, torch.float, torch.double) |
| def test_ihfftn(self, device, dtype): |
| skip_helper_for_fft(device, dtype) |
| |
| # input_ndim, dim |
| transform_desc = [ |
| *product(range(2, 5), (None, (0,), (0, -1))), |
| (6, None), |
| (5, (1, 3, 4)), |
| (3, (1,)), |
| (1, (0,)), |
| (4, (0, 1)) |
| ] |
| |
| for input_ndim, dim in transform_desc: |
| if dtype is torch.half: |
| shape = tuple(itertools.islice(itertools.cycle((2, 4, 8)), input_ndim)) |
| else: |
| shape = tuple(itertools.islice(itertools.cycle(range(4, 9)), input_ndim)) |
| |
| input = torch.randn(*shape, device=device, dtype=dtype) |
| expect = torch.fft.ifftn(input, dim=dim, norm="ortho") |
| |
| # Slice off the half-symmetric component |
| lastdim = -1 if dim is None else dim[-1] |
| lastdim_size = expect.size(lastdim) // 2 + 1 |
| idx = [slice(None)] * input_ndim |
| idx[lastdim] = slice(0, lastdim_size) |
| expect = expect[idx] |
| |
| actual = torch.fft.ihfftn(input, dim=dim, norm="ortho") |
| self.assertEqual(expect, actual) |
| |
| |
| # 2d-fft tests |
| |
| # NOTE: 2d transforms are only thin wrappers over n-dim transforms, |
| # so don't require exhaustive testing. |
| |
| |
| @skipCPUIfNoFFT |
| @onlyNativeDeviceTypes |
| @dtypes(torch.double, torch.complex128) |
| def test_fft2_numpy(self, device, dtype): |
| norm_modes = REFERENCE_NORM_MODES |
| |
| # input_ndim, s |
| transform_desc = [ |
| *product(range(2, 5), (None, (4, 10))), |
| ] |
| |
| fft_functions = ['fft2', 'ifft2', 'irfft2', 'hfft2'] |
| if dtype.is_floating_point: |
| fft_functions += ['rfft2', 'ihfft2'] |
| |
| for input_ndim, s in transform_desc: |
| shape = itertools.islice(itertools.cycle(range(4, 9)), input_ndim) |
| input = torch.randn(*shape, device=device, dtype=dtype) |
| for fname, norm in product(fft_functions, norm_modes): |
| torch_fn = getattr(torch.fft, fname) |
| if "hfft" in fname: |
| if not has_scipy_fft: |
| continue # Requires scipy to compare against |
| numpy_fn = getattr(scipy.fft, fname) |
| else: |
| numpy_fn = getattr(np.fft, fname) |
| |
| def fn(t: torch.Tensor, s: Optional[List[int]], dim: List[int] = (-2, -1), norm: Optional[str] = None): |
| return torch_fn(t, s, dim, norm) |
| |
| torch_fns = (torch_fn, torch.jit.script(fn)) |
| |
| # Once with dim defaulted |
| input_np = input.cpu().numpy() |
| expected = numpy_fn(input_np, s, norm=norm) |
| for fn in torch_fns: |
| actual = fn(input, s, norm=norm) |
| self.assertEqual(actual, expected) |
| |
| # Once with explicit dims |
| dim = (1, 0) |
| expected = numpy_fn(input_np, s, dim, norm) |
| for fn in torch_fns: |
| actual = fn(input, s, dim, norm) |
| self.assertEqual(actual, expected) |
| |
| @skipCPUIfNoFFT |
| @onlyNativeDeviceTypes |
| @dtypes(torch.float, torch.complex64) |
| def test_fft2_fftn_equivalence(self, device, dtype): |
| norm_modes = (None, "forward", "backward", "ortho") |
| |
| # input_ndim, s, dim |
| transform_desc = [ |
| *product(range(2, 5), (None, (4, 10)), (None, (1, 0))), |
| (3, None, (0, 2)), |
| ] |
| |
| fft_functions = ['fft', 'ifft', 'irfft', 'hfft'] |
| # Real-only functions |
| if dtype.is_floating_point: |
| fft_functions += ['rfft', 'ihfft'] |
| |
| for input_ndim, s, dim in transform_desc: |
| shape = itertools.islice(itertools.cycle(range(4, 9)), input_ndim) |
| x = torch.randn(*shape, device=device, dtype=dtype) |
| |
| for func, norm in product(fft_functions, norm_modes): |
| f2d = getattr(torch.fft, func + '2') |
| fnd = getattr(torch.fft, func + 'n') |
| |
| kwargs = {'s': s, 'norm': norm} |
| |
| if dim is not None: |
| kwargs['dim'] = dim |
| expect = fnd(x, **kwargs) |
| else: |
| expect = fnd(x, dim=(-2, -1), **kwargs) |
| |
| actual = f2d(x, **kwargs) |
| |
| self.assertEqual(actual, expect) |
| |
| @skipCPUIfNoFFT |
| @onlyNativeDeviceTypes |
| def test_fft2_invalid(self, device): |
| a = torch.rand(10, 10, 10, device=device) |
| fft_funcs = (torch.fft.fft2, torch.fft.ifft2, |
| torch.fft.rfft2, torch.fft.irfft2) |
| |
| for func in fft_funcs: |
| with self.assertRaisesRegex(RuntimeError, "dims must be unique"): |
| func(a, dim=(0, 0)) |
| |
| with self.assertRaisesRegex(RuntimeError, "dims must be unique"): |
| func(a, dim=(2, -1)) |
| |
| with self.assertRaisesRegex(RuntimeError, "dim and shape .* same length"): |
| func(a, s=(1,)) |
| |
| with self.assertRaisesRegex(IndexError, "Dimension out of range"): |
| func(a, dim=(2, 3)) |
| |
| c = torch.complex(a, a) |
| with self.assertRaisesRegex(RuntimeError, "rfftn expects a real-valued input"): |
| torch.fft.rfft2(c) |
| |
| # Helper functions |
| |
| @skipCPUIfNoFFT |
| @onlyNativeDeviceTypes |
| @unittest.skipIf(not TEST_NUMPY, 'NumPy not found') |
| @dtypes(torch.float, torch.double) |
| def test_fftfreq_numpy(self, device, dtype): |
| test_args = [ |
| *product( |
| # n |
| range(1, 20), |
| # d |
| (None, 10.0), |
| ) |
| ] |
| |
| functions = ['fftfreq', 'rfftfreq'] |
| |
| for fname in functions: |
| torch_fn = getattr(torch.fft, fname) |
| numpy_fn = getattr(np.fft, fname) |
| |
| for n, d in test_args: |
| args = (n,) if d is None else (n, d) |
| expected = numpy_fn(*args) |
| actual = torch_fn(*args, device=device, dtype=dtype) |
| self.assertEqual(actual, expected, exact_dtype=False) |
| |
| @skipCPUIfNoFFT |
| @onlyNativeDeviceTypes |
| @dtypes(torch.float, torch.double) |
| def test_fftfreq_out(self, device, dtype): |
| for func in (torch.fft.fftfreq, torch.fft.rfftfreq): |
| expect = func(n=100, d=.5, device=device, dtype=dtype) |
| actual = torch.empty((), device=device, dtype=dtype) |
| with self.assertWarnsRegex(UserWarning, "out tensor will be resized"): |
| func(n=100, d=.5, out=actual) |
| self.assertEqual(actual, expect) |
| |
| |
| @skipCPUIfNoFFT |
| @onlyNativeDeviceTypes |
| @unittest.skipIf(not TEST_NUMPY, 'NumPy not found') |
| @dtypes(torch.float, torch.double, torch.complex64, torch.complex128) |
| def test_fftshift_numpy(self, device, dtype): |
| test_args = [ |
| # shape, dim |
| *product(((11,), (12,)), (None, 0, -1)), |
| *product(((4, 5), (6, 6)), (None, 0, (-1,))), |
| *product(((1, 1, 4, 6, 7, 2),), (None, (3, 4))), |
| ] |
| |
| functions = ['fftshift', 'ifftshift'] |
| |
| for shape, dim in test_args: |
| input = torch.rand(*shape, device=device, dtype=dtype) |
| input_np = input.cpu().numpy() |
| |
| for fname in functions: |
| torch_fn = getattr(torch.fft, fname) |
| numpy_fn = getattr(np.fft, fname) |
| |
| expected = numpy_fn(input_np, axes=dim) |
| actual = torch_fn(input, dim=dim) |
| self.assertEqual(actual, expected) |
| |
| @skipCPUIfNoFFT |
| @onlyNativeDeviceTypes |
| @unittest.skipIf(not TEST_NUMPY, 'NumPy not found') |
| @dtypes(torch.float, torch.double) |
| def test_fftshift_frequencies(self, device, dtype): |
| for n in range(10, 15): |
| sorted_fft_freqs = torch.arange(-(n // 2), n - (n // 2), |
| device=device, dtype=dtype) |
| x = torch.fft.fftfreq(n, d=1 / n, device=device, dtype=dtype) |
| |
| # Test fftshift sorts the fftfreq output |
| shifted = torch.fft.fftshift(x) |
| self.assertEqual(shifted, shifted.sort().values) |
| self.assertEqual(sorted_fft_freqs, shifted) |
| |
| # And ifftshift is the inverse |
| self.assertEqual(x, torch.fft.ifftshift(shifted)) |
| |
| # Legacy fft tests |
| def _test_fft_ifft_rfft_irfft(self, device, dtype): |
| complex_dtype = corresponding_complex_dtype(dtype) |
| |
| def _test_complex(sizes, signal_ndim, prepro_fn=lambda x: x): |
| x = prepro_fn(torch.randn(*sizes, dtype=complex_dtype, device=device)) |
| dim = tuple(range(-signal_ndim, 0)) |
| for norm in ('ortho', None): |
| res = torch.fft.fftn(x, dim=dim, norm=norm) |
| rec = torch.fft.ifftn(res, dim=dim, norm=norm) |
| self.assertEqual(x, rec, atol=1e-8, rtol=0, msg='fft and ifft') |
| res = torch.fft.ifftn(x, dim=dim, norm=norm) |
| rec = torch.fft.fftn(res, dim=dim, norm=norm) |
| self.assertEqual(x, rec, atol=1e-8, rtol=0, msg='ifft and fft') |
| |
| def _test_real(sizes, signal_ndim, prepro_fn=lambda x: x): |
| x = prepro_fn(torch.randn(*sizes, dtype=dtype, device=device)) |
| signal_numel = 1 |
| signal_sizes = x.size()[-signal_ndim:] |
| dim = tuple(range(-signal_ndim, 0)) |
| for norm in (None, 'ortho'): |
| res = torch.fft.rfftn(x, dim=dim, norm=norm) |
| rec = torch.fft.irfftn(res, s=signal_sizes, dim=dim, norm=norm) |
| self.assertEqual(x, rec, atol=1e-8, rtol=0, msg='rfft and irfft') |
| res = torch.fft.fftn(x, dim=dim, norm=norm) |
| rec = torch.fft.ifftn(res, dim=dim, norm=norm) |
| x_complex = torch.complex(x, torch.zeros_like(x)) |
| self.assertEqual(x_complex, rec, atol=1e-8, rtol=0, msg='fft and ifft (from real)') |
| |
| # contiguous case |
| _test_real((100,), 1) |
| _test_real((10, 1, 10, 100), 1) |
| _test_real((100, 100), 2) |
| _test_real((2, 2, 5, 80, 60), 2) |
| _test_real((50, 40, 70), 3) |
| _test_real((30, 1, 50, 25, 20), 3) |
| |
| _test_complex((100,), 1) |
| _test_complex((100, 100), 1) |
| _test_complex((100, 100), 2) |
| _test_complex((1, 20, 80, 60), 2) |
| _test_complex((50, 40, 70), 3) |
| _test_complex((6, 5, 50, 25, 20), 3) |
| |
| # non-contiguous case |
| _test_real((165,), 1, lambda x: x.narrow(0, 25, 100)) # input is not aligned to complex type |
| _test_real((100, 100, 3), 1, lambda x: x[:, :, 0]) |
| _test_real((100, 100), 2, lambda x: x.t()) |
| _test_real((20, 100, 10, 10), 2, lambda x: x.view(20, 100, 100)[:, :60]) |
| _test_real((65, 80, 115), 3, lambda x: x[10:60, 13:53, 10:80]) |
| _test_real((30, 20, 50, 25), 3, lambda x: x.transpose(1, 2).transpose(2, 3)) |
| |
| _test_complex((100,), 1, lambda x: x.expand(100, 100)) |
| _test_complex((20, 90, 110), 2, lambda x: x[:, 5:85].narrow(2, 5, 100)) |
| _test_complex((40, 60, 3, 80), 3, lambda x: x.transpose(2, 0).select(0, 2)[5:55, :, 10:]) |
| _test_complex((30, 55, 50, 22), 3, lambda x: x[:, 3:53, 15:40, 1:21]) |
| |
| @skipCPUIfNoFFT |
| @onlyNativeDeviceTypes |
| @dtypes(torch.double) |
| def test_fft_ifft_rfft_irfft(self, device, dtype): |
| self._test_fft_ifft_rfft_irfft(device, dtype) |
| |
| @deviceCountAtLeast(1) |
| @onlyCUDA |
| @dtypes(torch.double) |
| def test_cufft_plan_cache(self, devices, dtype): |
| @contextmanager |
| def plan_cache_max_size(device, n): |
| if device is None: |
| plan_cache = torch.backends.cuda.cufft_plan_cache |
| else: |
| plan_cache = torch.backends.cuda.cufft_plan_cache[device] |
| original = plan_cache.max_size |
| plan_cache.max_size = n |
| try: |
| yield |
| finally: |
| plan_cache.max_size = original |
| |
| with plan_cache_max_size(devices[0], max(1, torch.backends.cuda.cufft_plan_cache.size - 10)): |
| self._test_fft_ifft_rfft_irfft(devices[0], dtype) |
| |
| with plan_cache_max_size(devices[0], 0): |
| self._test_fft_ifft_rfft_irfft(devices[0], dtype) |
| |
| torch.backends.cuda.cufft_plan_cache.clear() |
| |
| # check that stll works after clearing cache |
| with plan_cache_max_size(devices[0], 10): |
| self._test_fft_ifft_rfft_irfft(devices[0], dtype) |
| |
| with self.assertRaisesRegex(RuntimeError, r"must be non-negative"): |
| torch.backends.cuda.cufft_plan_cache.max_size = -1 |
| |
| with self.assertRaisesRegex(RuntimeError, r"read-only property"): |
| torch.backends.cuda.cufft_plan_cache.size = -1 |
| |
| with self.assertRaisesRegex(RuntimeError, r"but got device with index"): |
| torch.backends.cuda.cufft_plan_cache[torch.cuda.device_count() + 10] |
| |
| # Multigpu tests |
| if len(devices) > 1: |
| # Test that different GPU has different cache |
| x0 = torch.randn(2, 3, 3, device=devices[0]) |
| x1 = x0.to(devices[1]) |
| self.assertEqual(torch.fft.rfftn(x0, dim=(-2, -1)), torch.fft.rfftn(x1, dim=(-2, -1))) |
| # If a plan is used across different devices, the following line (or |
| # the assert above) would trigger illegal memory access. Other ways |
| # to trigger the error include |
| # (1) setting CUDA_LAUNCH_BLOCKING=1 (pytorch/pytorch#19224) and |
| # (2) printing a device 1 tensor. |
| x0.copy_(x1) |
| |
| # Test that un-indexed `torch.backends.cuda.cufft_plan_cache` uses current device |
| with plan_cache_max_size(devices[0], 10): |
| with plan_cache_max_size(devices[1], 11): |
| self.assertEqual(torch.backends.cuda.cufft_plan_cache[0].max_size, 10) |
| self.assertEqual(torch.backends.cuda.cufft_plan_cache[1].max_size, 11) |
| |
| self.assertEqual(torch.backends.cuda.cufft_plan_cache.max_size, 10) # default is cuda:0 |
| with torch.cuda.device(devices[1]): |
| self.assertEqual(torch.backends.cuda.cufft_plan_cache.max_size, 11) # default is cuda:1 |
| with torch.cuda.device(devices[0]): |
| self.assertEqual(torch.backends.cuda.cufft_plan_cache.max_size, 10) # default is cuda:0 |
| |
| self.assertEqual(torch.backends.cuda.cufft_plan_cache[0].max_size, 10) |
| with torch.cuda.device(devices[1]): |
| with plan_cache_max_size(None, 11): # default is cuda:1 |
| self.assertEqual(torch.backends.cuda.cufft_plan_cache[0].max_size, 10) |
| self.assertEqual(torch.backends.cuda.cufft_plan_cache[1].max_size, 11) |
| |
| self.assertEqual(torch.backends.cuda.cufft_plan_cache.max_size, 11) # default is cuda:1 |
| with torch.cuda.device(devices[0]): |
| self.assertEqual(torch.backends.cuda.cufft_plan_cache.max_size, 10) # default is cuda:0 |
| self.assertEqual(torch.backends.cuda.cufft_plan_cache.max_size, 11) # default is cuda:1 |
| |
| @onlyCUDA |
| @dtypes(torch.cfloat, torch.cdouble) |
| def test_cufft_context(self, device, dtype): |
| # Regression test for https://github.com/pytorch/pytorch/issues/109448 |
| x = torch.randn(32, dtype=dtype, device=device, requires_grad=True) |
| dout = torch.zeros(32, dtype=dtype, device=device) |
| |
| # compute iFFT(FFT(x)) |
| out = torch.fft.ifft(torch.fft.fft(x)) |
| out.backward(dout, retain_graph=True) |
| |
| dx = torch.fft.fft(torch.fft.ifft(dout)) |
| |
| self.assertTrue((x.grad - dx).abs().max() == 0) |
| self.assertFalse((x.grad - x).abs().max() == 0) |
| |
| # passes on ROCm w/ python 2.7, fails w/ python 3.6 |
| @skipIfTorchDynamo("cannot set WRITEABLE flag to True of this array") |
| @skipCPUIfNoFFT |
| @onlyNativeDeviceTypes |
| @dtypes(torch.double) |
| def test_stft(self, device, dtype): |
| if not TEST_LIBROSA: |
| raise unittest.SkipTest('librosa not found') |
| |
| def librosa_stft(x, n_fft, hop_length, win_length, window, center): |
| if window is None: |
| window = np.ones(n_fft if win_length is None else win_length) |
| else: |
| window = window.cpu().numpy() |
| input_1d = x.dim() == 1 |
| if input_1d: |
| x = x.view(1, -1) |
| |
| # NOTE: librosa 0.9 changed default pad_mode to 'constant' (zero padding) |
| # however, we use the pre-0.9 default ('reflect') |
| pad_mode = 'reflect' |
| |
| result = [] |
| for xi in x: |
| ri = librosa.stft(xi.cpu().numpy(), n_fft=n_fft, hop_length=hop_length, |
| win_length=win_length, window=window, center=center, |
| pad_mode=pad_mode) |
| result.append(torch.from_numpy(np.stack([ri.real, ri.imag], -1))) |
| result = torch.stack(result, 0) |
| if input_1d: |
| result = result[0] |
| return result |
| |
| def _test(sizes, n_fft, hop_length=None, win_length=None, win_sizes=None, |
| center=True, expected_error=None): |
| x = torch.randn(*sizes, dtype=dtype, device=device) |
| if win_sizes is not None: |
| window = torch.randn(*win_sizes, dtype=dtype, device=device) |
| else: |
| window = None |
| if expected_error is None: |
| result = x.stft(n_fft, hop_length, win_length, window, |
| center=center, return_complex=False) |
| # NB: librosa defaults to np.complex64 output, no matter what |
| # the input dtype |
| ref_result = librosa_stft(x, n_fft, hop_length, win_length, window, center) |
| self.assertEqual(result, ref_result, atol=7e-6, rtol=0, msg='stft comparison against librosa', exact_dtype=False) |
| # With return_complex=True, the result is the same but viewed as complex instead of real |
| result_complex = x.stft(n_fft, hop_length, win_length, window, center=center, return_complex=True) |
| self.assertEqual(result_complex, torch.view_as_complex(result)) |
| else: |
| self.assertRaises(expected_error, |
| lambda: x.stft(n_fft, hop_length, win_length, window, center=center)) |
| |
| for center in [True, False]: |
| _test((10,), 7, center=center) |
| _test((10, 4000), 1024, center=center) |
| |
| _test((10,), 7, 2, center=center) |
| _test((10, 4000), 1024, 512, center=center) |
| |
| _test((10,), 7, 2, win_sizes=(7,), center=center) |
| _test((10, 4000), 1024, 512, win_sizes=(1024,), center=center) |
| |
| # spectral oversample |
| _test((10,), 7, 2, win_length=5, center=center) |
| _test((10, 4000), 1024, 512, win_length=100, center=center) |
| |
| _test((10, 4, 2), 1, 1, expected_error=RuntimeError) |
| _test((10,), 11, 1, center=False, expected_error=RuntimeError) |
| _test((10,), -1, 1, expected_error=RuntimeError) |
| _test((10,), 3, win_length=5, expected_error=RuntimeError) |
| _test((10,), 5, 4, win_sizes=(11,), expected_error=RuntimeError) |
| _test((10,), 5, 4, win_sizes=(1, 1), expected_error=RuntimeError) |
| |
| @skipIfTorchDynamo("double") |
| @skipCPUIfNoFFT |
| @onlyNativeDeviceTypes |
| @dtypes(torch.double) |
| def test_istft_against_librosa(self, device, dtype): |
| if not TEST_LIBROSA: |
| raise unittest.SkipTest('librosa not found') |
| |
| def librosa_istft(x, n_fft, hop_length, win_length, window, length, center): |
| if window is None: |
| window = np.ones(n_fft if win_length is None else win_length) |
| else: |
| window = window.cpu().numpy() |
| |
| return librosa.istft(x.cpu().numpy(), n_fft=n_fft, hop_length=hop_length, |
| win_length=win_length, length=length, window=window, center=center) |
| |
| def _test(size, n_fft, hop_length=None, win_length=None, win_sizes=None, |
| length=None, center=True): |
| x = torch.randn(size, dtype=dtype, device=device) |
| if win_sizes is not None: |
| window = torch.randn(*win_sizes, dtype=dtype, device=device) |
| else: |
| window = None |
| |
| x_stft = x.stft(n_fft, hop_length, win_length, window, center=center, |
| onesided=True, return_complex=True) |
| |
| ref_result = librosa_istft(x_stft, n_fft, hop_length, win_length, |
| window, length, center) |
| result = x_stft.istft(n_fft, hop_length, win_length, window, |
| length=length, center=center) |
| self.assertEqual(result, ref_result) |
| |
| for center in [True, False]: |
| _test(10, 7, center=center) |
| _test(4000, 1024, center=center) |
| _test(4000, 1024, center=center, length=4000) |
| |
| _test(10, 7, 2, center=center) |
| _test(4000, 1024, 512, center=center) |
| _test(4000, 1024, 512, center=center, length=4000) |
| |
| _test(10, 7, 2, win_sizes=(7,), center=center) |
| _test(4000, 1024, 512, win_sizes=(1024,), center=center) |
| _test(4000, 1024, 512, win_sizes=(1024,), center=center, length=4000) |
| |
| @onlyNativeDeviceTypes |
| @skipCPUIfNoFFT |
| @dtypes(torch.double, torch.cdouble) |
| def test_complex_stft_roundtrip(self, device, dtype): |
| test_args = list(product( |
| # input |
| (torch.randn(600, device=device, dtype=dtype), |
| torch.randn(807, device=device, dtype=dtype), |
| torch.randn(12, 60, device=device, dtype=dtype)), |
| # n_fft |
| (50, 27), |
| # hop_length |
| (None, 10), |
| # center |
| (True,), |
| # pad_mode |
| ("constant", "reflect", "circular"), |
| # normalized |
| (True, False), |
| # onesided |
| (True, False) if not dtype.is_complex else (False,), |
| )) |
| |
| for args in test_args: |
| x, n_fft, hop_length, center, pad_mode, normalized, onesided = args |
| common_kwargs = { |
| 'n_fft': n_fft, 'hop_length': hop_length, 'center': center, |
| 'normalized': normalized, 'onesided': onesided, |
| } |
| |
| # Functional interface |
| x_stft = torch.stft(x, pad_mode=pad_mode, return_complex=True, **common_kwargs) |
| x_roundtrip = torch.istft(x_stft, return_complex=dtype.is_complex, |
| length=x.size(-1), **common_kwargs) |
| self.assertEqual(x_roundtrip, x) |
| |
| # Tensor method interface |
| x_stft = x.stft(pad_mode=pad_mode, return_complex=True, **common_kwargs) |
| x_roundtrip = torch.istft(x_stft, return_complex=dtype.is_complex, |
| length=x.size(-1), **common_kwargs) |
| self.assertEqual(x_roundtrip, x) |
| |
| @onlyNativeDeviceTypes |
| @skipCPUIfNoFFT |
| @dtypes(torch.double, torch.cdouble) |
| def test_stft_roundtrip_complex_window(self, device, dtype): |
| test_args = list(product( |
| # input |
| (torch.randn(600, device=device, dtype=dtype), |
| torch.randn(807, device=device, dtype=dtype), |
| torch.randn(12, 60, device=device, dtype=dtype)), |
| # n_fft |
| (50, 27), |
| # hop_length |
| (None, 10), |
| # pad_mode |
| ("constant", "reflect", "replicate", "circular"), |
| # normalized |
| (True, False), |
| )) |
| for args in test_args: |
| x, n_fft, hop_length, pad_mode, normalized = args |
| window = torch.rand(n_fft, device=device, dtype=torch.cdouble) |
| x_stft = torch.stft( |
| x, n_fft=n_fft, hop_length=hop_length, window=window, |
| center=True, pad_mode=pad_mode, normalized=normalized) |
| self.assertEqual(x_stft.dtype, torch.cdouble) |
| self.assertEqual(x_stft.size(-2), n_fft) # Not onesided |
| |
| x_roundtrip = torch.istft( |
| x_stft, n_fft=n_fft, hop_length=hop_length, window=window, |
| center=True, normalized=normalized, length=x.size(-1), |
| return_complex=True) |
| self.assertEqual(x_stft.dtype, torch.cdouble) |
| |
| if not dtype.is_complex: |
| self.assertEqual(x_roundtrip.imag, torch.zeros_like(x_roundtrip.imag), |
| atol=1e-6, rtol=0) |
| self.assertEqual(x_roundtrip.real, x) |
| else: |
| self.assertEqual(x_roundtrip, x) |
| |
| |
| @skipCPUIfNoFFT |
| @dtypes(torch.cdouble) |
| def test_complex_stft_definition(self, device, dtype): |
| test_args = list(product( |
| # input |
| (torch.randn(600, device=device, dtype=dtype), |
| torch.randn(807, device=device, dtype=dtype)), |
| # n_fft |
| (50, 27), |
| # hop_length |
| (10, 15) |
| )) |
| |
| for args in test_args: |
| window = torch.randn(args[1], device=device, dtype=dtype) |
| expected = _stft_reference(args[0], args[2], window) |
| actual = torch.stft(*args, window=window, center=False) |
| self.assertEqual(actual, expected) |
| |
| @onlyNativeDeviceTypes |
| @skipCPUIfNoFFT |
| @dtypes(torch.cdouble) |
| def test_complex_stft_real_equiv(self, device, dtype): |
| test_args = list(product( |
| # input |
| (torch.rand(600, device=device, dtype=dtype), |
| torch.rand(807, device=device, dtype=dtype), |
| torch.rand(14, 50, device=device, dtype=dtype), |
| torch.rand(6, 51, device=device, dtype=dtype)), |
| # n_fft |
| (50, 27), |
| # hop_length |
| (None, 10), |
| # win_length |
| (None, 20), |
| # center |
| (False, True), |
| # pad_mode |
| ("constant", "reflect", "circular"), |
| # normalized |
| (True, False), |
| )) |
| |
| for args in test_args: |
| x, n_fft, hop_length, win_length, center, pad_mode, normalized = args |
| expected = _complex_stft(x, n_fft, hop_length=hop_length, |
| win_length=win_length, pad_mode=pad_mode, |
| center=center, normalized=normalized) |
| actual = torch.stft(x, n_fft, hop_length=hop_length, |
| win_length=win_length, pad_mode=pad_mode, |
| center=center, normalized=normalized) |
| self.assertEqual(expected, actual) |
| |
| @skipCPUIfNoFFT |
| @dtypes(torch.cdouble) |
| def test_complex_istft_real_equiv(self, device, dtype): |
| test_args = list(product( |
| # input |
| (torch.rand(40, 20, device=device, dtype=dtype), |
| torch.rand(25, 1, device=device, dtype=dtype), |
| torch.rand(4, 20, 10, device=device, dtype=dtype)), |
| # hop_length |
| (None, 10), |
| # center |
| (False, True), |
| # normalized |
| (True, False), |
| )) |
| |
| for args in test_args: |
| x, hop_length, center, normalized = args |
| n_fft = x.size(-2) |
| expected = _complex_istft(x, n_fft, hop_length=hop_length, |
| center=center, normalized=normalized) |
| actual = torch.istft(x, n_fft, hop_length=hop_length, |
| center=center, normalized=normalized, |
| return_complex=True) |
| self.assertEqual(expected, actual) |
| |
| @skipCPUIfNoFFT |
| def test_complex_stft_onesided(self, device): |
| # stft of complex input cannot be onesided |
| for x_dtype, window_dtype in product((torch.double, torch.cdouble), repeat=2): |
| x = torch.rand(100, device=device, dtype=x_dtype) |
| window = torch.rand(10, device=device, dtype=window_dtype) |
| |
| if x_dtype.is_complex or window_dtype.is_complex: |
| with self.assertRaisesRegex(RuntimeError, 'complex'): |
| x.stft(10, window=window, pad_mode='constant', onesided=True) |
| else: |
| y = x.stft(10, window=window, pad_mode='constant', onesided=True, |
| return_complex=True) |
| self.assertEqual(y.dtype, torch.cdouble) |
| self.assertEqual(y.size(), (6, 51)) |
| |
| x = torch.rand(100, device=device, dtype=torch.cdouble) |
| with self.assertRaisesRegex(RuntimeError, 'complex'): |
| x.stft(10, pad_mode='constant', onesided=True) |
| |
| # stft is currently warning that it requires return-complex while an upgrader is written |
| @onlyNativeDeviceTypes |
| @skipCPUIfNoFFT |
| def test_stft_requires_complex(self, device): |
| x = torch.rand(100) |
| with self.assertRaisesRegex(RuntimeError, 'stft requires the return_complex parameter'): |
| y = x.stft(10, pad_mode='constant') |
| |
| # stft and istft are currently warning if a window is not provided |
| @onlyNativeDeviceTypes |
| @skipCPUIfNoFFT |
| def test_stft_requires_window(self, device): |
| x = torch.rand(100) |
| with self.assertWarnsOnceRegex(UserWarning, "A window was not provided"): |
| y = x.stft(10, pad_mode='constant', return_complex=True) |
| |
| @onlyNativeDeviceTypes |
| @skipCPUIfNoFFT |
| def test_istft_requires_window(self, device): |
| stft = torch.rand((51, 5), dtype=torch.cdouble) |
| # 51 = 2 * n_fft + 1, 5 = number of frames |
| with self.assertWarnsOnceRegex(UserWarning, "A window was not provided"): |
| x = torch.istft(stft, n_fft=100, length=100) |
| |
| @skipCPUIfNoFFT |
| def test_fft_input_modification(self, device): |
| # FFT functions should not modify their input (gh-34551) |
| |
| signal = torch.ones((2, 2, 2), device=device) |
| signal_copy = signal.clone() |
| spectrum = torch.fft.fftn(signal, dim=(-2, -1)) |
| self.assertEqual(signal, signal_copy) |
| |
| spectrum_copy = spectrum.clone() |
| _ = torch.fft.ifftn(spectrum, dim=(-2, -1)) |
| self.assertEqual(spectrum, spectrum_copy) |
| |
| half_spectrum = torch.fft.rfftn(signal, dim=(-2, -1)) |
| self.assertEqual(signal, signal_copy) |
| |
| half_spectrum_copy = half_spectrum.clone() |
| _ = torch.fft.irfftn(half_spectrum_copy, s=(2, 2), dim=(-2, -1)) |
| self.assertEqual(half_spectrum, half_spectrum_copy) |
| |
| @onlyNativeDeviceTypes |
| @skipCPUIfNoFFT |
| def test_fft_plan_repeatable(self, device): |
| # Regression test for gh-58724 and gh-63152 |
| for n in [2048, 3199, 5999]: |
| a = torch.randn(n, device=device, dtype=torch.complex64) |
| res1 = torch.fft.fftn(a) |
| res2 = torch.fft.fftn(a.clone()) |
| self.assertEqual(res1, res2) |
| |
| a = torch.randn(n, device=device, dtype=torch.float64) |
| res1 = torch.fft.rfft(a) |
| res2 = torch.fft.rfft(a.clone()) |
| self.assertEqual(res1, res2) |
| |
| @onlyNativeDeviceTypes |
| @skipCPUIfNoFFT |
| @dtypes(torch.double) |
| def test_istft_round_trip_simple_cases(self, device, dtype): |
| """stft -> istft should recover the original signale""" |
| def _test(input, n_fft, length): |
| stft = torch.stft(input, n_fft=n_fft, return_complex=True) |
| inverse = torch.istft(stft, n_fft=n_fft, length=length) |
| self.assertEqual(input, inverse, exact_dtype=True) |
| |
| _test(torch.ones(4, dtype=dtype, device=device), 4, 4) |
| _test(torch.zeros(4, dtype=dtype, device=device), 4, 4) |
| |
| @onlyNativeDeviceTypes |
| @skipCPUIfNoFFT |
| @dtypes(torch.double) |
| def test_istft_round_trip_various_params(self, device, dtype): |
| """stft -> istft should recover the original signale""" |
| def _test_istft_is_inverse_of_stft(stft_kwargs): |
| # generates a random sound signal for each tril and then does the stft/istft |
| # operation to check whether we can reconstruct signal |
| data_sizes = [(2, 20), (3, 15), (4, 10)] |
| num_trials = 100 |
| istft_kwargs = stft_kwargs.copy() |
| del istft_kwargs['pad_mode'] |
| for sizes in data_sizes: |
| for i in range(num_trials): |
| original = torch.randn(*sizes, dtype=dtype, device=device) |
| stft = torch.stft(original, return_complex=True, **stft_kwargs) |
| inversed = torch.istft(stft, length=original.size(1), **istft_kwargs) |
| self.assertEqual( |
| inversed, original, msg='istft comparison against original', |
| atol=7e-6, rtol=0, exact_dtype=True) |
| |
| patterns = [ |
| # hann_window, centered, normalized, onesided |
| { |
| 'n_fft': 12, |
| 'hop_length': 4, |
| 'win_length': 12, |
| 'window': torch.hann_window(12, dtype=dtype, device=device), |
| 'center': True, |
| 'pad_mode': 'reflect', |
| 'normalized': True, |
| 'onesided': True, |
| }, |
| # hann_window, centered, not normalized, not onesided |
| { |
| 'n_fft': 12, |
| 'hop_length': 2, |
| 'win_length': 8, |
| 'window': torch.hann_window(8, dtype=dtype, device=device), |
| 'center': True, |
| 'pad_mode': 'reflect', |
| 'normalized': False, |
| 'onesided': False, |
| }, |
| # hamming_window, centered, normalized, not onesided |
| { |
| 'n_fft': 15, |
| 'hop_length': 3, |
| 'win_length': 11, |
| 'window': torch.hamming_window(11, dtype=dtype, device=device), |
| 'center': True, |
| 'pad_mode': 'constant', |
| 'normalized': True, |
| 'onesided': False, |
| }, |
| # hamming_window, centered, not normalized, onesided |
| # window same size as n_fft |
| { |
| 'n_fft': 5, |
| 'hop_length': 2, |
| 'win_length': 5, |
| 'window': torch.hamming_window(5, dtype=dtype, device=device), |
| 'center': True, |
| 'pad_mode': 'constant', |
| 'normalized': False, |
| 'onesided': True, |
| }, |
| ] |
| for i, pattern in enumerate(patterns): |
| _test_istft_is_inverse_of_stft(pattern) |
| |
| @onlyNativeDeviceTypes |
| @skipCPUIfNoFFT |
| @dtypes(torch.double) |
| def test_istft_round_trip_with_padding(self, device, dtype): |
| """long hop_length or not centered may cause length mismatch in the inversed signal""" |
| def _test_istft_is_inverse_of_stft_with_padding(stft_kwargs): |
| # generates a random sound signal for each tril and then does the stft/istft |
| # operation to check whether we can reconstruct signal |
| num_trials = 100 |
| sizes = stft_kwargs['size'] |
| del stft_kwargs['size'] |
| istft_kwargs = stft_kwargs.copy() |
| del istft_kwargs['pad_mode'] |
| for i in range(num_trials): |
| original = torch.randn(*sizes, dtype=dtype, device=device) |
| stft = torch.stft(original, return_complex=True, **stft_kwargs) |
| with self.assertWarnsOnceRegex(UserWarning, "The length of signal is shorter than the length parameter."): |
| inversed = torch.istft(stft, length=original.size(-1), **istft_kwargs) |
| n_frames = stft.size(-1) |
| if stft_kwargs["center"] is True: |
| len_expected = stft_kwargs["n_fft"] // 2 + stft_kwargs["hop_length"] * (n_frames - 1) |
| else: |
| len_expected = stft_kwargs["n_fft"] + stft_kwargs["hop_length"] * (n_frames - 1) |
| # trim the original for case when constructed signal is shorter than original |
| padding = inversed[..., len_expected:] |
| inversed = inversed[..., :len_expected] |
| original = original[..., :len_expected] |
| # test the padding points of the inversed signal are all zeros |
| zeros = torch.zeros_like(padding, device=padding.device) |
| self.assertEqual( |
| padding, zeros, msg='istft padding values against zeros', |
| atol=7e-6, rtol=0, exact_dtype=True) |
| self.assertEqual( |
| inversed, original, msg='istft comparison against original', |
| atol=7e-6, rtol=0, exact_dtype=True) |
| |
| patterns = [ |
| # hamming_window, not centered, not normalized, not onesided |
| # window same size as n_fft |
| { |
| 'size': [2, 20], |
| 'n_fft': 3, |
| 'hop_length': 2, |
| 'win_length': 3, |
| 'window': torch.hamming_window(3, dtype=dtype, device=device), |
| 'center': False, |
| 'pad_mode': 'reflect', |
| 'normalized': False, |
| 'onesided': False, |
| }, |
| # hamming_window, centered, not normalized, onesided, long hop_length |
| # window same size as n_fft |
| { |
| 'size': [2, 500], |
| 'n_fft': 256, |
| 'hop_length': 254, |
| 'win_length': 256, |
| 'window': torch.hamming_window(256, dtype=dtype, device=device), |
| 'center': True, |
| 'pad_mode': 'constant', |
| 'normalized': False, |
| 'onesided': True, |
| }, |
| ] |
| for i, pattern in enumerate(patterns): |
| _test_istft_is_inverse_of_stft_with_padding(pattern) |
| |
| @onlyNativeDeviceTypes |
| def test_istft_throws(self, device): |
| """istft should throw exception for invalid parameters""" |
| stft = torch.zeros((3, 5, 2), device=device) |
| # the window is size 1 but it hops 20 so there is a gap which throw an error |
| self.assertRaises( |
| RuntimeError, torch.istft, stft, n_fft=4, |
| hop_length=20, win_length=1, window=torch.ones(1)) |
| # A window of zeros does not meet NOLA |
| invalid_window = torch.zeros(4, device=device) |
| self.assertRaises( |
| RuntimeError, torch.istft, stft, n_fft=4, win_length=4, window=invalid_window) |
| # Input cannot be empty |
| self.assertRaises(RuntimeError, torch.istft, torch.zeros((3, 0, 2)), 2) |
| self.assertRaises(RuntimeError, torch.istft, torch.zeros((0, 3, 2)), 2) |
| |
| @skipIfTorchDynamo("Failed running call_function") |
| @onlyNativeDeviceTypes |
| @skipCPUIfNoFFT |
| @dtypes(torch.double) |
| def test_istft_of_sine(self, device, dtype): |
| complex_dtype = corresponding_complex_dtype(dtype) |
| |
| def _test(amplitude, L, n): |
| # stft of amplitude*sin(2*pi/L*n*x) with the hop length and window size equaling L |
| x = torch.arange(2 * L + 1, device=device, dtype=dtype) |
| original = amplitude * torch.sin(2 * math.pi / L * x * n) |
| # stft = torch.stft(original, L, hop_length=L, win_length=L, |
| # window=torch.ones(L), center=False, normalized=False) |
| stft = torch.zeros((L // 2 + 1, 2), device=device, dtype=complex_dtype) |
| stft_largest_val = (amplitude * L) / 2.0 |
| if n < stft.size(0): |
| stft[n].imag = torch.tensor(-stft_largest_val, dtype=dtype) |
| |
| if 0 <= L - n < stft.size(0): |
| # symmetric about L // 2 |
| stft[L - n].imag = torch.tensor(stft_largest_val, dtype=dtype) |
| |
| inverse = torch.istft( |
| stft, L, hop_length=L, win_length=L, |
| window=torch.ones(L, device=device, dtype=dtype), center=False, normalized=False) |
| # There is a larger error due to the scaling of amplitude |
| original = original[..., :inverse.size(-1)] |
| self.assertEqual(inverse, original, atol=1e-3, rtol=0) |
| |
| _test(amplitude=123, L=5, n=1) |
| _test(amplitude=150, L=5, n=2) |
| _test(amplitude=111, L=5, n=3) |
| _test(amplitude=160, L=7, n=4) |
| _test(amplitude=145, L=8, n=5) |
| _test(amplitude=80, L=9, n=6) |
| _test(amplitude=99, L=10, n=7) |
| |
| @onlyNativeDeviceTypes |
| @skipCPUIfNoFFT |
| @dtypes(torch.double) |
| def test_istft_linearity(self, device, dtype): |
| num_trials = 100 |
| complex_dtype = corresponding_complex_dtype(dtype) |
| |
| def _test(data_size, kwargs): |
| for i in range(num_trials): |
| tensor1 = torch.randn(data_size, device=device, dtype=complex_dtype) |
| tensor2 = torch.randn(data_size, device=device, dtype=complex_dtype) |
| a, b = torch.rand(2, dtype=dtype, device=device) |
| # Also compare method vs. functional call signature |
| istft1 = tensor1.istft(**kwargs) |
| istft2 = tensor2.istft(**kwargs) |
| istft = a * istft1 + b * istft2 |
| estimate = torch.istft(a * tensor1 + b * tensor2, **kwargs) |
| self.assertEqual(istft, estimate, atol=1e-5, rtol=0) |
| patterns = [ |
| # hann_window, centered, normalized, onesided |
| ( |
| (2, 7, 7), |
| { |
| 'n_fft': 12, |
| 'window': torch.hann_window(12, device=device, dtype=dtype), |
| 'center': True, |
| 'normalized': True, |
| 'onesided': True, |
| }, |
| ), |
| # hann_window, centered, not normalized, not onesided |
| ( |
| (2, 12, 7), |
| { |
| 'n_fft': 12, |
| 'window': torch.hann_window(12, device=device, dtype=dtype), |
| 'center': True, |
| 'normalized': False, |
| 'onesided': False, |
| }, |
| ), |
| # hamming_window, centered, normalized, not onesided |
| ( |
| (2, 12, 7), |
| { |
| 'n_fft': 12, |
| 'window': torch.hamming_window(12, device=device, dtype=dtype), |
| 'center': True, |
| 'normalized': True, |
| 'onesided': False, |
| }, |
| ), |
| # hamming_window, not centered, not normalized, onesided |
| ( |
| (2, 7, 3), |
| { |
| 'n_fft': 12, |
| 'window': torch.hamming_window(12, device=device, dtype=dtype), |
| 'center': False, |
| 'normalized': False, |
| 'onesided': True, |
| }, |
| ) |
| ] |
| for data_size, kwargs in patterns: |
| _test(data_size, kwargs) |
| |
| @onlyNativeDeviceTypes |
| @skipCPUIfNoFFT |
| def test_batch_istft(self, device): |
| original = torch.tensor([ |
| [4., 4., 4., 4., 4.], |
| [0., 0., 0., 0., 0.], |
| [0., 0., 0., 0., 0.] |
| ], device=device, dtype=torch.complex64) |
| |
| single = original.repeat(1, 1, 1) |
| multi = original.repeat(4, 1, 1) |
| |
| i_original = torch.istft(original, n_fft=4, length=4) |
| i_single = torch.istft(single, n_fft=4, length=4) |
| i_multi = torch.istft(multi, n_fft=4, length=4) |
| |
| self.assertEqual(i_original.repeat(1, 1), i_single, atol=1e-6, rtol=0, exact_dtype=True) |
| self.assertEqual(i_original.repeat(4, 1), i_multi, atol=1e-6, rtol=0, exact_dtype=True) |
| |
| @onlyCUDA |
| @skipIf(not TEST_MKL, "Test requires MKL") |
| def test_stft_window_device(self, device): |
| # Test the (i)stft window must be on the same device as the input |
| x = torch.randn(1000, dtype=torch.complex64) |
| window = torch.randn(100, dtype=torch.complex64) |
| |
| with self.assertRaisesRegex(RuntimeError, "stft input and window must be on the same device"): |
| torch.stft(x, n_fft=100, window=window.to(device)) |
| |
| with self.assertRaisesRegex(RuntimeError, "stft input and window must be on the same device"): |
| torch.stft(x.to(device), n_fft=100, window=window) |
| |
| X = torch.stft(x, n_fft=100, window=window) |
| |
| with self.assertRaisesRegex(RuntimeError, "istft input and window must be on the same device"): |
| torch.istft(X, n_fft=100, window=window.to(device)) |
| |
| with self.assertRaisesRegex(RuntimeError, "istft input and window must be on the same device"): |
| torch.istft(x.to(device), n_fft=100, window=window) |
| |
| |
| class FFTDocTestFinder: |
| '''The default doctest finder doesn't like that function.__module__ doesn't |
| match torch.fft. It assumes the functions are leaked imports. |
| ''' |
| def __init__(self) -> None: |
| self.parser = doctest.DocTestParser() |
| |
| def find(self, obj, name=None, module=None, globs=None, extraglobs=None): |
| doctests = [] |
| |
| modname = name if name is not None else obj.__name__ |
| globs = {} if globs is None else globs |
| |
| for fname in obj.__all__: |
| func = getattr(obj, fname) |
| if inspect.isroutine(func): |
| qualname = modname + '.' + fname |
| docstring = inspect.getdoc(func) |
| if docstring is None: |
| continue |
| |
| examples = self.parser.get_doctest( |
| docstring, globs=globs, name=fname, filename=None, lineno=None) |
| doctests.append(examples) |
| |
| return doctests |
| |
| |
| class TestFFTDocExamples(TestCase): |
| pass |
| |
| def generate_doc_test(doc_test): |
| def test(self, device): |
| self.assertEqual(device, 'cpu') |
| runner = doctest.DocTestRunner() |
| runner.run(doc_test) |
| |
| if runner.failures != 0: |
| runner.summarize() |
| self.fail('Doctest failed') |
| |
| setattr(TestFFTDocExamples, 'test_' + doc_test.name, skipCPUIfNoFFT(test)) |
| |
| for doc_test in FFTDocTestFinder().find(torch.fft, globs=dict(torch=torch)): |
| generate_doc_test(doc_test) |
| |
| |
| instantiate_device_type_tests(TestFFT, globals()) |
| instantiate_device_type_tests(TestFFTDocExamples, globals(), only_for='cpu') |
| |
| if __name__ == '__main__': |
| run_tests() |