blob: 3b7d344271678feda670363a34c4165487e91e78 [file] [log] [blame]
# Owner(s): ["oncall: jit"]
import torch
import os
import sys
from torch.testing._internal.jit_utils import JitTestCase, execWrapper
from torch.testing._internal.common_utils import IS_MACOS
from typing import List, Dict
from itertools import product
from textwrap import dedent
import cmath
# Make the helper files in test/ importable
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir)
class TestComplex(JitTestCase):
def test_script(self):
def fn(a: complex):
return a
self.checkScript(fn, (3 + 5j,))
def test_complexlist(self):
def fn(a: List[complex], idx: int):
return a[idx]
input = [1j, 2, 3 + 4j, -5, -7j]
self.checkScript(fn, (input, 2))
def test_complexdict(self):
def fn(a: Dict[complex, complex], key: complex) -> complex:
return a[key]
input = {2 + 3j : 2 - 3j, -4.3 - 2j: 3j}
self.checkScript(fn, (input, -4.3 - 2j))
def test_pickle(self):
class ComplexModule(torch.jit.ScriptModule):
def __init__(self):
super().__init__()
self.a = 3 + 5j
self.b = [2 + 3j, 3 + 4j, 0 - 3j, -4 + 0j]
self.c = {2 + 3j : 2 - 3j, -4.3 - 2j: 3j}
@torch.jit.script_method
def forward(self, b: int):
return b + 2j
loaded = self.getExportImportCopy(ComplexModule())
self.assertEqual(loaded.a, 3 + 5j)
self.assertEqual(loaded.b, [2 + 3j, 3 + 4j, -3j, -4])
self.assertEqual(loaded.c, {2 + 3j : 2 - 3j, -4.3 - 2j: 3j})
self.assertEqual(loaded(2), 2 + 2j)
def test_complex_parse(self):
def fn(a: int, b: torch.Tensor, dim: int):
# verifies `emitValueToTensor()` 's behavior
b[dim] = 2.4 + 0.5j
return (3 * 2j) + a + 5j - 7.4j - 4
t1 = torch.tensor(1)
t2 = torch.tensor([0.4, 1.4j, 2.35])
self.checkScript(fn, (t1, t2, 2))
def test_complex_constants_and_ops(self):
vals = ([0.0, 1.0, 2.2, -1.0, -0.0, -2.2, 1, 0, 2]
+ [10.0 ** i for i in range(2)] + [-(10.0 ** i) for i in range(2)])
complex_vals = tuple(complex(x, y) for x, y in product(vals, vals))
funcs_template = dedent('''
def func(a: complex):
return cmath.{func_or_const}(a)
''')
def checkCmath(func_name, funcs_template=funcs_template):
funcs_str = funcs_template.format(func_or_const=func_name)
scope = {}
execWrapper(funcs_str, globals(), scope)
cu = torch.jit.CompilationUnit(funcs_str)
f_script = cu.func
f = scope['func']
if func_name in ['isinf', 'isnan', 'isfinite']:
new_vals = vals + ([float('inf'), float('nan'), -1 * float('inf')])
final_vals = tuple(complex(x, y) for x, y in product(new_vals, new_vals))
else:
final_vals = complex_vals
for a in final_vals:
res_python = None
res_script = None
try:
res_python = f(a)
except Exception as e:
res_python = e
try:
res_script = f_script(a)
except Exception as e:
res_script = e
if res_python != res_script:
if isinstance(res_python, Exception):
continue
msg = f"Failed on {func_name} with input {a}. Python: {res_python}, Script: {res_script}"
self.assertEqual(res_python, res_script, msg=msg)
unary_ops = ['log', 'log10', 'sqrt', 'exp', 'sin', 'cos', 'asin', 'acos', 'atan', 'sinh', 'cosh',
'tanh', 'asinh', 'acosh', 'atanh', 'phase', 'isinf', 'isnan', 'isfinite']
# --- Unary ops ---
for op in unary_ops:
checkCmath(op)
def fn(x: complex):
return abs(x)
for val in complex_vals:
self.checkScript(fn, (val, ))
def pow_complex_float(x: complex, y: float):
return pow(x, y)
def pow_float_complex(x: float, y: complex):
return pow(x, y)
self.checkScript(pow_float_complex, (2, 3j))
self.checkScript(pow_complex_float, (3j, 2))
def pow_complex_complex(x: complex, y: complex):
return pow(x, y)
for x, y in zip(complex_vals, complex_vals):
# Reference: https://github.com/pytorch/pytorch/issues/54622
if (x == 0):
continue
self.checkScript(pow_complex_complex, (x, y))
if not IS_MACOS:
# --- Binary op ---
def rect_fn(x: float, y: float):
return cmath.rect(x, y)
for x, y in product(vals, vals):
self.checkScript(rect_fn, (x, y, ))
func_constants_template = dedent('''
def func():
return cmath.{func_or_const}
''')
float_consts = ['pi', 'e', 'tau', 'inf', 'nan']
complex_consts = ['infj', 'nanj']
for x in (float_consts + complex_consts):
checkCmath(x, funcs_template=func_constants_template)
def test_infj_nanj_pickle(self):
class ComplexModule(torch.jit.ScriptModule):
def __init__(self):
super().__init__()
self.a = 3 + 5j
@torch.jit.script_method
def forward(self, infj: int, nanj: int):
if infj == 2:
return infj + cmath.infj
else:
return nanj + cmath.nanj
loaded = self.getExportImportCopy(ComplexModule())
self.assertEqual(loaded(2, 3), 2 + cmath.infj)
self.assertEqual(loaded(3, 4), 4 + cmath.nanj)
def test_complex_constructor(self):
# Test all scalar types
def fn_int(real: int, img: int):
return complex(real, img)
self.checkScript(fn_int, (0, 0, ))
self.checkScript(fn_int, (-1234, 0, ))
self.checkScript(fn_int, (0, -1256, ))
self.checkScript(fn_int, (-167, -1256, ))
def fn_float(real: float, img: float):
return complex(real, img)
self.checkScript(fn_float, (0.0, 0.0, ))
self.checkScript(fn_float, (-1234.78, 0, ))
self.checkScript(fn_float, (0, 56.18, ))
self.checkScript(fn_float, (-1.9, -19.8, ))
def fn_bool(real: bool, img: bool):
return complex(real, img)
self.checkScript(fn_bool, (True, True, ))
self.checkScript(fn_bool, (False, False, ))
self.checkScript(fn_bool, (False, True, ))
self.checkScript(fn_bool, (True, False, ))
def fn_bool_int(real: bool, img: int):
return complex(real, img)
self.checkScript(fn_bool_int, (True, 0, ))
self.checkScript(fn_bool_int, (False, 0, ))
self.checkScript(fn_bool_int, (False, -1, ))
self.checkScript(fn_bool_int, (True, 3, ))
def fn_int_bool(real: int, img: bool):
return complex(real, img)
self.checkScript(fn_int_bool, (0, True, ))
self.checkScript(fn_int_bool, (0, False, ))
self.checkScript(fn_int_bool, (-3, True, ))
self.checkScript(fn_int_bool, (6, False, ))
def fn_bool_float(real: bool, img: float):
return complex(real, img)
self.checkScript(fn_bool_float, (True, 0.0, ))
self.checkScript(fn_bool_float, (False, 0.0, ))
self.checkScript(fn_bool_float, (False, -1.0, ))
self.checkScript(fn_bool_float, (True, 3.0, ))
def fn_float_bool(real: float, img: bool):
return complex(real, img)
self.checkScript(fn_float_bool, (0.0, True, ))
self.checkScript(fn_float_bool, (0.0, False, ))
self.checkScript(fn_float_bool, (-3.0, True, ))
self.checkScript(fn_float_bool, (6.0, False, ))
def fn_float_int(real: float, img: int):
return complex(real, img)
self.checkScript(fn_float_int, (0.0, 1, ))
self.checkScript(fn_float_int, (0.0, -1, ))
self.checkScript(fn_float_int, (1.8, -3, ))
self.checkScript(fn_float_int, (2.7, 8, ))
def fn_int_float(real: int, img: float):
return complex(real, img)
self.checkScript(fn_int_float, (1, 0.0, ))
self.checkScript(fn_int_float, (-1, 1.7, ))
self.checkScript(fn_int_float, (-3, 0.0, ))
self.checkScript(fn_int_float, (2, -8.9, ))
def test_torch_complex_constructor_with_tensor(self):
tensors = ([torch.rand(1), torch.randint(-5, 5, (1, )), torch.tensor([False])])
def fn_tensor_float(real, img: float):
return complex(real, img)
def fn_tensor_int(real, img: int):
return complex(real, img)
def fn_tensor_bool(real, img: bool):
return complex(real, img)
def fn_float_tensor(real: float, img):
return complex(real, img)
def fn_int_tensor(real: int, img):
return complex(real, img)
def fn_bool_tensor(real: bool, img):
return complex(real, img)
for tensor in tensors:
self.checkScript(fn_tensor_float, (tensor, 1.2))
self.checkScript(fn_tensor_int, (tensor, 3))
self.checkScript(fn_tensor_bool, (tensor, True))
self.checkScript(fn_float_tensor, (1.2, tensor))
self.checkScript(fn_int_tensor, (3, tensor))
self.checkScript(fn_bool_tensor, (True, tensor))
def fn_tensor_tensor(real, img):
return complex(real, img) + complex(2)
for x, y in product(tensors, tensors):
self.checkScript(fn_tensor_tensor, (x, y, ))
def test_comparison_ops(self):
def fn1(a: complex, b: complex):
return a == b
def fn2(a: complex, b: complex):
return a != b
def fn3(a: complex, b: float):
return a == b
def fn4(a: complex, b: float):
return a != b
x, y = 2 - 3j, 4j
self.checkScript(fn1, (x, x))
self.checkScript(fn1, (x, y))
self.checkScript(fn2, (x, x))
self.checkScript(fn2, (x, y))
x1, y1 = 1 + 0j, 1.0
self.checkScript(fn3, (x1, y1))
self.checkScript(fn4, (x1, y1))
def test_div(self):
def fn1(a: complex, b: complex):
return a / b
x, y = 2 - 3j, 4j
self.checkScript(fn1, (x, y))
def test_complex_list_sum(self):
def fn(x: List[complex]):
return sum(x)
self.checkScript(fn, (torch.randn(4, dtype=torch.cdouble).tolist(), ))
def test_tensor_attributes(self):
def tensor_real(x):
return x.real
def tensor_imag(x):
return x.imag
t = torch.randn(2, 3, dtype=torch.cdouble)
self.checkScript(tensor_real, (t, ))
self.checkScript(tensor_imag, (t, ))
def test_binary_op_complex_tensor(self):
def mul(x: complex, y: torch.Tensor):
return x * y
def add(x: complex, y: torch.Tensor):
return x + y
def eq(x: complex, y: torch.Tensor):
return x == y
def ne(x: complex, y: torch.Tensor):
return x != y
def sub(x: complex, y: torch.Tensor):
return x - y
def div(x: complex, y: torch.Tensor):
return x - y
ops = [mul, add, eq, ne, sub, div]
for shape in [(1, ), (2, 2)]:
x = 0.71 + 0.71j
y = torch.randn(shape, dtype=torch.cfloat)
for op in ops:
eager_result = op(x, y)
scripted = torch.jit.script(op)
jit_result = scripted(x, y)
self.assertEqual(eager_result, jit_result)