| # Owner(s): ["oncall: jit"] |
| |
| import os |
| import sys |
| |
| import torch |
| from torch.testing._internal.jit_utils import JitTestCase |
| from torch.testing._internal.common_utils import IS_WINDOWS |
| from collections import namedtuple |
| from typing import List, Tuple, Optional, Dict |
| |
| # Make the helper files in test/ importable |
| pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) |
| sys.path.append(pytorch_test_dir) |
| |
| if __name__ == "__main__": |
| raise RuntimeError( |
| "This test file is not meant to be run directly, use:\n\n" |
| "\tpython test/test_jit.py TESTNAME\n\n" |
| "instead." |
| ) |
| |
| class TestTyping(JitTestCase): |
| def test_dict_in_not_in(self): |
| def test_in_dict(x): |
| # type: (Dict[str, int]) -> bool |
| return 'hi' in x |
| |
| self.checkScript(test_in_dict, ({'hi': 2, 'bye': 3},)) |
| self.checkScript(test_in_dict, ({'bye': 3},)) |
| |
| # Check evaluation order |
| @torch.jit.script |
| def a(): |
| print("a") |
| return 3 |
| |
| @torch.jit.script |
| def b(): |
| print("b") |
| return {3: 2, 4: 1} |
| |
| @torch.jit.script |
| def fn(): |
| return a() in b() |
| |
| with self.capture_stdout() as captured: |
| self.assertTrue(fn()) |
| if not IS_WINDOWS: |
| # no stdout capturing on windows |
| self.assertEqual(captured[0], "a\nb\n") |
| |
| def test_not_in_dict(a): |
| # type: (Dict[str, int]) -> bool |
| if "hello" not in a: |
| return False |
| else: |
| return True |
| |
| self.checkScript(test_not_in_dict, ({"hello": 1, "world": 2}, )) |
| self.checkScript(test_not_in_dict, ({"world": 2}, )) |
| |
| def test_dict_tensor_key(a, t): |
| # type: (Dict[Tensor, int], Tensor) -> bool |
| if t in a: |
| return True |
| else: |
| return False |
| |
| inp1 = torch.tensor(3) |
| inp2 = torch.tensor(5) |
| dict_a = {inp1: 1, inp2: 3} |
| self.checkScript(test_dict_tensor_key, (dict_a, torch.tensor(4))) |
| self.checkScript(test_dict_tensor_key, (dict_a, torch.tensor(3))) |
| self.checkScript(test_dict_tensor_key, (dict_a, inp1)) |
| self.checkScript(test_dict_tensor_key, (dict_a, inp2)) |
| |
| def test_list_type_refinement_annotation_element_mismatch(self): |
| def fn(): |
| l: List[int] = [1, 2, "foo", 3] |
| return l |
| |
| with self.assertRaisesRegex(RuntimeError, "List type annotation" |
| r" `List\[int\]` did not match the " |
| "types of the given list elements"): |
| torch.jit.script(fn) |
| |
| def test_dict_type_refinement_annotation_key_mismatch(self): |
| def fn(): |
| l1 = [1, 2, "foo", 3] |
| l2 = ["foo", "bar", "baz", "qux"] |
| d: Dict[int, str] = {k : v for k, v in zip(l1, l2)} |
| return d |
| |
| with self.assertRaisesRegex(RuntimeError, "Dicts may only " |
| "contain homogeneous keys, but the " |
| "type of the first generated key " |
| r"was Union\[int, str\]"): |
| torch.jit.script(fn) |
| |
| def test_dict_type_refinement_annotation_value_mismatch(self): |
| def fn(): |
| l1 = ["foo", "bar", "baz", "qux"] |
| l2 = [1, 2, "foo", 3] |
| d: Dict[str, int] = {k : v for k, v in zip(l1, l2)} |
| return d |
| |
| with self.assertRaisesRegex(RuntimeError, "Dict type annotation" |
| r" `Dict\[str, int\]` did not match" |
| " the type of an actual value type" |
| r" `Union\[int, str\]`"): |
| torch.jit.script(fn) |
| |
| def test_dict_invalid_annotations(self): |
| # Check for invalid value type annotation |
| def wrong_value_type(dictionary: Dict[str, torch.jit.ScriptModule]): |
| return |
| with self.assertRaisesRegex(ValueError, "Unknown type annotation"): |
| torch.jit.script(wrong_value_type) |
| |
| # Check for invalid key type annotation |
| def wrong_key_type(dictionary: Dict[torch.jit.ScriptModule, str]): |
| return |
| with self.assertRaisesRegex(ValueError, "Unknown type annotation"): |
| torch.jit.script(wrong_key_type) |
| |
| # Check for invalid key and value type annotation |
| def wrong_key_value_type(dictionary: Dict[torch.jit.ScriptModule, torch.jit.ScriptModule]): |
| return |
| with self.assertRaisesRegex(ValueError, "Unknown type annotation"): |
| torch.jit.script(wrong_key_value_type) |
| |
| def test_tuple_specialization(self): |
| @torch.jit.script |
| def f(t, s): |
| # type: (Tuple[Tensor, Tuple[int, Tensor]], str) -> Tensor |
| x, t2 = t |
| _, y = t2 |
| return x + y |
| |
| t = torch.randn(2, 2), (1, torch.randn(2, 2)), |
| f(t, "hi") |
| graph = f.graph_for(t, "hi") |
| input_types = list(next(graph.inputs()).type().elements()) |
| w = input_types[0] |
| self.assertEqual(input_types[0].kind(), 'TensorType') |
| self.assertEqual(input_types[1].elements()[1].kind(), 'TensorType') |
| |
| def test_tuple_io(self): |
| def stuff(x): |
| # type: (Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tensor] |
| a, b = x |
| return b, a |
| |
| a = (torch.rand(3), torch.rand(3)) |
| self.checkScript(stuff, (a,)) |
| |
| def test_tuple_keyword(self): |
| def bar(): |
| f = tuple((1, 2)) # noqa: C409 |
| return f |
| |
| self.checkScript(bar, ()) |
| |
| def foo(): |
| return tuple(1, 2) |
| |
| self.checkScriptRaisesRegex(foo, (), Exception, |
| "1 argument") |
| |
| def cant_infer_size(): |
| return tuple([1, 2, 3]) # noqa: C409 |
| |
| with self.assertRaisesRegex(Exception, "cannot statically infer the expected"): |
| torch.jit.script(cant_infer_size) |
| |
| def test_tuple_create_return(self): |
| def stuff2(x): |
| # type: (int) -> Tuple[Tensor, Tensor] |
| a = (torch.ones(x), torch.zeros(x)) |
| return a |
| self.checkScript(stuff2, (3,)) |
| |
| def test_list_io(self): |
| def stuff3(x): |
| # type: (List[int]) -> Tuple[Tensor, List[int]] |
| return torch.ones(x), x |
| self.checkScript(stuff3, ([3, 2],)) |
| |
| def test_bool_list_io(self): |
| @torch.jit.script |
| def stuff4(x): |
| # type: (List[bool]) -> Tuple[List[bool], List[bool], List[List[bool]]] |
| return x, [True, False], [[True]] |
| |
| li_1, li_2, li_3 = stuff4([True]) |
| li_3 = li_3[0] |
| for li in [li_1, li_2, li_3]: |
| self.assertTrue(type(li[0]) == type(True)) |
| |
| def test_nested_list(self): |
| def foo(z): |
| # type: (Tuple[int, List[List[int]]]) -> int |
| x, y = z |
| return y[0][1] |
| self.checkScript(foo, ((1, [[1, 2], [3, 4]]),)) |
| |
| def test_list_sum(self): |
| def fn(x: List[int]) -> int: |
| return sum(x) |
| |
| def fn1(x: List[float]): |
| return sum(x) |
| |
| def fn2(x: List[bool]): |
| return sum(x) |
| |
| self.checkScript(fn, ([1, 2, 3], )) |
| self.checkScript(fn1, ([1.0, 2.0, 3.0], )) |
| self.checkScript(fn1, ([1, 2.8, 3], )) |
| self.checkScript(fn2, ([True, False, False], )) |
| self.checkScript(fn2, ([False, False, False], )) |
| self.checkScript(fn2, ([0, 1, 1, 0], )) |
| |
| def test_list_unification(self): |
| def fn(): |
| return [1, None, 2] |
| |
| def fn2(x): |
| return [torch.ones(2, 2), None, x] |
| |
| self.checkScript(fn, []) |
| self.checkScript(fn2, (torch.ones(2, 2),)) |
| |
| # to avoid defining sum_list in multiple tests |
| def get_sum_list_fn(self): |
| def sum_list(a): |
| # type: (List[int]) -> int |
| sum = 0 |
| for i in a: |
| sum += i |
| |
| return sum |
| |
| return sum_list |
| |
| def test_sum_list_diff_elms(self): |
| self.checkScript(self.get_sum_list_fn(), ([1, 2, 3, 4, 5],)) |
| |
| def test_sum_list_empty(self): |
| self.checkScript(self.get_sum_list_fn(), ([],)) |
| |
| def test_sum_list_one(self): |
| self.checkScript(self.get_sum_list_fn(), ([1],)) |
| |
| def test_sum_list_literal(self): |
| |
| def sum_list(): |
| # type: () -> int |
| sum = 0 |
| for i in [1, 2, 3, 4, 5]: |
| sum += i |
| |
| return sum |
| |
| self.checkScript(sum_list, ()) |
| |
| def test_sum_list_wrong_type(self): |
| |
| with self.assertRaisesRegex(RuntimeError, "'int' object is not iterable"): |
| @torch.jit.script |
| def sum_list(a): |
| # type: (int) -> int |
| sum = 0 |
| for i in a: # noqa: T484 |
| sum += i |
| |
| return sum |
| |
| sum_list(1) |
| |
| def test_list_iterables(self): |
| with self.assertRaisesRegex(RuntimeError, 'List of iterables is not supported currently'): |
| cu = torch.jit.CompilationUnit(''' |
| def list_iterables(x): |
| for i, j in [2, 3, 4], [5, 6, 7]: |
| x += i |
| x += j |
| return x |
| ''') |
| |
| def test_for_in_string(self): |
| def test_strings(x): |
| # type: (str) -> str |
| reverse = "" |
| for c in x: |
| reverse = c + reverse |
| return reverse |
| |
| self.checkScript(test_strings, ("hello",)) |
| self.checkScript(test_strings, ("",)) |
| |
| def test_list_strings(x): |
| # type: (List[str]) -> str |
| result = "" |
| for sub_str in x: |
| result += sub_str |
| return result |
| |
| self.checkScript(test_list_strings, (["hello", "world"],)) |
| self.checkScript(test_list_strings, (["hello", " ", "world", ""],)) |
| |
| def test_for_in_dict(self): |
| def test_dicts(x): |
| # type: (Dict[str, int]) -> int |
| sum = 0 |
| for key in x: |
| sum += x[key] |
| return sum |
| |
| self.checkScript(test_dicts, ({"a": 1, "b": 2, "c": 3},)) |
| |
| def test_dict_keys_values(x): |
| # type: (Dict[str, int]) -> Tuple[str, int] |
| key_str = "" |
| sum = 0 |
| for key in x.keys(): |
| key_str += key |
| for val in x.values(): |
| sum += val |
| return key_str, sum |
| |
| self.checkScript(test_dicts, ({"a": 1, "b": 2, "c": 3},)) |
| |
| def test_for_tuple_unpack(self): |
| def for_tuple_unpack(x, y): |
| for i, j in [[3, 4], [5, 6], [7, 8]]: |
| x += i |
| y += j |
| return x, y |
| |
| self.checkScript(for_tuple_unpack, (torch.tensor(3), torch.tensor(5))) |
| |
| def nested_tuple_unpack(x, y): |
| # type: (List[int], List[int]) -> int |
| sum = 0 |
| for i, (j, k), v in zip(x, enumerate(x), y): |
| sum += i + j + k + v |
| return sum |
| |
| self.checkScript(nested_tuple_unpack, ([1, 3, 5], [2, 4, 6])) |
| |
| def test_dict_comprehension(self): |
| def fn(): |
| return {i : chr(i + 65) for i in range(4)} |
| self.checkScript(fn, ()) |
| |
| def test_dict_comprehension_with_type_annotation(self): |
| def fn(): |
| d: Dict[int, str] = {i : chr(i + 65) for i in range(4)} |
| return d |
| self.checkScript(fn, ()) |
| |
| with self.assertRaisesRegex(RuntimeError, ""): |
| with self.assertRaisesRegex(AssertionError, "Expected Dict " |
| "type annotation for dict " |
| "comprehension, found " |
| "Tuple[int, str]"): |
| @torch.jit.script |
| def fn(): |
| d: Tuple[int, str] = {i : chr(i + 65) for i in range(4)} |
| return d |
| |
| def test_dict_comprehension_scope(self): |
| def comprehension_can_access_outer_scope_variables(): |
| lst = ["foo", "bar", "baz"] |
| return {l : len(l) for l in lst} |
| |
| self.checkScript(comprehension_can_access_outer_scope_variables, ()) |
| |
| with self.assertRaisesRegex(RuntimeError, "undefined value i"): |
| @torch.jit.script |
| def outer_scope_cannot_access_comprehension_variables(): |
| d = {i : chr(i + 65) for i in range(4)} |
| i = i + 1 |
| |
| def test_for_tuple_assign(self): |
| def test_simple_assign(x): |
| # type: (Tuple[int, float]) -> float |
| sum = 0.0 |
| for a in x: |
| sum += float(a) |
| return sum |
| |
| self.checkScript(test_simple_assign, ((1, 2.5),)) |
| |
| def test_tuple_assign(x): |
| # type: (Tuple[Tuple[int, int], Tuple[int, int]]) -> int |
| sum = 0 |
| for a in x: |
| sum += a[0] |
| sum += a[1] |
| return sum |
| |
| self.checkScript(test_tuple_assign, (((1, 2), (4, 7)), )) |
| |
| def test_single_starred_lhs(self): |
| with self.assertRaisesRegex(RuntimeError, 'A Starred expression may only appear on the lhs within the presence' |
| ' of another non-starred expression'): |
| cu = torch.jit.CompilationUnit(''' |
| def single_starred_lhs(x): |
| a = (x, x, x) |
| *b, = a |
| return b |
| ''') |
| |
| def test_singleton_tuple_unpack(self): |
| def foo(a): |
| b, = (a,) |
| return b + 1 |
| self.checkScript(foo, (torch.rand(3),)) |
| |
| def test_tuple_assignments(self): |
| def var_tuple_assign(x, y): |
| # type: (Tuple[Tensor, Tensor], Tensor) -> Tensor |
| (a, b), c = x, y |
| return a + b + c |
| |
| tuple_inputs = (torch.randn(1, 4), torch.randn(3, 4)) |
| self.checkScript(var_tuple_assign, (tuple_inputs, torch.randn(3, 4))) |
| |
| def nested_tuple_assign(x, y, z): |
| # type: (int, Tuple[int, Tuple[int, int]], Tuple[int, int]) -> int |
| a, (b, (c, d)), (e, f) = x, y, z |
| return a + b + c + d + e + f |
| |
| self.checkScript(nested_tuple_assign, ((1, (2, (3, 4)), (5, 6)))) |
| |
| def subscript_tuple_assign(a, x, i): |
| # type: (List[int], Tensor, int) -> Tuple[int, Tensor, int] |
| a[i], (x[i], b) = 1, (2, 3) |
| return a[i] + 1, x + 5, b |
| |
| self.checkScript(subscript_tuple_assign, ([12, 7, 9, 11], torch.tensor((3, 13, 17)), 0)) |
| |
| def star_tuple_assign(): |
| # type: () -> Tuple[int, int, Tuple[int, int], Tuple[int, int]] |
| a, (b, *c), *d = 1, (2, 3, 4), 5, 6 |
| return a, b, c, d |
| |
| self.checkScript(star_tuple_assign, ()) |
| |
| def subscript_tuple_augmented_assign(a): |
| # type: (Tuple[int, int]) -> Tuple[int, int] |
| a[0] += 1 |
| return a |
| |
| with self.assertRaisesRegex(RuntimeError, 'does not support augmented assign'): |
| scripted_aug_assign = torch.jit.script(subscript_tuple_augmented_assign) |
| |
| def test_multiple_assign(self): |
| def test(): |
| a = b, c = d, f = (1, 1) |
| |
| # side effect |
| ten = torch.tensor(1) |
| ten1 = ten2 = ten.add_(1) |
| |
| # ordering |
| x = 1 |
| y = 3 |
| x, y = y, x + y |
| |
| return a, b, c, d, f, ten, ten1, ten2, x, y |
| |
| self.checkScript(test, ()) |
| |
| def test_opt_opt_refinement(self): |
| @torch.jit.script |
| def test_unify(weight, bias): |
| # type: (Optional[int], Optional[int]) -> Optional[int] |
| if weight is not None: |
| opt = None |
| else: |
| if bias is not None: |
| opt = 1 |
| else: |
| opt = None |
| |
| return opt |
| |
| def test_optional_refinement(self): |
| @torch.jit.script |
| def test_if_none_assignment(x): |
| # type: (Optional[int]) -> int |
| if x is None: |
| x = 1 |
| return x + 1 |
| |
| self.assertEqual(test_if_none_assignment(1), 2) |
| |
| def test_optional_conversion(self): |
| @torch.jit.script |
| def other_fn(x=None): |
| # type: (Optional[int]) -> int |
| return torch.jit._unwrap_optional(x) |
| |
| |
| @torch.jit.script |
| def fn(x): |
| # type: (int) -> int |
| return other_fn(x) |
| |
| self.assertEqual(fn(2), 2) |
| |
| @torch.jit.script |
| def unify_to_optional(x): |
| # type: (bool) -> Optional[int] |
| if x: |
| a = None |
| else: |
| a = 2 |
| return a |
| |
| self.assertEqual(unify_to_optional(True), None) |
| self.assertEqual(unify_to_optional(False), 2) |
| |
| @torch.jit.script |
| def opt_list(x): |
| # type: (Optional[List[float]]) -> int |
| return 2 |
| |
| @torch.jit.script |
| def broadcast_opt_list(x): |
| # type: (Optional[BroadcastingList2[float]]) -> int |
| return 2 |
| |
| @torch.jit.script |
| def opt_list_tuple_caller(x): |
| # type: (Tuple[float, float]) -> int |
| return opt_list(x) + broadcast_opt_list(x) |
| |
| self.assertEqual(opt_list_tuple_caller((2., 3.)), 4) |
| |
| def test_optional_tuple(self): |
| def fn(x=None): |
| # type: (Optional[Tuple[int, int]]) -> Tuple[int, int] |
| if x is None: |
| new_x = (1, 2) |
| else: |
| new_x = x |
| return new_x |
| |
| self.checkScript(fn, ((3, 4),)) |
| self.checkScript(fn, ()) |
| |
| def test_namedtuple_redefine(self): |
| global _1, _2 |
| _1 = namedtuple('GoogLeNetOutputs', ['logits', 'aux_logits2', 'aux_logits1']) |
| _2 = namedtuple('GoogLeNetOutputs', ['different']) |
| |
| with self.assertRaisesRegex(RuntimeError, r'redefine'): |
| @torch.jit.script |
| def foo(x, y): |
| # type: (_1, _2) -> _1 |
| return x |
| |
| def test_namedtuple_py2(self): |
| global _GoogLeNetOutputs # see [local resolution in python] |
| _GoogLeNetOutputs = namedtuple('GoogLeNetOutputs', ['logits', 'aux_logits2', 'aux_logits1']) |
| |
| @torch.jit.script |
| def foo(x): |
| # type: (_GoogLeNetOutputs) -> _GoogLeNetOutputs |
| return x |
| |
| vals = torch.rand(3), torch.rand(4), torch.rand(5) |
| out = foo(_GoogLeNetOutputs(logits=vals[0], aux_logits2=vals[1], aux_logits1=vals[2])) |
| self.assertEqual(out.logits, vals[0]) |
| self.assertEqual(out.aux_logits2, vals[1]) |
| self.assertEqual(out.aux_logits1, vals[2]) |
| |
| def test_namedtuple_good_error(self): |
| global _GoogLeNetOutputs # see [local resolution in python] |
| _GoogLeNetOutputs = namedtuple('GoogLeNetOutputs', ['logits', 'aux_logits2', 'aux_logits1']) |
| |
| @torch.jit.script |
| def foo(x): |
| # type: (_GoogLeNetOutputs) -> _GoogLeNetOutputs |
| return x |
| |
| with self.assertRaisesRegex(RuntimeError, |
| r'aka NamedTuple\(logits, aux_logits2, aux_logits1\)'): |
| out = foo(_GoogLeNetOutputs(logits="3", aux_logits2="4", aux_logits1="5")) |