| # mypy: allow-untyped-defs |
| # Owner(s): ["module: complex"] |
| |
| import torch |
| from torch.testing._internal.common_device_type import ( |
| dtypes, |
| instantiate_device_type_tests, |
| onlyCPU, |
| ) |
| from torch.testing._internal.common_dtype import complex_types |
| from torch.testing._internal.common_utils import run_tests, set_default_dtype, TestCase |
| |
| |
| devices = (torch.device("cpu"), torch.device("cuda:0")) |
| |
| |
| class TestComplexTensor(TestCase): |
| @dtypes(*complex_types()) |
| def test_to_list(self, device, dtype): |
| # test that the complex float tensor has expected values and |
| # there's no garbage value in the resultant list |
| self.assertEqual( |
| torch.zeros((2, 2), device=device, dtype=dtype).tolist(), |
| [[0j, 0j], [0j, 0j]], |
| ) |
| |
| @dtypes(torch.float32, torch.float64, torch.float16) |
| def test_dtype_inference(self, device, dtype): |
| # issue: https://github.com/pytorch/pytorch/issues/36834 |
| with set_default_dtype(dtype): |
| x = torch.tensor([3.0, 3.0 + 5.0j], device=device) |
| if dtype == torch.float16: |
| self.assertEqual(x.dtype, torch.chalf) |
| elif dtype == torch.float32: |
| self.assertEqual(x.dtype, torch.cfloat) |
| else: |
| self.assertEqual(x.dtype, torch.cdouble) |
| |
| @dtypes(*complex_types()) |
| def test_conj_copy(self, device, dtype): |
| # issue: https://github.com/pytorch/pytorch/issues/106051 |
| x1 = torch.tensor([5 + 1j, 2 + 2j], device=device, dtype=dtype) |
| xc1 = torch.conj(x1) |
| x1.copy_(xc1) |
| self.assertEqual(x1, torch.tensor([5 - 1j, 2 - 2j], device=device, dtype=dtype)) |
| |
| @dtypes(*complex_types()) |
| def test_all(self, device, dtype): |
| # issue: https://github.com/pytorch/pytorch/issues/120875 |
| x = torch.tensor([1 + 2j, 3 - 4j, 5j, 6], device=device, dtype=dtype) |
| self.assertTrue(torch.all(x)) |
| |
| @dtypes(*complex_types()) |
| def test_any(self, device, dtype): |
| # issue: https://github.com/pytorch/pytorch/issues/120875 |
| x = torch.tensor( |
| [0, 0j, -0 + 0j, -0 - 0j, 0 + 0j, 0 - 0j], device=device, dtype=dtype |
| ) |
| self.assertFalse(torch.any(x)) |
| |
| @onlyCPU |
| @dtypes(*complex_types()) |
| def test_eq(self, device, dtype): |
| "Test eq on complex types" |
| nan = float("nan") |
| # Non-vectorized operations |
| for a, b in ( |
| ( |
| torch.tensor([-0.0610 - 2.1172j], device=device, dtype=dtype), |
| torch.tensor([-6.1278 - 8.5019j], device=device, dtype=dtype), |
| ), |
| ( |
| torch.tensor([-0.0610 - 2.1172j], device=device, dtype=dtype), |
| torch.tensor([-6.1278 - 2.1172j], device=device, dtype=dtype), |
| ), |
| ( |
| torch.tensor([-0.0610 - 2.1172j], device=device, dtype=dtype), |
| torch.tensor([-0.0610 - 8.5019j], device=device, dtype=dtype), |
| ), |
| ): |
| actual = torch.eq(a, b) |
| expected = torch.tensor([False], device=device, dtype=torch.bool) |
| self.assertEqual( |
| actual, expected, msg=f"\neq\nactual {actual}\nexpected {expected}" |
| ) |
| |
| actual = torch.eq(a, a) |
| expected = torch.tensor([True], device=device, dtype=torch.bool) |
| self.assertEqual( |
| actual, expected, msg=f"\neq\nactual {actual}\nexpected {expected}" |
| ) |
| |
| actual = torch.full_like(b, complex(2, 2)) |
| torch.eq(a, b, out=actual) |
| expected = torch.tensor([complex(0)], device=device, dtype=dtype) |
| self.assertEqual( |
| actual, expected, msg=f"\neq(out)\nactual {actual}\nexpected {expected}" |
| ) |
| |
| actual = torch.full_like(b, complex(2, 2)) |
| torch.eq(a, a, out=actual) |
| expected = torch.tensor([complex(1)], device=device, dtype=dtype) |
| self.assertEqual( |
| actual, expected, msg=f"\neq(out)\nactual {actual}\nexpected {expected}" |
| ) |
| |
| # Vectorized operations |
| for a, b in ( |
| ( |
| torch.tensor( |
| [ |
| -0.0610 - 2.1172j, |
| 5.1576 + 5.4775j, |
| complex(2.8871, nan), |
| -6.6545 - 3.7655j, |
| -2.7036 - 1.4470j, |
| 0.3712 + 7.989j, |
| -0.0610 - 2.1172j, |
| 5.1576 + 5.4775j, |
| complex(nan, -3.2650), |
| -6.6545 - 3.7655j, |
| -2.7036 - 1.4470j, |
| 0.3712 + 7.989j, |
| ], |
| device=device, |
| dtype=dtype, |
| ), |
| torch.tensor( |
| [ |
| -6.1278 - 8.5019j, |
| 0.5886 + 8.8816j, |
| complex(2.8871, nan), |
| 6.3505 + 2.2683j, |
| 0.3712 + 7.9659j, |
| 0.3712 + 7.989j, |
| -6.1278 - 2.1172j, |
| 5.1576 + 8.8816j, |
| complex(nan, -3.2650), |
| 6.3505 + 2.2683j, |
| 0.3712 + 7.9659j, |
| 0.3712 + 7.989j, |
| ], |
| device=device, |
| dtype=dtype, |
| ), |
| ), |
| ): |
| actual = torch.eq(a, b) |
| expected = torch.tensor( |
| [ |
| False, |
| False, |
| False, |
| False, |
| False, |
| True, |
| False, |
| False, |
| False, |
| False, |
| False, |
| True, |
| ], |
| device=device, |
| dtype=torch.bool, |
| ) |
| self.assertEqual( |
| actual, expected, msg=f"\neq\nactual {actual}\nexpected {expected}" |
| ) |
| |
| actual = torch.eq(a, a) |
| expected = torch.tensor( |
| [ |
| True, |
| True, |
| False, |
| True, |
| True, |
| True, |
| True, |
| True, |
| False, |
| True, |
| True, |
| True, |
| ], |
| device=device, |
| dtype=torch.bool, |
| ) |
| self.assertEqual( |
| actual, expected, msg=f"\neq\nactual {actual}\nexpected {expected}" |
| ) |
| |
| actual = torch.full_like(b, complex(2, 2)) |
| torch.eq(a, b, out=actual) |
| expected = torch.tensor( |
| [ |
| complex(0), |
| complex(0), |
| complex(0), |
| complex(0), |
| complex(0), |
| complex(1), |
| complex(0), |
| complex(0), |
| complex(0), |
| complex(0), |
| complex(0), |
| complex(1), |
| ], |
| device=device, |
| dtype=dtype, |
| ) |
| self.assertEqual( |
| actual, expected, msg=f"\neq(out)\nactual {actual}\nexpected {expected}" |
| ) |
| |
| actual = torch.full_like(b, complex(2, 2)) |
| torch.eq(a, a, out=actual) |
| expected = torch.tensor( |
| [ |
| complex(1), |
| complex(1), |
| complex(0), |
| complex(1), |
| complex(1), |
| complex(1), |
| complex(1), |
| complex(1), |
| complex(0), |
| complex(1), |
| complex(1), |
| complex(1), |
| ], |
| device=device, |
| dtype=dtype, |
| ) |
| self.assertEqual( |
| actual, expected, msg=f"\neq(out)\nactual {actual}\nexpected {expected}" |
| ) |
| |
| @onlyCPU |
| @dtypes(*complex_types()) |
| def test_ne(self, device, dtype): |
| "Test ne on complex types" |
| nan = float("nan") |
| # Non-vectorized operations |
| for a, b in ( |
| ( |
| torch.tensor([-0.0610 - 2.1172j], device=device, dtype=dtype), |
| torch.tensor([-6.1278 - 8.5019j], device=device, dtype=dtype), |
| ), |
| ( |
| torch.tensor([-0.0610 - 2.1172j], device=device, dtype=dtype), |
| torch.tensor([-6.1278 - 2.1172j], device=device, dtype=dtype), |
| ), |
| ( |
| torch.tensor([-0.0610 - 2.1172j], device=device, dtype=dtype), |
| torch.tensor([-0.0610 - 8.5019j], device=device, dtype=dtype), |
| ), |
| ): |
| actual = torch.ne(a, b) |
| expected = torch.tensor([True], device=device, dtype=torch.bool) |
| self.assertEqual( |
| actual, expected, msg=f"\nne\nactual {actual}\nexpected {expected}" |
| ) |
| |
| actual = torch.ne(a, a) |
| expected = torch.tensor([False], device=device, dtype=torch.bool) |
| self.assertEqual( |
| actual, expected, msg=f"\nne\nactual {actual}\nexpected {expected}" |
| ) |
| |
| actual = torch.full_like(b, complex(2, 2)) |
| torch.ne(a, b, out=actual) |
| expected = torch.tensor([complex(1)], device=device, dtype=dtype) |
| self.assertEqual( |
| actual, expected, msg=f"\nne(out)\nactual {actual}\nexpected {expected}" |
| ) |
| |
| actual = torch.full_like(b, complex(2, 2)) |
| torch.ne(a, a, out=actual) |
| expected = torch.tensor([complex(0)], device=device, dtype=dtype) |
| self.assertEqual( |
| actual, expected, msg=f"\nne(out)\nactual {actual}\nexpected {expected}" |
| ) |
| |
| # Vectorized operations |
| for a, b in ( |
| ( |
| torch.tensor( |
| [ |
| -0.0610 - 2.1172j, |
| 5.1576 + 5.4775j, |
| complex(2.8871, nan), |
| -6.6545 - 3.7655j, |
| -2.7036 - 1.4470j, |
| 0.3712 + 7.989j, |
| -0.0610 - 2.1172j, |
| 5.1576 + 5.4775j, |
| complex(nan, -3.2650), |
| -6.6545 - 3.7655j, |
| -2.7036 - 1.4470j, |
| 0.3712 + 7.989j, |
| ], |
| device=device, |
| dtype=dtype, |
| ), |
| torch.tensor( |
| [ |
| -6.1278 - 8.5019j, |
| 0.5886 + 8.8816j, |
| complex(2.8871, nan), |
| 6.3505 + 2.2683j, |
| 0.3712 + 7.9659j, |
| 0.3712 + 7.989j, |
| -6.1278 - 2.1172j, |
| 5.1576 + 8.8816j, |
| complex(nan, -3.2650), |
| 6.3505 + 2.2683j, |
| 0.3712 + 7.9659j, |
| 0.3712 + 7.989j, |
| ], |
| device=device, |
| dtype=dtype, |
| ), |
| ), |
| ): |
| actual = torch.ne(a, b) |
| expected = torch.tensor( |
| [ |
| True, |
| True, |
| True, |
| True, |
| True, |
| False, |
| True, |
| True, |
| True, |
| True, |
| True, |
| False, |
| ], |
| device=device, |
| dtype=torch.bool, |
| ) |
| self.assertEqual( |
| actual, expected, msg=f"\nne\nactual {actual}\nexpected {expected}" |
| ) |
| |
| actual = torch.ne(a, a) |
| expected = torch.tensor( |
| [ |
| False, |
| False, |
| True, |
| False, |
| False, |
| False, |
| False, |
| False, |
| True, |
| False, |
| False, |
| False, |
| ], |
| device=device, |
| dtype=torch.bool, |
| ) |
| self.assertEqual( |
| actual, expected, msg=f"\nne\nactual {actual}\nexpected {expected}" |
| ) |
| |
| actual = torch.full_like(b, complex(2, 2)) |
| torch.ne(a, b, out=actual) |
| expected = torch.tensor( |
| [ |
| complex(1), |
| complex(1), |
| complex(1), |
| complex(1), |
| complex(1), |
| complex(0), |
| complex(1), |
| complex(1), |
| complex(1), |
| complex(1), |
| complex(1), |
| complex(0), |
| ], |
| device=device, |
| dtype=dtype, |
| ) |
| self.assertEqual( |
| actual, expected, msg=f"\nne(out)\nactual {actual}\nexpected {expected}" |
| ) |
| |
| actual = torch.full_like(b, complex(2, 2)) |
| torch.ne(a, a, out=actual) |
| expected = torch.tensor( |
| [ |
| complex(0), |
| complex(0), |
| complex(1), |
| complex(0), |
| complex(0), |
| complex(0), |
| complex(0), |
| complex(0), |
| complex(1), |
| complex(0), |
| complex(0), |
| complex(0), |
| ], |
| device=device, |
| dtype=dtype, |
| ) |
| self.assertEqual( |
| actual, expected, msg=f"\nne(out)\nactual {actual}\nexpected {expected}" |
| ) |
| |
| |
| instantiate_device_type_tests(TestComplexTensor, globals()) |
| |
| if __name__ == "__main__": |
| TestCase._default_dtype_check_enabled = True |
| run_tests() |