| # Owner(s): ["oncall: export"] |
| |
| import unittest |
| from collections import OrderedDict |
| from typing import Any, Dict, List, Optional, Tuple |
| |
| import torch |
| import torch.utils._pytree as pytree |
| from torch._dynamo.test_case import TestCase |
| from torch._export.converter import TS2EPConverter |
| from torch.export import ExportedProgram |
| from torch.testing._internal.common_quantized import override_quantized_engine |
| from torch.testing._internal.common_utils import IS_WINDOWS, run_tests |
| from torch.testing._internal.torchbind_impls import ( |
| _empty_tensor_queue, |
| init_torchbind_implementations, |
| ) |
| |
| |
| requires_cuda = unittest.skipUnless(torch.cuda.is_available(), "requires cuda") |
| |
| |
| class TestConverter(TestCase): |
| def setUp(self): |
| init_torchbind_implementations() |
| |
| @torch._library.register_fake_class("_TorchScriptTesting::_TensorQueue") |
| class FakeTensorQueue: |
| def __init__(self, queue): |
| self.queue = queue |
| |
| @classmethod |
| def __obj_unflatten__(cls, flattened_ctx): |
| return cls(**dict(flattened_ctx)) |
| |
| def push(self, x): |
| self.queue.append(x) |
| |
| def pop(self): |
| if self.is_empty(): |
| return torch.empty([]) |
| return self.queue.pop(0) |
| |
| def size(self): |
| return len(self.queue) |
| |
| def is_empty(self): |
| return len(self.queue) == 0 |
| |
| def float_size(self): |
| return float(len(self.queue)) |
| |
| self.torch_bind_ops = [ |
| torch.ops._TorchScriptTesting.queue_pop, |
| torch.ops._TorchScriptTesting.queue_push, |
| torch.ops._TorchScriptTesting.queue_size, |
| ] |
| |
| def tearDown(self): |
| torch._library.fake_class_registry.deregister_fake_class( |
| "_TorchScriptTesting::_TensorQueue" |
| ) |
| |
| def _check_equal_ts_ep_converter( |
| self, |
| M, |
| inp, |
| option: Optional[List[str]] = None, |
| check_persistent=False, |
| lifted_tensor_constants=None, |
| ) -> List[ExportedProgram]: |
| # By default, it tests both jit.trace and jit.script. |
| if option is None: |
| option = ["trace", "script"] |
| |
| if check_persistent: |
| num_iterations = 10 |
| else: |
| num_iterations = 1 |
| |
| ep_list = [] |
| for opt in option: |
| if opt == "script": |
| # Separate two models for testing non-functional effects |
| if check_persistent: |
| original_ts_model = torch.jit.script(M()) |
| ts_model = torch.jit.script(M()) |
| eager_model = M() |
| else: |
| original_ts_model = torch.jit.script(M) |
| ts_model = torch.jit.script(M) |
| eager_model = M |
| elif opt == "trace": |
| if check_persistent: |
| original_ts_model = torch.jit.trace(M(), inp) |
| ts_model = torch.jit.trace(M(), inp) |
| eager_model = M() |
| else: |
| original_ts_model = torch.jit.trace(M, inp) |
| ts_model = torch.jit.trace(M, inp) |
| eager_model = M |
| else: |
| raise RuntimeError(f"Unrecognized mode for torch.jit: {opt}") |
| |
| converter = TS2EPConverter(ts_model, inp) |
| ep = converter.convert() |
| ep_list.append(ep) |
| |
| for _ in range(num_iterations): |
| orig_out, _ = pytree.tree_flatten(original_ts_model(*inp)) |
| ep_out, _ = pytree.tree_flatten(ep.module()(*inp)) |
| |
| # Check module. |
| if isinstance(eager_model, torch.nn.Module): |
| expected_state_dict = OrderedDict() |
| expected_state_dict.update(ts_model.state_dict()) |
| if lifted_tensor_constants: |
| expected_state_dict.update(lifted_tensor_constants) |
| self.assertEqual( |
| ep.state_dict.keys(), |
| expected_state_dict.keys(), |
| ) |
| |
| # Check results |
| self._check_tensor_list_equal(ep_out, orig_out) |
| return ep_list |
| |
| def _check_tensor_list_equal(self, xs: List[torch.Tensor], ys: List[torch.Tensor]): |
| self.assertEqual(len(xs), len(ys)) |
| for x, y in zip(xs, ys): |
| if isinstance(x, torch.Tensor) and isinstance(y, torch.Tensor): |
| self.assertEqual(x.shape, y.shape) |
| self.assertTrue(torch.allclose(x, y)) |
| else: |
| self.assertEqual(type(x), type(y)) |
| self.assertEqual(x, y) |
| |
| def test_ts2ep_converter_basic(self): |
| class MSingle(torch.nn.Module): |
| def forward(self, x, y): |
| return x + y |
| |
| class MMulti(torch.nn.Module): |
| def forward(self, x, y): |
| x = x.cos() + 1 |
| y = y.sin() - 1 |
| return x, y |
| |
| inp = (torch.ones(1, 3), torch.ones(1, 3)) |
| self._check_equal_ts_ep_converter(MSingle(), inp) |
| self._check_equal_ts_ep_converter(MMulti(), inp) |
| |
| def test_ts2ep_converter_container_output(self): |
| # Output is a List. |
| class MOutputList(torch.nn.Module): |
| def forward(self, x: torch.Tensor, y: torch.Tensor): |
| a = x * x |
| b = y + y |
| return [a, b] |
| |
| # Output is a Tuple. |
| class MOutputTuple(torch.nn.Module): |
| def forward(self, x: torch.Tensor, y: torch.Tensor): |
| a = x * x |
| b = y + y |
| return (a, b) |
| |
| # Output is a Dict. |
| class MOutputDict(torch.nn.Module): |
| def forward(self, x: torch.Tensor, y: torch.Tensor): |
| a = x * x |
| b = y + y |
| return {"data": {"mul": a, "add": b}} |
| |
| inp = (torch.tensor(4), torch.tensor(4)) |
| |
| # Traced function must use immutable structure as output. |
| self._check_equal_ts_ep_converter(MOutputList(), inp, ["script"]) |
| self._check_equal_ts_ep_converter(MOutputTuple(), inp) |
| self._check_equal_ts_ep_converter(MOutputDict(), inp, ["script"]) |
| |
| def test_aten_dim(self): |
| class Module(torch.nn.Module): |
| def forward(self, x): |
| num_dim = x.dim() |
| return torch.ones(num_dim) |
| |
| inp = (torch.ones(1, 3),) |
| self._check_equal_ts_ep_converter(Module(), inp) |
| |
| def test_aten_len(self): |
| class Module(torch.nn.Module): |
| def forward(self, x: torch.Tensor): |
| length = len(x) |
| return torch.ones(length) |
| |
| # aten::len.Tensor |
| inp = (torch.ones(2, 3),) |
| self._check_equal_ts_ep_converter(Module(), inp) |
| |
| class Module(torch.nn.Module): |
| def forward(self, x: List[int]): |
| length = len(x) |
| return torch.ones(length) |
| |
| # aten::len.t |
| inp = ([1, 2, 3],) |
| self._check_equal_ts_ep_converter(Module(), inp, ["script"]) |
| |
| class Module(torch.nn.Module): |
| def forward(self, x: Dict[int, str]): |
| length = len(x) |
| return torch.ones(length) |
| |
| # aten::len.Dict_int |
| inp = ({1: "a", 2: "b", 3: "c"},) |
| self._check_equal_ts_ep_converter(Module(), inp, ["script"]) |
| |
| class Module(torch.nn.Module): |
| def forward(self, x: Dict[bool, str]): |
| length = len(x) |
| return torch.ones(length) |
| |
| # aten::len.Dict_bool |
| inp = ({True: "a", False: "b"},) |
| self._check_equal_ts_ep_converter(Module(), inp, ["script"]) |
| |
| class Module(torch.nn.Module): |
| def forward(self, x: Dict[float, str]): |
| length = len(x) |
| return torch.ones(length) |
| |
| # aten::len.Dict_float |
| inp = ({1.2: "a", 3.4: "b"},) |
| self._check_equal_ts_ep_converter(Module(), inp, ["script"]) |
| |
| class Module(torch.nn.Module): |
| def forward(self, x: Dict[torch.Tensor, str]): |
| length = len(x) |
| return torch.ones(length) |
| |
| # aten::len.Dict_Tensor |
| inp = ({torch.zeros(2, 3): "a", torch.ones(2, 3): "b"},) |
| self._check_equal_ts_ep_converter(Module(), inp, ["script"]) |
| |
| # aten::len.str and aten::len.Dict_str are not supported |
| # since torch._C._jit_flatten does not support str |
| # inp = ("abcdefg",) |
| # self._check_equal_ts_ep_converter(Module(), inp) |
| # inp = ({"a": 1, "b": 2},) |
| # self._check_equal_ts_ep_converter(Module(), inp) |
| |
| def test_aten_add_t(self): |
| # python list append |
| class Module(torch.nn.Module): |
| def forward(self, x: List[torch.Tensor]): |
| out = [] |
| out = out + x |
| a = torch.cat(out) |
| out = out + x |
| b = torch.cat(out) |
| return a, b |
| |
| inp = ([torch.ones(2, 3), torch.ones(2, 3)],) |
| self._check_equal_ts_ep_converter(Module(), inp, ["script"]) |
| |
| def test_aten_to_dtype_with_mutating_storage(self): |
| class Module(torch.nn.Module): |
| def forward(self, x: torch.Tensor, y: torch.Tensor): |
| x = x.to(y.dtype) |
| torch.ops.aten.index_put_(x, [torch.tensor([0])], y) |
| return x |
| |
| inp = (torch.ones(2, 3), torch.tensor([0, 0, 0])) |
| self._check_equal_ts_ep_converter(Module(), inp) |
| |
| def test_prim_min(self): |
| class Module(torch.nn.Module): |
| def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: |
| x_len = len(x) |
| y_len = len(y) |
| |
| # prim::min.int |
| len_int = min(x_len, y_len) |
| |
| # prim::min.float |
| len_float = int(min(x_len * 2.0, y_len * 2.0)) |
| |
| # prim::min.self_int |
| len_self_int = min([x_len, y_len]) |
| |
| # prim::min.self_float |
| len_self_float = int(min([x_len * 2.0, y_len * 2.0])) |
| |
| # prim::min.float_int |
| len_float_int = int(min(x_len * 2.0, y_len)) |
| |
| # prim::min.int_float |
| len_int_float = int(min(x_len, y_len * 2.0)) |
| |
| return torch.ones( |
| len_int |
| + len_float |
| + len_self_int |
| + len_self_float |
| + len_float_int |
| + len_int_float |
| ) |
| |
| inp = (torch.randn(10, 2), torch.randn(5)) |
| self._check_equal_ts_ep_converter(Module(), inp) |
| |
| def test_prim_max(self): |
| class Module(torch.nn.Module): |
| def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: |
| x_len = len(x) |
| y_len = len(y) |
| |
| # prim::max.int |
| len_int = max(x_len, y_len) |
| |
| # prim::max.float |
| len_float = int(max(x_len * 2.0, y_len * 2.0)) |
| |
| # prim::max.self_int |
| len_self_int = max([x_len, y_len]) |
| |
| # prim::max.self_float |
| len_self_float = int(max([x_len * 2.0, y_len * 2.0])) |
| |
| # prim::max.float_int |
| len_float_int = int(max(x_len * 2.0, y_len)) |
| |
| # prim::max.int_float |
| len_int_float = int(max(x_len, y_len * 2.0)) |
| |
| return torch.ones( |
| len_int |
| + len_float |
| + len_self_int |
| + len_self_float |
| + len_float_int |
| + len_int_float |
| ) |
| |
| inp = (torch.randn(10, 2), torch.randn(5)) |
| self._check_equal_ts_ep_converter(Module(), inp) |
| |
| def test_aten___getitem___list(self): |
| class Module(torch.nn.Module): |
| def forward(self, x): |
| y = torch.split(x, 2) |
| return y[0] |
| |
| inp = (torch.rand((3, 2)),) |
| self._check_equal_ts_ep_converter(Module(), inp) |
| |
| def test_aten___getitem___dict(self): |
| class Module(torch.nn.Module): |
| def forward(self, x): |
| y = torch.split(x, 2) |
| d_int = {0: y[0], 1: y[1]} |
| d_str = {"0": y[0], "1": y[1]} |
| d_bool = {True: y[0], False: y[1]} |
| d_float = {0.1: y[0], 2.3: y[1]} |
| return d_int[0], d_str["0"], d_bool[True], d_float[0.1] |
| |
| inp = (torch.rand((3, 2)),) |
| self._check_equal_ts_ep_converter(Module(), inp) |
| |
| def test_prim_device(self): |
| class Module(torch.nn.Module): |
| def forward(self, x): |
| device = x.device |
| return torch.ones(2, 3, device=device) |
| |
| inp = (torch.rand(3, 4),) |
| self._check_equal_ts_ep_converter(Module(), inp) |
| |
| @requires_cuda |
| def test_prim_device_cuda(self): |
| class Module(torch.nn.Module): |
| def forward(self, x): |
| device = x.device |
| return torch.ones(2, 3, device=device) |
| |
| inp = (torch.rand((3, 4), device="cuda:0"),) |
| self._check_equal_ts_ep_converter(Module(), inp) |
| |
| def test_prim_dtype(self): |
| class Module(torch.nn.Module): |
| def forward(self, x): |
| dtype = x.dtype |
| return torch.ones(2, 3, dtype=dtype) |
| |
| for dtype in [ |
| torch.float32, |
| torch.double, |
| ]: |
| inp = (torch.rand((3, 4), dtype=dtype),) |
| self._check_equal_ts_ep_converter(Module(), inp) |
| |
| for dtype in [ |
| torch.uint8, |
| torch.int8, |
| torch.int32, |
| ]: |
| inp = (torch.randint(high=128, size=(3, 4), dtype=dtype),) |
| self._check_equal_ts_ep_converter(Module(), inp) |
| |
| def test_convert_if_basic(self): |
| class M(torch.nn.Module): |
| def forward(self, x: torch.Tensor, y: torch.Tensor): |
| if x: |
| return y * y |
| else: |
| return y + y |
| |
| inp = (torch.tensor(True), torch.tensor(4)) |
| ep_list = self._check_equal_ts_ep_converter(M(), inp) |
| |
| for ep in ep_list[1:]: |
| torch.testing.assert_close( |
| ep.module()(torch.tensor(False), torch.tensor(4)), |
| M()(torch.tensor(False), torch.tensor(4)), |
| ) |
| |
| def test_convert_if_tuple_out(self): |
| class M(torch.nn.Module): |
| def true_fn(self, y, z): |
| return (z * z, z + z) |
| |
| def false_fn(self, y, z): |
| return (y * y * y, y + y) |
| |
| def forward(self, x: torch.Tensor, y: torch.Tensor): |
| z = y * y |
| |
| if x: |
| res = self.true_fn(y, z) |
| else: |
| res = self.false_fn(y, z) |
| |
| return res[0] + res[1] |
| |
| inp = (torch.tensor(True), torch.tensor(4)) |
| ep_list = self._check_equal_ts_ep_converter(M(), inp) |
| |
| for ep in ep_list[1:]: |
| torch.testing.assert_close( |
| ep.module()(torch.tensor(False), torch.tensor(4)), |
| M()(torch.tensor(False), torch.tensor(4)), |
| ) |
| |
| def test_convert_if_multiple_out(self): |
| class M(torch.nn.Module): |
| def true_fn(self, y, z): |
| return z * z |
| |
| def false_fn(self, y, z): |
| return y * y * y |
| |
| def forward(self, x: torch.Tensor, y: torch.Tensor): |
| z = y * y |
| |
| if x: |
| res1 = self.true_fn(y, z) |
| res2 = y |
| else: |
| res1 = z |
| res2 = self.false_fn(y, z) |
| |
| return res1 + res2 |
| |
| inp = (torch.tensor(True), torch.tensor(4)) |
| ep_list = self._check_equal_ts_ep_converter(M(), inp) |
| |
| for ep in ep_list[1:]: |
| torch.testing.assert_close( |
| ep.module()(torch.tensor(False), torch.tensor(4)), |
| M()(torch.tensor(False), torch.tensor(4)), |
| ) |
| |
| def test_profiler__record_function(self): |
| class Module(torch.nn.Module): |
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| handle = torch.ops.profiler._record_function_enter_new("foo", None) |
| y = x * 2 + 4 |
| torch.ops.profiler._record_function_exit(handle) |
| return y |
| |
| x = torch.randn(10, 10) |
| self._check_equal_ts_ep_converter(Module(), (x,)) |
| |
| def test_aten_floordiv(self): |
| class Module(torch.nn.Module): |
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| return x // 2 |
| |
| x = torch.randn(10, 10) |
| self._check_equal_ts_ep_converter(Module(), (x,)) |
| |
| def test_aten___is__(self): |
| class Module(torch.nn.Module): |
| def forward( |
| self, x: torch.Tensor, y: torch.Tensor |
| ) -> Tuple[bool, torch.Tensor]: |
| z = x + 1 |
| return x is y, z |
| |
| # Traced function must return output that has tensors. |
| inp = (torch.randn(10, 10), torch.rand(10, 10)) |
| self._check_equal_ts_ep_converter(Module(), inp, ["script"]) |
| |
| def test_aten___isnot__(self): |
| class Module(torch.nn.Module): |
| def forward( |
| self, x: torch.Tensor, y: torch.Tensor |
| ) -> Tuple[bool, torch.Tensor]: |
| z = x + 1 |
| return x is not y, z |
| |
| # Traced function must return output that has tensors. |
| inp = (torch.randn(10, 10), torch.rand(10, 10)) |
| self._check_equal_ts_ep_converter(Module(), inp, ["script"]) |
| |
| def test_aten___not__(self): |
| class Module(torch.nn.Module): |
| def forward( |
| self, x: torch.Tensor, y: torch.Tensor |
| ) -> Tuple[bool, torch.Tensor]: |
| z = x + 1 |
| return not (x is not y), z |
| |
| # Traced function must return output that has tensors. |
| inp = (torch.randn(10, 10), torch.rand(10, 10)) |
| self._check_equal_ts_ep_converter(Module(), inp, ["script"]) |
| |
| def test_ts2ep_converter_unpack(self): |
| class MUnpackList(torch.nn.Module): |
| def forward(self, x): |
| x, y = torch.split(x, 2) |
| return x + y |
| |
| class MUnpackTuple(torch.nn.Module): |
| def forward(self, x_tuple: Tuple[torch.Tensor, torch.Tensor]): |
| x, y = x_tuple |
| x = x.cos() |
| return x + y |
| |
| inp = (torch.ones(4),) |
| self._check_equal_ts_ep_converter(MUnpackList(), inp) |
| inp = ((torch.zeros(1, 4), torch.ones(1, 4)),) |
| self._check_equal_ts_ep_converter(MUnpackTuple(), inp) |
| |
| @unittest.skipIf( |
| IS_WINDOWS, |
| "torch.cond doesn't go through torch.compile on windows" |
| "causing output not normalized as list", |
| ) |
| def test_convert_retrace_nested_scripted_modules(self): |
| class Wrapper(torch.nn.Module): |
| def __init__(self, mod) -> None: |
| super().__init__() |
| self.mod = mod |
| |
| def forward(self, x, y): |
| return self.mod(x, y) |
| |
| class LinearM(torch.nn.Module): |
| def __init__(self, dim: int) -> None: |
| super().__init__() |
| self.linear = torch.nn.Linear(dim, dim) |
| |
| def forward(self, x, y): |
| return self.linear(y) |
| |
| class M(torch.nn.Module): |
| def __init__(self, dim: int) -> None: |
| super().__init__() |
| m = LinearM(dim) |
| m = torch.jit.script(m) |
| self.mod1 = m |
| self.mod2 = Wrapper(m) |
| |
| def forward(self, x: torch.Tensor, y: torch.Tensor): |
| if x: |
| return -self.mod1(x, y) - self.mod2(x, y) |
| else: |
| return -self.mod1(x, y) + self.mod2(x, y) |
| |
| class NestedM(torch.nn.Module): |
| def __init__(self, dim: int) -> None: |
| super().__init__() |
| m = M(dim) |
| m = torch.jit.script(m) |
| self.mod1 = m |
| self.mod2 = Wrapper(m) |
| |
| def forward(self, x: torch.Tensor, y: torch.Tensor): |
| if x: |
| return self.mod1(x, y) + self.mod2(x, y) |
| else: |
| return self.mod1(x, y) - self.mod2(x, y) |
| |
| inp = ( |
| torch.tensor(True), |
| torch.randn([3, 3]), |
| ) |
| self._check_equal_ts_ep_converter(NestedM(3), inp) |
| |
| def test_convert_nn_module_with_nested_param(self): |
| class M(torch.nn.Module): |
| def __init__(self, dim: int) -> None: |
| super().__init__() |
| self.linear = torch.nn.Linear(dim, dim) |
| |
| def forward(self, x: torch.Tensor): |
| return self.linear(x) |
| |
| class NestedM(torch.nn.Module): |
| def __init__(self, dim: int) -> None: |
| super().__init__() |
| self.linear = torch.nn.Linear(dim, dim) |
| self.m = M(dim) |
| |
| def forward(self, x: torch.Tensor): |
| return self.linear(self.m(x)) |
| |
| class SuperNestedM(torch.nn.Module): |
| def __init__(self, dim: int) -> None: |
| super().__init__() |
| self.linear = torch.nn.Linear(dim, dim) |
| self.m = NestedM(dim) |
| |
| def forward(self, x: torch.Tensor): |
| return self.linear(self.m(x)) |
| |
| inp = (torch.ones(3),) |
| orig_m = NestedM(3) |
| self._check_equal_ts_ep_converter(orig_m, inp) |
| orig_m = SuperNestedM(3) |
| self._check_equal_ts_ep_converter(orig_m, inp) |
| |
| def test_convert_nn_module_with_nested_buffer(self): |
| class M(torch.nn.Module): |
| def __init__(self) -> None: |
| super().__init__() |
| self.w = torch.nn.Buffer(torch.randn(1)) |
| |
| def forward(self, x: torch.Tensor): |
| return self.w + x |
| |
| class NestedM(torch.nn.Module): |
| def __init__(self) -> None: |
| super().__init__() |
| self.m = M() |
| self.w = torch.nn.Buffer(torch.randn(1)) |
| |
| def forward(self, x: torch.Tensor): |
| return self.w + self.m(x) |
| |
| class SuperNestedM(torch.nn.Module): |
| def __init__(self) -> None: |
| super().__init__() |
| self.m = NestedM() |
| self.w = torch.nn.Buffer(torch.randn(1)) |
| |
| def forward(self, x: torch.Tensor): |
| return self.w + self.m(x) |
| |
| inp = (torch.ones(1),) |
| orig_m = NestedM() |
| self._check_equal_ts_ep_converter(orig_m, inp) |
| orig_m = SuperNestedM() |
| self._check_equal_ts_ep_converter(orig_m, inp) |
| |
| def test_convert_nn_module_with_nested_if_and_buffer(self): |
| class M(torch.nn.Module): |
| def __init__(self) -> None: |
| super().__init__() |
| self.w = torch.nn.Buffer(torch.randn(1)) |
| self.count = 1 |
| |
| def forward(self, x: torch.Tensor): |
| return self.w + x + self.count |
| |
| class NestedM(torch.nn.Module): |
| def __init__(self) -> None: |
| super().__init__() |
| self.m1 = M() |
| self.m2 = M() |
| self.w = torch.nn.Buffer(torch.randn(1)) |
| |
| def forward(self, x: torch.Tensor): |
| if torch.sum(x) > 1: |
| return self.w + self.m1(x) |
| else: |
| return self.w + self.m2(x) |
| |
| # Super nested, parameters neeed to lifted |
| # multiple times. |
| class SuperNestedM(torch.nn.Module): |
| def __init__(self) -> None: |
| super().__init__() |
| self.m1 = NestedM() |
| self.m2 = NestedM() |
| self.w = torch.nn.Buffer(torch.randn(1)) |
| |
| def forward(self, x: torch.Tensor): |
| if torch.max(x) > 1: |
| return self.w + self.m1(x) |
| else: |
| return self.w + self.m2(x) |
| |
| # Super nested module testing. |
| inp = (torch.ones(1),) |
| orig_m = SuperNestedM() |
| ep_list = self._check_equal_ts_ep_converter(orig_m, inp) |
| |
| t = inp[0] |
| t -= 1 |
| for ep in ep_list: |
| torch.testing.assert_close( |
| ep.module()(*inp), |
| orig_m(*inp), |
| ) |
| |
| @unittest.skipIf( |
| IS_WINDOWS, |
| "torch.cond doesn't go through torch.compile on windows" |
| "causing output not normalized as list", |
| ) |
| def test_convert_nn_module_with_nested_if_and_param(self): |
| class M(torch.nn.Module): |
| def __init__(self, dim: int) -> None: |
| super().__init__() |
| self.linear = torch.nn.Linear(dim, dim) |
| |
| def forward(self, x: torch.Tensor): |
| return self.linear(x) |
| |
| class NestedM(torch.nn.Module): |
| def __init__(self, dim: int) -> None: |
| super().__init__() |
| self.m1 = M(dim) |
| self.m2 = M(dim) |
| self.linear = torch.nn.Linear(dim, dim) |
| |
| def forward(self, x: torch.Tensor): |
| if torch.sum(x) > 1: |
| return self.linear(self.m1(x)) |
| else: |
| return self.linear(self.m2(x)) |
| |
| # Super nested, parameters neeed to lifted |
| # multiple times. |
| class SuperNestedM1(torch.nn.Module): |
| def __init__(self, dim: int) -> None: |
| super().__init__() |
| self.m1 = NestedM(dim) |
| self.m2 = NestedM(dim) |
| self.linear = torch.nn.Linear(dim, dim) |
| |
| def forward(self, x: torch.Tensor): |
| if torch.max(x) > 1: |
| return self.linear(self.m1(x)) |
| else: |
| return self.linear(self.m2(x)) |
| |
| # Super nested, even the input needs to be |
| # lifted recursively due to value propogation optimiztaion. |
| class SuperNestedM2(torch.nn.Module): |
| def __init__(self, dim: int) -> None: |
| super().__init__() |
| self.m1 = NestedM(dim) |
| self.m2 = NestedM(dim) |
| self.linear = torch.nn.Linear(dim, dim) |
| |
| def forward(self, x: torch.Tensor): |
| if torch.sum(x) > 1: |
| return self.linear(self.m1(x)) |
| else: |
| return self.linear(self.m2(x)) |
| |
| # Basic module testing. |
| inp = (torch.ones(3),) |
| orig_m = M(3) |
| ep_list = self._check_equal_ts_ep_converter(orig_m, inp) |
| |
| t = inp[0] |
| t -= 0.8 |
| for ep in ep_list[1:]: |
| torch.testing.assert_close( |
| ep.module()(*inp), |
| orig_m(*inp), |
| ) |
| |
| # Nested module testing. |
| inp = (torch.ones(3),) |
| orig_m = NestedM(3) |
| ep_list = self._check_equal_ts_ep_converter(orig_m, inp) |
| |
| t = inp[0] |
| t -= 0.8 |
| # Skip jit.traced because it specializes on one path. |
| for ep in ep_list[1:]: |
| torch.testing.assert_close( |
| ep.module()(*inp), |
| orig_m(*inp), |
| ) |
| |
| # Super nested module testing. |
| inp = (torch.ones(3),) |
| orig_m = SuperNestedM1(3) |
| ep_list = self._check_equal_ts_ep_converter(orig_m, inp) |
| |
| t = inp[0] |
| t -= 0.8 |
| # Skip jit.traced because it specializes on one path. |
| for ep in ep_list[1:]: |
| torch.testing.assert_close( |
| ep.module()(*inp), |
| orig_m(*inp), |
| ) |
| |
| # Super nested module testing. |
| inp = (torch.ones(3),) |
| orig_m = SuperNestedM2(3) |
| ep_list = self._check_equal_ts_ep_converter(orig_m, inp) |
| |
| t = inp[0] |
| t -= 0.8 |
| # Skip jit.traced because it specializes on one path. |
| for ep in ep_list[1:]: |
| torch.testing.assert_close( |
| ep.module()(*inp), |
| orig_m(*inp), |
| ) |
| |
| def test_ts2ep_converter_contains(self): |
| class MIn(torch.nn.Module): |
| def forward(self, x: torch.Tensor): |
| return x.dtype in [torch.float32, torch.float64] |
| |
| class MNotIn(torch.nn.Module): |
| def forward(self, x: torch.Tensor): |
| return x.dtype in [torch.int8] |
| |
| class MTensorIn(torch.nn.Module): |
| def forward(self, x: torch.Tensor, x_dict: Dict[torch.Tensor, str]): |
| return x in x_dict |
| |
| # Traced function must return output that has tensors. |
| inp = (torch.tensor(4),) |
| self._check_equal_ts_ep_converter(MIn(), inp, ["script"]) |
| self._check_equal_ts_ep_converter(MNotIn(), inp, ["script"]) |
| |
| # TODO: update test to use reference for in. |
| inp = (torch.tensor(4), {torch.tensor(4): "foo"}) |
| self._check_equal_ts_ep_converter(MTensorIn(), inp, ["script"]) |
| inp = (torch.tensor(1), {torch.tensor(4): "foo"}) |
| self._check_equal_ts_ep_converter(MTensorIn(), inp, ["script"]) |
| |
| def test_ts2ep_converter_custom_op(self): |
| with torch.library._scoped_library("mylib", "FRAGMENT") as lib: |
| torch._dynamo.config.capture_scalar_outputs = True |
| torch._dynamo.config.capture_dynamic_output_shape_ops = True |
| |
| torch.library.define( |
| "mylib::foo", |
| "(Tensor x) -> Tensor", |
| lib=lib, |
| ) |
| |
| # PyTorch custorm op implementation |
| @torch.library.impl( |
| "mylib::foo", |
| "CompositeExplicitAutograd", |
| lib=lib, |
| ) |
| def foo_impl(x): |
| return x + x |
| |
| # Meta function of the custom op. |
| @torch.library.impl_abstract( |
| "mylib::foo", |
| lib=lib, |
| ) |
| def foo_meta(x): |
| return x + x |
| |
| class M(torch.nn.Module): |
| def forward(self, x): |
| return torch.ops.mylib.foo(x) |
| |
| inp = (torch.randn(3, 3),) |
| m = M() |
| self._check_equal_ts_ep_converter(m, inp) |
| |
| def test_convert_func_without_param(self): |
| def func1(x, y): |
| return x + y |
| |
| def func2(x, y): |
| if x.sum() > 0: |
| return x + y |
| else: |
| return x - y |
| |
| inp = ( |
| torch.tensor(1), |
| torch.tensor(1), |
| ) |
| self._check_equal_ts_ep_converter(func1, inp) |
| |
| ep_list = self._check_equal_ts_ep_converter(func2, inp) |
| |
| t = inp[0] |
| t -= 1 |
| for ep in ep_list[1:]: |
| torch.testing.assert_close( |
| ep.module()(*inp), |
| func2(*inp), |
| ) |
| |
| def test_implicit_constant_to_tensor_handling(self): |
| def func1(x): |
| return x + 2 |
| |
| def func2(x, y): |
| return x * y / (x - 2 * y) + y |
| |
| def func3(x): |
| return x + torch.tensor([3]) |
| |
| def func4(): |
| val = torch.tensor(float("inf")) |
| return torch.full((10, 10), val) |
| |
| def func5(): |
| x = -1 |
| return x * torch.ones(1, dtype=torch.float), torch.zeros( |
| 1, dtype=torch.float |
| ) |
| |
| def func6(x1, x2, x3, x4): |
| return ( |
| x1.numel(), |
| x1.size(), |
| x2.numel(), |
| x2.size(), |
| x3.numel(), |
| x3.size(), |
| x4.numel(), |
| x4.size(), |
| torch.ones(x1.numel()), # Just make sure downstream ops still work. |
| torch.ones(x1.size()), # Just make sure downstream ops still work. |
| ) |
| |
| class M1(torch.nn.Module): |
| def __init__(self, value): |
| super().__init__() |
| self.x = torch.tensor(value) |
| |
| def forward(self): |
| return self.x.clone() |
| |
| class M2(torch.nn.Module): |
| def forward(self, x): |
| return torch.tensor(4) + x |
| |
| inp = (torch.randn([2, 2]),) |
| self._check_equal_ts_ep_converter(func1, inp) |
| inp = (torch.randn([2, 2]), torch.randn([2, 2])) |
| self._check_equal_ts_ep_converter(func2, inp) |
| |
| inp = (torch.randn([2, 2]),) |
| self._check_equal_ts_ep_converter(func3, inp) |
| |
| self._check_equal_ts_ep_converter(func4, ()) |
| self._check_equal_ts_ep_converter(M1(5), ()) |
| |
| inp = (torch.randn(2),) |
| self._check_equal_ts_ep_converter(M2(), inp) |
| |
| self._check_equal_ts_ep_converter(func5, ()) |
| inp = ( |
| torch.randn([2, 3, 4]).to(torch.int8), |
| torch.randn([2, 3, 4]).to(torch.int32), |
| torch.randn([2, 3, 4]).to(torch.float32), |
| torch.randn([2, 3, 4]).to(torch.float64), |
| ) |
| ep_list = self._check_equal_ts_ep_converter(func6, inp) |
| |
| # TODO: Additional check once dynamic shape is supported. |
| # for ep in ep_list: |
| # self.assertEqual( |
| # ep.module()( |
| # torch.randn([1, 1, 1]).to(torch.int8), |
| # torch.randn([1, 1, 1]).to(torch.int32), |
| # torch.randn([1, 1, 1]).to(torch.float32), |
| # torch.randn([1, 1, 1]).to(torch.float64), |
| # )[0], 1 |
| # ) |
| |
| def test_aten_tensor_dtype_int(self): |
| class M(torch.nn.Module): |
| def forward(self, x): |
| y = torch.tensor(1, dtype=torch.int32) |
| return y + x |
| |
| ep_list = self._check_equal_ts_ep_converter(M(), (torch.tensor(1),)) |
| for ep in ep_list: |
| self.assertEqual(len(ep.constants), 1) |
| |
| def test_aten_tensor_prim_dtype(self): |
| class M(torch.nn.Module): |
| def forward(self, x): |
| y = torch.tensor(1, dtype=x.dtype) |
| return y + x |
| |
| ep_list = self._check_equal_ts_ep_converter(M(), (torch.tensor(1),)) |
| for ep in ep_list: |
| self.assertEqual(len(ep.constants), 1) |
| |
| def test_aten_tensor_dynamic(self): |
| class M(torch.nn.Module): |
| def forward(self, x): |
| s = x.shape[0] |
| y = torch.tensor(s) |
| return y |
| |
| ep_list = self._check_equal_ts_ep_converter(M(), (torch.ones(3),)) |
| for ep in ep_list: |
| self.assertEqual(len(ep.constants), 0) |
| |
| # TODO: Additional check once dynamic shape is supported. |
| # for ep in ep_list: |
| # torch.testing.assert_close( |
| # ep.module()(torch.ones(4)), |
| # M()(torch.ones(4)), |
| # ) |
| |
| class M(torch.nn.Module): |
| def forward(self, x): |
| s = x.shape[0] |
| y = torch.tensor([s, s * 2, 1]) |
| return y |
| |
| ep_list = self._check_equal_ts_ep_converter(M(), (torch.ones(3),)) |
| # Trace directly inline a tensor constant. |
| for ep in ep_list[1:]: |
| self.assertEqual(len(ep.constants), 0) |
| |
| # TODO: Additional check once dynamic shape is supported. |
| # for ep in ep_list: |
| # torch.testing.assert_close( |
| # ep.module()(torch.ones(4)), |
| # M()(torch.ones(4)), |
| # ) |
| |
| def test_prim_tolist(self): |
| class Module(torch.nn.Module): |
| def forward(self, x: torch.Tensor) -> List[int]: |
| return x.tolist() |
| |
| inp = (torch.tensor([1, 2, 3]),) |
| self._check_equal_ts_ep_converter(Module(), inp, ["script"]) |
| |
| class Module(torch.nn.Module): |
| def forward(self, x: torch.Tensor) -> List[List[int]]: |
| return x.tolist() |
| |
| inp = (torch.tensor([[1, 2, 3], [4, 5, 6]]),) |
| self._check_equal_ts_ep_converter(Module(), inp, ["script"]) |
| |
| def test_get_tensor_constants(self): |
| # Since self.data is only read but not written, it is lifted as |
| # constant tensors. |
| class Foo(torch.nn.Module): |
| def __init__(self) -> None: |
| super().__init__() |
| self.data = torch.randn(3, 2) |
| |
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| return x + self.data |
| |
| class Goo(torch.nn.Module): |
| def __init__(self) -> None: |
| super().__init__() |
| self.data = torch.randn(3, 2) |
| self.foo = Foo() |
| |
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| return x + self.data + self.foo.data + self.foo(x) |
| |
| inp = (torch.randn(3, 2),) |
| goo = Goo() |
| self._check_equal_ts_ep_converter(goo, inp) |
| |
| def test_prim_SetAttr(self): |
| class Module(torch.nn.Module): |
| def __init__(self) -> None: |
| super().__init__() |
| self.data = torch.nn.Buffer(torch.ones(3, 2)) |
| |
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| self.data = self.data + x |
| return x + x |
| |
| inp = (torch.ones(3, 2),) |
| self._check_equal_ts_ep_converter( |
| Module, inp, ["script"], check_persistent=True |
| ) |
| |
| class Module(torch.nn.Module): |
| def __init__(self) -> None: |
| super().__init__() |
| self.data = torch.nn.Buffer(torch.ones(3, 2)) |
| |
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| self.data = self.data + x |
| return x + self.data |
| |
| inp = (torch.ones(3, 2),) |
| self._check_equal_ts_ep_converter( |
| Module, inp, ["script"], check_persistent=True |
| ) |
| |
| # export lifts a tensor constant (self.data) as an input if it is not assigned. |
| # If it is assigned, export will error and ask users to register it as a buffer. |
| # In converter, we change tensor constants that are assigned as a buffer automatically, |
| # since it might be hard to manually register them as buffers. |
| class Module(torch.nn.Module): |
| def __init__(self) -> None: |
| super().__init__() |
| self.data = torch.ones(3, 2) |
| |
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| self.data = self.data + x |
| return x + self.data |
| |
| inp = (torch.ones(3, 2),) |
| self._check_equal_ts_ep_converter( |
| Module, |
| inp, |
| ["script"], |
| check_persistent=True, |
| lifted_tensor_constants=OrderedDict([("data", torch.ones(3, 2))]), |
| ) |
| |
| class Module(torch.nn.Module): |
| def __init__(self) -> None: |
| super().__init__() |
| self.count = 0 |
| |
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| self.count += 1 |
| return x + self.count |
| |
| # check_persistent is False since export specializes on non-tensor constants |
| inp = (torch.ones(3, 2),) |
| self._check_equal_ts_ep_converter( |
| Module(), inp, ["script"], check_persistent=False |
| ) |
| |
| class M(torch.nn.Module): |
| def __init__(self) -> None: |
| super().__init__() |
| self.count = 0 |
| |
| def forward(self, x): |
| count1 = self.count |
| self.count += 1 |
| count2 = self.count |
| self.count += 1 |
| count3 = self.count |
| return x + count1 + count2 + count3 |
| |
| inp = (torch.ones(1),) |
| self._check_equal_ts_ep_converter(M(), inp, ["script"], check_persistent=False) |
| |
| class M(torch.nn.Module): |
| def __init__(self) -> None: |
| super().__init__() |
| self.w2 = torch.nn.Buffer(torch.ones(1)) |
| |
| def forward(self, x: torch.Tensor): |
| self.w2 += 1 |
| return self.w2 |
| |
| inp = (torch.ones(1),) |
| self._check_equal_ts_ep_converter(M, inp, ["script"], check_persistent=True) |
| |
| def test_raise_exception(self): |
| class Module(torch.nn.Module): |
| def forward(self, x: torch.Tensor, y: int) -> torch.Tensor: |
| if y > 0: |
| raise RuntimeError("test") |
| return x + y |
| |
| # match non-strict export behavior that errors when the given input leads to |
| # RaiseException. |
| with self.assertRaisesRegex(torch.jit.Error, "builtins.RuntimeError"): |
| inp = (torch.randn(3, 2), 1) |
| self._check_equal_ts_ep_converter(Module(), inp, ["script"]) |
| |
| # Matching non-strict export behavior that only executes 1 if-branch according |
| # to the given input. |
| inp = (torch.randn(3, 2), 0) |
| self._check_equal_ts_ep_converter(Module(), inp, ["script"]) |
| |
| class Module(torch.nn.Module): |
| def forward(self, x: torch.Tensor, y: int) -> torch.Tensor: |
| z = x |
| if y > 0: |
| raise RuntimeError("test") |
| # z = x |
| else: |
| z = x + y |
| return x + y + z |
| |
| # match non-strict export behavior that errors when the given input leads to |
| # RaiseException. |
| with self.assertRaisesRegex(torch.jit.Error, "builtins.RuntimeError"): |
| inp = (torch.randn(3, 2), 1) |
| self._check_equal_ts_ep_converter(Module(), inp, ["script"]) |
| |
| # Matching non-strict export behavior that only executes 1 if-branch according |
| # to the given input. |
| inp = (torch.randn(3, 2), 0) |
| self._check_equal_ts_ep_converter(Module(), inp, ["script"]) |
| |
| def test_context_manager(self): |
| class ContextManager: |
| def __init__(self) -> None: |
| self.count = 0 |
| return |
| |
| def __enter__(self): |
| self.count += 1 |
| return |
| |
| def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: |
| self.count -= 1 |
| return |
| |
| class M(torch.nn.Module): |
| def forward(self, x, y): |
| with ContextManager(): |
| res = x + y |
| return res |
| |
| inp = (torch.ones(3, 3), torch.ones(3, 3)) |
| self._check_equal_ts_ep_converter(M(), inp) |
| |
| def test_hidden_input_name(self): |
| @torch.jit.script |
| def func1(x): |
| return x + 1 |
| |
| def func2(*args): |
| v = torch.cat(args, dim=1) |
| return v * v |
| |
| inp = (torch.randn([1, 1]),) |
| self._check_equal_ts_ep_converter(func1, inp) |
| |
| inp = (torch.ones(5, 5),) |
| # Cannot script again. |
| self._check_equal_ts_ep_converter(torch.ops.aten.relu, inp, ["trace"]) |
| |
| M = 2 |
| Ns = [4, 2, 1] |
| empty = torch.tensor([], dtype=torch.double) |
| values = [empty] + [torch.randn(M, N) for N in Ns] |
| # Cannot script variable length inputs. |
| self._check_equal_ts_ep_converter(func2, tuple(values), ["trace"]) |
| |
| def test_ts2ep_multi_outputs_on_call_ops(self): |
| class M(torch.nn.Module): |
| def __init__(self) -> None: |
| super().__init__() |
| self.pool = torch.nn.AdaptiveMaxPool2d((2, 2), return_indices=True) |
| |
| def forward(self, x: torch.Tensor, y: torch.Tensor): |
| return ( |
| torch.max(x, dim=0), |
| torch.topk(x, 3), |
| torch.sort(x, dim=0), |
| self.pool(y), |
| ) |
| |
| inp = (torch.randn([4, 4]), torch.randn([1, 1, 10, 10])) |
| self._check_equal_ts_ep_converter(M(), inp) |
| |
| def test_aten_append_t(self): |
| class M(torch.nn.Module): |
| def forward(self, x: List[torch.Tensor]): |
| out = [] |
| out.append(x[0] + x[1]) |
| out.append(x[0] - x[1]) |
| out1 = torch.cat(out) |
| out.append(x[0] * x[1]) |
| out2 = torch.cat(out) |
| return out, out1, out2 |
| |
| inp = ([torch.ones(2, 3), torch.ones(2, 3)],) |
| # Trace already unrolls the list. |
| self._check_equal_ts_ep_converter(M(), inp, ["script"]) |
| |
| def test_convert_script_object(self): |
| class M1(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.tq = _empty_tensor_queue() |
| |
| def forward(self, x: torch.Tensor): |
| self.tq.push(x) |
| torch.ops._TorchScriptTesting.queue_push(self.tq, x.cos()) |
| return torch.ops._TorchScriptTesting.queue_pop(self.tq), self.tq.pop() |
| |
| inp = (torch.randn(2, 3),) |
| self._check_equal_ts_ep_converter(M1(), inp, ["script"]) |
| |
| def test_ts2ep_with_loop(self): |
| def func1(x, x_list: List[torch.Tensor]): |
| a, b, c = x, x, x |
| for i in range(1, 5, 2): |
| for k in range(5): |
| a = a + a + k |
| b = b + b - k |
| x_list.append(x_list[k] + x_list[k + 1]) |
| for k in range(5): |
| b = b + b - k |
| c = c + c * k |
| x_list.append(x_list[k] + x_list[k + 1] - x_list[k + 2]) |
| return x, x_list |
| |
| def func2(x): |
| for i in range(x.size(0)): |
| x = x * x * i |
| return x |
| |
| def func3(x): |
| while x.sum() < 10: |
| x += x.sin() |
| return x |
| |
| inp = ( |
| torch.tensor(1), |
| [torch.ones([2, 2]), torch.ones([2, 2]) * 2], |
| ) |
| # Trace unrolls the loop. |
| self._check_equal_ts_ep_converter(func1, inp, ["script"]) |
| |
| # TODO: (2/N) |
| # Trace unrolls the loop. |
| # self._check_equal_ts_ep_converter(func2, inp, ["script"]) |
| |
| # TODO: (3/N) |
| # Trace unrolls the loop. |
| # self._check_equal_ts_ep_converter(func3, inp, ["script"]) |
| |
| @unittest.skipIf( |
| IS_WINDOWS, |
| "Windows does not support qnnpack", |
| ) |
| def test_ts2ep_convert_quantized_model(self): |
| class Standalone(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.quant = torch.ao.quantization.QuantStub() |
| self.conv1 = torch.nn.Conv2d(1, 1, 1) |
| self.conv2 = torch.nn.Conv2d(1, 1, 1) |
| self.relu = torch.nn.ReLU() |
| self.dequant = torch.ao.quantization.DeQuantStub() |
| |
| def forward(self, x): |
| x = self.quant(x) |
| x = self.conv1(x) |
| x = self.conv2(x) |
| x = self.relu(x) |
| x = self.dequant(x) |
| return x |
| |
| def fuse_model(self): |
| torch.ao.quantization.fuse_modules( |
| self, [["conv2", "relu"]], inplace=True |
| ) |
| |
| with override_quantized_engine("qnnpack"): |
| model = Standalone() |
| model.qconfig = torch.ao.quantization.get_default_qconfig("qnnpack") |
| model.fuse_model() |
| torch.ao.quantization.prepare(model, inplace=True) |
| model(torch.randn(4, 1, 4, 4)) |
| torch.ao.quantization.convert(model, inplace=True) |
| |
| # Use customized checking here, because state_dict of quantization will be |
| # modified by the quantization pass. |
| inp = (torch.randn(4, 1, 4, 4),) |
| original_ts_model = torch.jit.script(model) |
| ts_model = torch.jit.script(model) |
| converter = TS2EPConverter(ts_model, inp) |
| ep = converter.convert() |
| |
| orig_out, _ = pytree.tree_flatten(original_ts_model(*inp)) |
| ep_out, _ = pytree.tree_flatten(ep.module()(*inp)) |
| self._check_tensor_list_equal(orig_out, ep_out) |
| |
| def test_ts2ep_convert_quantized_model_with_opcontext(self): |
| class M(torch.nn.Module): |
| def __init__(self, linear_op): |
| super().__init__() |
| self.linear_op = linear_op |
| |
| def forward(self, x): |
| x = torch.ops.prepacked.linear_clamp_run(x, self.linear_op) |
| return x |
| |
| linear_op = torch.ops.prepacked.linear_clamp_prepack( |
| torch.randn(10, 10), torch.randn(10) |
| ) |
| m = M(linear_op) |
| inp = (torch.randn(1, 10),) |
| self._check_equal_ts_ep_converter(m, inp, ["script"]) |
| |
| |
| if __name__ == "__main__": |
| run_tests() |