| # Owner(s): ["oncall: jit"] |
| |
| import io |
| import os |
| import sys |
| |
| import torch |
| from torch.testing import FileCheck |
| from enum import Enum |
| from textwrap import dedent |
| from typing import Dict, List, Optional, Tuple, Union |
| |
| # Make the helper files in test/ importable |
| pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) |
| sys.path.append(pytorch_test_dir) |
| from torch.testing._internal.jit_utils import JitTestCase, make_global |
| |
| if __name__ == '__main__': |
| raise RuntimeError("This test file is not meant to be run directly, use:\n\n" |
| "\tpython test/test_jit.py TESTNAME\n\n" |
| "instead.") |
| |
| class TestUnion(JitTestCase): |
| """ |
| This class tests the functionality of `Union`. |
| |
| Note: It's important to be able to refine the type of a `Union` to |
| one of its internal types. Currently, there are differences in the |
| way Python expects `isinstance` checks and the way TorchScript |
| expects `isinstance` checks. This means that we can't use |
| `checkScript` in our test cases because either the eager mode or the |
| script mode wouldn't run! So, some test cases have separate but |
| equivalent functions to emulate `checkScript`. |
| """ |
| |
| def test_check_union_annotation(self): |
| def test_func(a: Union[int, float], b: Optional[int]): |
| return 0 |
| |
| scripted_func = torch.jit.script(test_func) |
| graph_rep = str(scripted_func.graph) |
| code_rep = str(scripted_func.code) |
| # TS graph IR for Union should be annotated as Union() |
| FileCheck().check("Union(").check("int?").run(graph_rep) |
| # Serialized code for Union should be annotated as Union[] |
| FileCheck().check("Union[").check("Optional[int]").run(code_rep) |
| self.checkScript(test_func, (5, 6)) |
| # this shouldn't error out |
| torch._C.parse_ir(str(scripted_func.graph)) |
| |
| def test_union_with_scalar_values(self): |
| def fn(x: Union[int, float]) -> str: |
| return "foo" |
| |
| self.checkScript(fn, (1,)) |
| self.checkScript(fn, (1.0,)) |
| |
| scripted = torch.jit.script(fn) |
| |
| with self.assertRaisesRegex(RuntimeError, "Expected a member of" |
| r" Union\[float, int\] but " |
| "instead found type str"): |
| scripted("1") |
| |
| def test_union_with_collections(self): |
| def fn(x: Union[Dict[str, int], List[int]]) -> str: |
| return "foo" |
| |
| self.checkScript(fn, ({"foo": 1, "bar": 2, "baz": 3},)) |
| self.checkScript(fn, ([1, 2, 3],)) |
| |
| scripted = torch.jit.script(fn) |
| |
| with self.assertRaisesRegex(RuntimeError, "Expected a member of" |
| r" Union\[List\[int\], Dict\[str, " |
| r"int\]\] but instead found type " |
| r"Dict\[str, str\]"): |
| scripted({"foo": "bar", "baz": "qux"}) |
| |
| with self.assertRaisesRegex(RuntimeError, "Expected a member of" |
| r" Union\[List\[int\], Dict\[str, " |
| r"int\]\] but instead found type " |
| r"List\[str\]"): |
| scripted(["foo", "bar", "baz"]) |
| |
| with self.assertRaisesRegex(RuntimeError, "Expected a member of" |
| r" Union\[List\[int\], Dict\[str, " |
| r"int\]\] but instead found type " |
| "str"): |
| scripted("1") |
| |
| def test_union_with_enum(self): |
| class Color(Enum): |
| RED = 1 |
| GREEN = 2 |
| |
| make_global(Color) |
| |
| def fn(x: Union[str, Color]) -> str: |
| return "foo" |
| |
| self.checkScript(fn, (Color.RED,)) |
| self.checkScript(fn, ("red",)) |
| |
| scripted = torch.jit.script(fn) |
| |
| with self.assertRaisesRegex(RuntimeError, "Expected a member of" |
| r" Union\[__torch__.jit.test_union." |
| r"Color, str\] but instead found " |
| "type int"): |
| scripted(1) |
| |
| def test_union_in_class_constructor(self): |
| |
| @torch.jit.script # noqa: B903 |
| class A(object): # noqa: B903 |
| def __init__(self, x: Union[int, str]) -> None: |
| self.x = x |
| |
| def fn(x: Union[str, int]) -> A: |
| return A(x) |
| |
| self.assertEqual(fn("foo").x, "foo") |
| self.assertEqual(fn(1).x, 1) |
| |
| scripted = torch.jit.script(fn) |
| |
| with self.assertRaisesRegex(RuntimeError, "Expected a member of" |
| r" Union\[int, str\] but instead " |
| r"found type List\[str\]"): |
| scripted(["foo", "bar", "baz"]) |
| |
| def test_union_return_type(self): |
| def fn(x: int) -> Union[int, str]: |
| return "foo" |
| |
| self.checkScript(fn, (1,)) |
| |
| def test_union_as_annotation(self): |
| def fn() -> Union[int, str]: |
| x: Union[int, str] = "foo" |
| return x |
| |
| self.checkScript(fn, ()) |
| |
| def test_union_as_annotation_in_typed_container(self): |
| def fn() -> None: |
| l: List[Union[int, str]] = [] |
| u1: Union[int, str] = "foo" |
| u2: Union[int, str] = 1 |
| l.append(u1) |
| l.append(u2) |
| |
| self.checkScript(fn, ()) |
| |
| def test_union_as_annotation_py2(self): |
| def fn(): |
| # type: () -> Union[int, str] |
| x: Union[int, str] = "foo" |
| return x |
| |
| self.checkScript(fn, ()) |
| |
| def test_union_as_internal_tuple_type(self): |
| def fn(): |
| t: Tuple[Union[int, str], Union[int, str]] = (1, "foo") |
| return t |
| |
| self.checkScript(fn, ()) |
| |
| def test_union_variable_can_be_reassigned(self): |
| @torch.jit.script |
| def aux1(i: int): |
| return int(i ** 2) |
| |
| @torch.jit.script |
| def aux2(s: str): |
| return s + s |
| |
| def fn() -> Union[int, str]: |
| x: Union[int, str] = "foo" |
| i: int = 1 |
| x = i |
| y: int = aux1(x) |
| z: str = aux2(str(y)) |
| x = z |
| return x |
| |
| self.checkScript(fn, ()) |
| |
| def test_union_does_not_replace_existing_annotated_type(self): |
| def fn(): |
| x: List[int] = [1, 2, 3] |
| x.append("foo") |
| return x |
| |
| with self.assertRaisesRegex(RuntimeError, "Could not match type str"): |
| scripted = torch.jit.script(fn) |
| scripted() |
| |
| def test_union_does_not_replace_existing_annotated_type_union(self): |
| def fn(): |
| x: List[Union[int, str]] = [1, "foo", 3] |
| x.append(2.0) |
| return x |
| |
| with self.assertRaisesRegex(RuntimeError, "Could not match type float"): |
| scripted = torch.jit.script(fn) |
| scripted() |
| |
| def test_union_does_not_replace_existing_annotated_type_empty_container(self): |
| def fn(): |
| x: List[int] = [] |
| x.append("foo") |
| return x |
| |
| with self.assertRaisesRegex(RuntimeError, "Could not match type str"): |
| scripted = torch.jit.script(fn) |
| scripted() |
| |
| def test_unions_of_unions_are_flattened(self): |
| @torch.jit.script |
| def fn(x: Union[Union[int, str], float]) -> str: |
| return "foo" |
| |
| s = fn.graph |
| |
| FileCheck().check("x : Union(float, int, str)") \ |
| .run(s) |
| |
| def test_unions_of_a_single_argument_vanish(self): |
| @torch.jit.script |
| def fn(x: Union[int]) -> str: |
| return "foo" |
| |
| s = fn.graph |
| |
| FileCheck().check("x : int") \ |
| .run(s) |
| |
| def test_union_redundant_arguments_are_skipped(self): |
| @torch.jit.script |
| def fn(x: Union[int, str, int]) -> str: |
| return "foo" |
| |
| s = fn.graph |
| |
| FileCheck().check("x : Union(int, str)") \ |
| .run(s) |
| |
| def test_union_redundant_arguments_are_skipped_optional(self): |
| @torch.jit.script |
| def fn(x: Union[int, Optional[float], Optional[int]]) -> str: |
| return "foo" |
| |
| s = fn.graph |
| |
| FileCheck().check("x : Union(float, int, NoneType)") \ |
| .run(s) |
| |
| def test_union_redundant_arguments_are_skipped_subtyping(self): |
| @torch.jit.script |
| def fn(x: Union[str, Tuple[Optional[int], int], Tuple[int, int]]) -> str: |
| return "foo" |
| |
| s = fn.graph |
| |
| FileCheck().check("x : Union((int?, int), str)") \ |
| .run(s) |
| |
| def test_union_redundant_arguments_are_skipped_container(self): |
| @torch.jit.script |
| def fn(x: Union[List[str], List[float], List[str]]) -> str: |
| return "foo" |
| |
| s = fn.graph |
| |
| FileCheck().check("x : Union(float[], str[])") \ |
| .run(s) |
| |
| def test_union_argument_order_is_ignored(self): |
| @torch.jit.script |
| def fn1(x: Union[int, str]) -> str: |
| return "foo" |
| |
| @torch.jit.script |
| def fn2(x: Union[str, int]) -> str: |
| return "foo" |
| |
| for s in (fn1.graph, fn2.graph): |
| FileCheck().check("x : Union(int, str)") \ |
| .run(s) |
| |
| def test_union_argument_order_is_ignored_container(self): |
| @torch.jit.script |
| def fn1(x: Union[List[str], List[int]]) -> str: |
| return "foo" |
| |
| @torch.jit.script |
| def fn2(x: Union[List[int], List[str]]) -> str: |
| return "foo" |
| |
| for s in (fn1.graph, fn2.graph): |
| FileCheck().check("x : Union(int[], str[])") \ |
| .run(s) |
| |
| def test_union_T_None_is_equivalent_to_optional_T(self): |
| @torch.jit.script |
| def inner(x: Union[int, None]) -> int: |
| if x is not None: |
| return x |
| else: |
| return 5 |
| |
| @torch.jit.script |
| def fn1() -> int: |
| a: Optional[int] = 5 |
| b: Optional[int] = None |
| a_ = inner(a) |
| b_ = inner(b) |
| return a_ + b_ |
| |
| self.assertEqual(fn1(), 10) |
| |
| @torch.jit.script |
| def inner2(x: Optional[int]) -> int: |
| if x is not None: |
| return x |
| else: |
| return 5 |
| |
| @torch.jit.script |
| def fn2() -> int: |
| a: Union[int, None] = 5 |
| b: Union[int, None] = None |
| a_ = inner(a) |
| b_ = inner(b) |
| return a_ + b_ |
| |
| self.assertEqual(fn2(), 10) |
| |
| def test_union_optional_of_union_is_flattened(self): |
| @torch.jit.script |
| def fn(flag: int) -> Union[str, int, None]: |
| y: Union[int, str, None] = "foo" |
| if flag == 0: |
| x: Optional[Union[int, str]] = y |
| elif flag == 1: |
| x: Optional[Union[int, str]] = 1 |
| else: |
| x: Optional[Union[int, str]] = None |
| return x |
| |
| # Can't use `checkScript` because it will flag the fact that |
| # the original code has `Optional[Union[int, str]]` but the |
| # saved/loaded code has `Union[int, NoneType, str]` (even |
| # though this is exactly what we want) |
| self.assertEqual(fn(0), "foo") |
| self.assertEqual(fn(1), 1) |
| self.assertEqual(fn(2), None) |
| |
| buffer = io.BytesIO() |
| torch.jit.save(fn, buffer) |
| buffer = io.BytesIO(buffer.getvalue()) |
| l = torch.jit.load(buffer) |
| |
| s = l.code |
| |
| FileCheck().check("Union[int, NoneType, str]") \ |
| .check("Union[int, NoneType, str]") \ |
| .run(s) |
| |
| def test_union_subclasses_larger_union(self): |
| def fn() -> Union[int, str, torch.Tensor]: |
| x: Union[int, str] = "foo" |
| return x |
| |
| self.checkScript(fn, ()) |
| |
| # TODO: We would like to eventually support this. The issue is being |
| # tracked at https://github.com/pytorch/pytorch/issues/58167 |
| def test_union_as_dict_key(self): |
| def fn(): |
| x: Dict[Union[int, str], str] = {} |
| x["foo"] = "bar" |
| x[1] = 2 |
| return x[1] |
| |
| with self.assertRaisesRegex(RuntimeError, "only int, float, " |
| "complex, Tensor, device and string keys " |
| "are supported"): |
| torch.jit.script(fn) |
| |
| def test_union_as_dict_value(self): |
| def fn(): |
| x: Dict[str, Union[int, str]] = {} |
| x["foo"] = "bar" |
| x["baz"] = 2 |
| return x["baz"] |
| |
| self.checkScript(fn, ()) |
| |
| def test_union_module_with_union_instance_variable(self): |
| class M(torch.nn.Module): |
| |
| x: Union[int, str] |
| |
| def __init__(self, x: Union[int, str]): |
| super().__init__() |
| self.x: Union[int, str] = x |
| |
| def forward(self, y: Union[int, str]): |
| self.x = y |
| return self.x |
| |
| self.checkModule(M(2,), (1,)) |
| self.checkModule(M("bar"), ("foo",)) |
| |
| def test_union_module_with_union_class_variable(self): |
| class M(torch.nn.Module): |
| x: Union[int, str] = "foo" |
| |
| def __init__(self, y: int): |
| super().__init__() |
| x = y |
| |
| def forward(self, z: str): |
| x = z |
| return x |
| |
| self.checkModule(M(1), ("foo",)) |
| |
| def test_union_type_refinement(self): |
| def fn(x: Union[int, str]) -> str: |
| if isinstance(x, str): |
| z = x + "bar" |
| return x |
| else: |
| return "baz" |
| |
| self.checkScript(fn, ("foo",)) |
| self.checkScript(fn, (1,)) |
| |
| def test_union_type_refinement_union_rhs(self): |
| def fn(x: int) -> str: |
| if torch.jit.isinstance(x, Union[int, str]): |
| return "bar" |
| else: |
| return "baz" |
| |
| self.checkScript(fn, (1,)) |
| |
| def test_union_type_refinement_tuple_rhs(self): |
| def fn(x: Union[int, float, List[str]]) -> str: |
| if isinstance(x, (int, float)): |
| if isinstance(x, int): |
| return str(x) |
| else: |
| return "foo" |
| else: |
| if len(x): |
| return x[0] |
| else: |
| return "bar" |
| |
| self.checkScript(fn, (1,)) |
| self.checkScript(fn, (1.0,)) |
| self.checkScript(fn, (["a", "b", "c"],)) |
| |
| def test_union_type_refinement_tuple_rhs_noncontained_type(self): |
| def fn(x: Union[int, List[str]]) -> str: |
| if isinstance(x, (int, float)): |
| y = x + x |
| return str(y) |
| else: |
| if len(x): |
| return x[0] |
| else: |
| return "bar" |
| |
| self.checkScript(fn, (1,)) |
| self.checkScript(fn, (["a", "b", "c"],)) |
| |
| def test_union_type_refinement_tuple_rhs_union(self): |
| @torch.jit.script |
| def fn(x: int) -> str: |
| if torch.jit.isinstance(x, (Union[int, str], float)): |
| y = x + x |
| return str(y) |
| else: |
| return "foo" |
| |
| # TODO: There's currently an unrelated bug in |
| # `torch.jit.isinstance` that makes it fail for tuple literals. |
| # Posted here: https://github.com/pytorch/pytorch/issues/60095 |
| # Change `assertEqual` to `checkScript` when the bug is fixed |
| self.assertEqual(fn(1), "2") |
| |
| def test_union_type_refinement_statically_false(self): |
| @torch.jit.script |
| def fn(x: int) -> str: |
| if torch.jit.isinstance(x, (Union[str, float], List[str], str)): |
| z = x + "foo" |
| return z |
| else: |
| return "bar" |
| |
| s = fn.graph |
| |
| # Check that we don't have any branching statements |
| FileCheck().check_not("block0()") \ |
| .check_not("block1()") \ |
| .run(s) |
| |
| def test_union_type_refinement_statically_true(self): |
| @torch.jit.script |
| def fn(x: Union[List[int], int]) -> Union[List[int], int]: |
| if not torch.jit.isinstance(x, (int, List[int])): |
| return x |
| else: |
| l = [1, 2, 3] |
| y: Union[List[int], int] = l |
| return y |
| |
| s = fn.graph |
| |
| # Check that we don't have any branching statements |
| FileCheck().check_not("block0()") \ |
| .check_not("block1()") \ |
| .run(s) |
| |
| def test_union_type_refinement_partial_static_refinement_tuple_rhs(self): |
| def fn(x: Union[List[int], int]) -> int: |
| if torch.jit.isinstance(x, (int, float, str)): |
| # We should know that `x` is an `int` here |
| z = x + 1 |
| return z |
| else: |
| return 100 |
| |
| self.checkScript(fn, ([1, 2, 3],)) |
| self.checkScript(fn, (1,)) |
| |
| def test_union_type_refinement_partial_static_refinement_union_rhs(self): |
| def fn(x: Union[List[int], int]) -> int: |
| if torch.jit.isinstance(x, Union[int, float, str]): |
| # We should know that `x` is an `int` here |
| z = x + 1 |
| return z |
| else: |
| return 100 |
| |
| self.checkScript(fn, ([1, 2, 3],)) |
| self.checkScript(fn, (1,)) |
| |
| def test_union_type_refinement_internal_declaration(self): |
| def fn(flag: bool) -> str: |
| x: Union[int, str, None] = None |
| if (flag): |
| y = "foo" |
| else: |
| y = 1 |
| if isinstance(x, str): |
| return x |
| else: |
| return "bar" |
| |
| self.checkScript(fn, (True,)) |
| self.checkScript(fn, (False,)) |
| |
| def test_union_branching_with_union_return_and_homogenous_types(self): |
| def fn(x: int) -> Union[int, str]: |
| if x % 2: |
| return "foo" |
| else: |
| return "bar" |
| |
| self.checkScript(fn, (1,)) |
| self.checkScript(fn, (8,)) |
| |
| def test_union_branching_does_not_autoinfer_undeclared_union(self): |
| def fn(x: int) -> str: |
| if x % 2: |
| y = "foo" |
| else: |
| y = x |
| if isinstance(y, str): |
| return y |
| else: |
| return "bar" |
| |
| with self.assertRaisesRegex(RuntimeError, "y is set to type str" |
| " in the true branch and type int " |
| "in the false branch"): |
| torch.jit.script(fn) |
| |
| def test_union_branching_does_not_widen_existing_inferred_type(self): |
| def fn(x: int) -> str: |
| y = "foo" |
| if x % 2: |
| y = "bar" |
| else: |
| y = x |
| if isinstance(y, str): |
| return y |
| else: |
| return "baz" |
| |
| with self.assertRaisesRegex(RuntimeError, "previously had type " |
| "str but is now being assigned to a" |
| " value of type int"): |
| torch.jit.script(fn) |
| |
| def test_union_schema_matching_on_internal_type(self): |
| def fn(x: Union[List[int], Dict[str, int]]) -> int: |
| if torch.jit.isinstance(x, List[int]): |
| return x[0] |
| else: |
| return list(x.values())[0] |
| |
| self.checkScript(fn, ([1, 2, 3],)) |
| self.checkScript(fn, ({"foo": 1, "bar": 2, "baz": 3},)) |
| |
| def test_union_subtractive_refinement(self): |
| def fn(x: Union[List[int], int]) -> int: |
| if not isinstance(x, int): |
| x.append(1) |
| return x[0] |
| else: |
| return x |
| |
| self.checkScript(fn, (1,)) |
| self.checkScript(fn, ([1, 2, 3],)) |
| |
| def test_union_subtractive_refinement_with_container(self): |
| def fn(x: Union[List[int], int]) -> int: |
| if not torch.jit.isinstance(x, List[int]): |
| return x |
| else: |
| x.append(1) |
| return x[0] |
| |
| self.checkScript(fn, (1,)) |
| self.checkScript(fn, ([1, 2, 3],)) |
| |
| def test_union_memory_aliasing(self): |
| def fn(): |
| x : List[torch.Tensor] = [] |
| z : List[Optional[List[torch.Tensor]]] = [] |
| z.append(x) |
| x_alias = z[0] |
| if torch.jit.isinstance(x_alias, List[torch.Tensor]): |
| x_alias.append(torch.tensor(3)) |
| return x |
| |
| self.checkScript(fn, ()) |
| |
| def test_union_serialization_preserves_type_annotations(self): |
| # This function will fail after being torch.jit.save'd and |
| # torch.jit.load'd if the type annotations aren't preserved |
| # for Union during serialization. We need the `Union[str, int]` |
| # annotation to make sure that `y` is typed as a Union instead |
| # of as a str in one branch and an int in the other |
| def fn(x: int) -> str: |
| if x % 2: |
| y: Union[str, int] = "bar" |
| else: |
| y: Union[str, int] = x |
| if isinstance(y, str): |
| return y |
| else: |
| return "baz" |
| |
| self.checkScript(fn, (1,)) |
| self.checkScript(fn, (8,)) |
| |
| def _assert_passes(self, template: str, ann: str, lhs: str): |
| code = template.format(ann=ann, lhs=lhs) |
| self.checkScript(code, (), name="fn") |
| |
| def _assert_raises(self, template: str, ann: str, lhs: str, msg: str): |
| code = template.format(ann=ann, lhs=lhs) |
| with self.assertRaisesRegex(RuntimeError, msg): |
| cu = torch.jit.CompilationUnit(code, _frames_up=1) |
| string_frontend = getattr(cu, "fn") # noqa: B009 |
| |
| def test_union_with_list_assignment(self): |
| template = dedent(''' |
| def fn(): |
| x: {ann} = {lhs} |
| if torch.jit.isinstance(x, List[torch.Tensor]): |
| x.append(torch.tensor(3)) |
| return x |
| ''') |
| |
| lhs = {"list_literal_empty" : "[]", |
| |
| "list_literal_of_tensor" : "[torch.arange(3), torch.arange(5)]", |
| |
| "list_literal_of_str" : "[\"foo\", \"bar\", \"baz\"]", |
| |
| "list_literal_of_mixed" : "[torch.arange(5), 1]", |
| |
| "list_comprehension_of_tensor" : |
| "[torch.add(x, 1) for x in [torch.arange(3), torch.arange(5)]]", |
| |
| "list_comprehension_of_str" : |
| "[x + \"!\" for x in [\"foo\", \"bar\", \"baz\"]]", |
| |
| "list_comprehension_of_mixed" : |
| "[torch.add(1, x) for x in [torch.arange(5), 1]]"} |
| |
| """ |
| Union[List[str], List[torch.Tensor]] |
| """ |
| self._assert_raises(template, |
| "Union[List[str], List[torch.Tensor]]", |
| lhs["list_literal_empty"], |
| "there are multiple possible List type " |
| "candidates in the Union annotation") |
| |
| self._assert_passes(template, |
| "Union[List[str], List[torch.Tensor]]", |
| lhs["list_literal_of_tensor"]) |
| |
| self._assert_passes(template, |
| "Union[List[str], List[torch.Tensor]]", |
| lhs["list_literal_of_str"]) |
| |
| self._assert_raises(template, |
| "Union[List[str], List[torch.Tensor]]", |
| lhs["list_literal_of_mixed"], |
| "none of those types match the types of the" |
| " given list elements") |
| |
| self._assert_passes(template, |
| "Union[List[str], List[torch.Tensor]]", |
| lhs["list_comprehension_of_tensor"]) |
| |
| self._assert_passes(template, |
| "Union[List[str], List[torch.Tensor]]", |
| lhs["list_comprehension_of_str"]) |
| |
| # TODO: Support mixed list comprehensions |
| self._assert_raises(template, |
| "Union[List[str], List[torch.Tensor]]", |
| lhs["list_comprehension_of_mixed"], |
| "Arguments for call are not valid") |
| |
| """ |
| Union[int, torch.Tensor] |
| """ |
| self._assert_raises(template, |
| "Union[int, torch.Tensor]", |
| lhs["list_literal_empty"], |
| "Expected an Union type annotation with an " |
| "inner List type") |
| |
| self._assert_raises(template, "Union[int, torch.Tensor]", |
| lhs["list_literal_of_tensor"], |
| "Expected an Union type annotation with an " |
| "inner List type") |
| |
| self._assert_raises(template, "Union[int, torch.Tensor]", |
| lhs["list_comprehension_of_tensor"], |
| "Expected an Union type annotation with an " |
| "inner List type") |
| |
| """ |
| Union[List[torch.Tensor], int] |
| """ |
| self._assert_passes(template, |
| "Union[List[torch.Tensor], int]", |
| lhs["list_literal_empty"]) |
| |
| self._assert_passes(template, |
| "Union[List[torch.Tensor], int]", |
| lhs["list_literal_of_tensor"]) |
| |
| self._assert_raises(template, "Union[List[torch.Tensor], int]", |
| lhs["list_literal_of_str"], |
| r"List type annotation `List\[Tensor\]` did " |
| "not match the types of the given list " |
| "elements") |
| |
| self._assert_raises(template, "Union[List[torch.Tensor], int]", |
| lhs["list_literal_of_mixed"], |
| r"List type annotation `List\[Tensor\]` did " |
| "not match the types of the given list " |
| "elements") |
| |
| self._assert_passes(template, |
| "Union[List[torch.Tensor], int]", |
| lhs["list_comprehension_of_tensor"]) |
| |
| self._assert_raises(template, |
| "Union[List[torch.Tensor], int]", |
| lhs["list_comprehension_of_str"], |
| r"List type annotation `List\[Tensor\]` did " |
| "not match the types of the given list " |
| "elements") |
| |
| # TODO(@ansley): Support mixed list comprehensions |
| self._assert_raises(template, |
| "Union[List[torch.Tensor], int]", |
| lhs["list_comprehension_of_mixed"], |
| "Arguments for call are not valid") |
| |
| def test_union_with_dict_assignment(self): |
| template = dedent(''' |
| def fn(): |
| x: {ann} = {lhs} |
| if torch.jit.isinstance(x, Dict[str, torch.Tensor]): |
| x["foo"] = torch.tensor(3) |
| return x |
| ''') |
| |
| lhs = {"dict_literal_empty" : "{}", |
| |
| "dict_literal_of_str_tensor" : |
| "{\"foo\" : torch.arange(3), \"bar\" : torch.arange(5)}", |
| |
| "dict_literal_of_str_int" : |
| "{\"foo\" : 1, \"bar\" : 2}", |
| |
| "dict_literal_of_mixed" : |
| "{\"foo\" : torch.arange(3), \"bar\" : 2}", |
| |
| "dict_comprehension_of_str_tensor" : |
| "{x : torch.add(y, 1) for x, y in \ |
| zip([\"foo\", \"bar\"], [torch.arange(3), torch.arange(5)])}", |
| |
| "dict_comprehension_of_str_int" : |
| "{x : torch.add(y, 1) for x, y in \ |
| zip([\"foo\", \"bar\"], [1, 2]}", |
| |
| "dict_comprehension_of_mixed" : |
| "{x : torch.add(y, 1) for x, y in \ |
| zip([\"foo\", \"bar\"], [torch.arange(3), 2])}", |
| |
| "dict_keyword" : |
| "dict(foo=torch.arange(3), baz=torch.arange(5))", |
| |
| "dict_keyword_with_iterable" : |
| "dict([(\"foo\", torch.arange(3)), (\"bar\", torch.arange(5))])", |
| |
| "dict_keyword_with_empty_iterable" : |
| "dict([])", |
| |
| "dict_keyword_with_internal_aggregate_function" : |
| "dict(zip([\"foo\", \"bar\"], [torch.arange(3), torch.arange(5)])", |
| |
| "dict_keyword_with_mapping" : |
| "dict({\"foo\" : torch.arange(3), \"bar\" : torch.arange(5)})", |
| |
| "dict_keyword_with_mapping_and_kwargs" : |
| "dict({\"foo\" : torch.arange(3), \"bar\" : torch.arange(5)}, baz=torch.arange(7))", |
| |
| } |
| |
| """ |
| Union[Dict[str, torch.Tensor], Dict[str, int]] |
| """ |
| self._assert_raises(template, |
| "Union[List[str], List[torch.Tensor]]", |
| lhs["dict_literal_empty"], |
| "Expected an Union type annotation with an " |
| "inner Dict type") |
| |
| self._assert_passes(template, |
| "Union[Dict[str, torch.Tensor], Dict[str, int]]", |
| lhs["dict_literal_of_str_tensor"]) |
| |
| self._assert_passes(template, |
| "Union[Dict[str, torch.Tensor], Dict[str, int]]", |
| lhs["dict_literal_of_str_int"]) |
| |
| self._assert_raises(template, "Union[Dict[str, torch.Tensor], Dict[str, int]]", |
| lhs["dict_literal_of_mixed"], |
| "none of those dict types can hold the " |
| "types of the given keys and values") |
| |
| # TODO: String frontend does not support tuple unpacking |
| # https://github.com/pytorch/pytorch/issues/64096 |
| # self._assert_passes(template, "Union[Dict[str, torch.Tensor], Dict[str, int]]", |
| # lhs["dict_comprehension_of_str_tensor"]) |
| |
| # self._assert_passes(template, "Union[Dict[str, torch.Tensor], Dict[str, int]]", |
| # lhs["dict_comprehension_of_str_int"]) |
| |
| # self._assert_raises(template, "Union[Dict[str, torch.Tensor], Dict[str, int]]", |
| # lhs["dict_comprehension_of_mixed"], |
| # "foobar") |
| |
| # self._assert_passes(template, |
| # "Union[Dict[str, torch.Tensor], Dict[str, int]]", |
| # lhs["dict_keyword_with_internal_aggregate_function"]) |
| |
| # TODO(@ansley): Follow-up project needed for full type |
| # inference with dict keyword (supported for dict comprehension |
| # and dict literal already; should not be a blocker for anyone) |
| self._assert_raises(template, |
| "Union[Dict[str, torch.Tensor], Dict[str, int]]", |
| lhs["dict_keyword"], |
| "full type inference is not yet supported") |
| |
| self._assert_raises(template, |
| "Union[Dict[str, torch.Tensor], Dict[str, int]]", |
| lhs["dict_keyword_with_iterable"], |
| "full type inference is not yet supported") |
| |
| self._assert_raises(template, |
| "Union[Dict[str, torch.Tensor], Dict[str, int]]", |
| lhs["dict_keyword_with_empty_iterable"], |
| "full type inference is not yet supported") |
| |
| self._assert_raises(template, |
| "Union[Dict[str, torch.Tensor], Dict[str, int]]", |
| lhs["dict_keyword_with_mapping"], |
| "full type inference is not yet supported") |
| |
| self._assert_raises(template, |
| "Union[Dict[str, torch.Tensor], Dict[str, int]]", |
| lhs["dict_keyword_with_mapping_and_kwargs"], |
| "full type inference is not yet supported") |
| |
| """ |
| Union[int, torch.Tensor] |
| """ |
| self._assert_raises(template, |
| "Union[int, torch.Tensor]", |
| lhs["dict_literal_empty"], |
| "Expected an Union type annotation with " |
| "an inner Dict type") |
| |
| self._assert_raises(template, |
| "Union[int, torch.Tensor]", |
| lhs["dict_literal_of_str_tensor"], |
| "Expected an Union type annotation with " |
| "an inner Dict type") |
| |
| # See above--string frontend does not support tuple unpacking |
| # self._assert_raises(template, "Union[int, torch.Tensor]", |
| # lhs["dict_comprehension_of_tensor"], |
| # "foobar") |
| |
| """ |
| Union[Dict[str, torch.Tensor], int] |
| """ |
| self._assert_passes(template, |
| "Union[Dict[str, torch.Tensor], int]", |
| lhs["dict_literal_empty"]) |
| |
| self._assert_passes(template, |
| "Union[Dict[str, torch.Tensor], int]", |
| lhs["dict_literal_of_str_tensor"]) |
| |
| self._assert_raises(template, |
| "Union[Dict[str, torch.Tensor], int]", |
| lhs["dict_literal_of_str_int"], |
| "Type annotation was inferred to be " |
| r"`Dict\[str, Tensor\]`, but the type of " |
| "values given by the dict literal is") |
| |
| self._assert_raises(template, |
| "Union[Dict[str, torch.Tensor], int]", |
| lhs["dict_literal_of_mixed"], |
| "Type annotation was inferred to be " |
| r"`Dict\[str, Tensor\]`, but the type of " |
| "values given by the dict literal is") |
| |
| self._assert_passes(template, |
| "Union[Dict[str, torch.Tensor], int]", |
| lhs["dict_keyword"]) |
| |
| self._assert_passes(template, |
| "Union[Dict[str, torch.Tensor], int]", |
| lhs["dict_keyword_with_iterable"]) |
| |
| self._assert_passes(template, |
| "Union[Dict[str, torch.Tensor], int]", |
| lhs["dict_keyword_with_empty_iterable"]) |
| |
| self._assert_passes(template, |
| "Union[Dict[str, torch.Tensor], int]", |
| lhs["dict_keyword_with_mapping"]) |
| |
| self._assert_passes(template, |
| "Union[Dict[str, torch.Tensor], int]", |
| lhs["dict_keyword_with_mapping_and_kwargs"]) |
| |
| # See above--string frontend does not support tuple unpacking |
| # self._assert_passes(template, |
| # "Union[Dict[str, torch.Tensor], int]", |
| # lhs["dict_keyword_with_internal_aggregate_function"]) |
| # |
| # self._assert_passes(template, |
| # "Union[Dict[str, torch.Tensor], int]", |
| # lhs["dict_comprehension_of_str_tensor"]) |
| |
| # self._assert_raises(template, |
| # "Union[Dict[str, torch.Tensor], int]", |
| # lhs["dict_comprehension_of_str_int"], |
| # "foobar") |
| |
| # self._assert_raises(template, |
| # "Union[Dict[str, torch.Tensor], int]", |
| # lhs["dict_comprehension_of_mixed"], |
| # "foobar") |