blob: 39d4c4a10d4ce3021a928bf8ce285f0a2b0d01f0 [file] [log] [blame]
# Owner(s): ["oncall: export"]
import unittest
from collections import OrderedDict
from typing import Any, Dict, List, Optional, Tuple
import torch
import torch.utils._pytree as pytree
from torch._dynamo.test_case import TestCase
from torch._export.converter import TS2EPConverter
from torch.export import ExportedProgram
from torch.testing._internal.common_quantized import override_quantized_engine
from torch.testing._internal.common_utils import IS_WINDOWS, run_tests
from torch.testing._internal.torchbind_impls import (
_empty_tensor_queue,
init_torchbind_implementations,
)
requires_cuda = unittest.skipUnless(torch.cuda.is_available(), "requires cuda")
class TestConverter(TestCase):
def setUp(self):
init_torchbind_implementations()
@torch._library.register_fake_class("_TorchScriptTesting::_TensorQueue")
class FakeTensorQueue:
def __init__(self, queue):
self.queue = queue
@classmethod
def __obj_unflatten__(cls, flattened_ctx):
return cls(**dict(flattened_ctx))
def push(self, x):
self.queue.append(x)
def pop(self):
if self.is_empty():
return torch.empty([])
return self.queue.pop(0)
def size(self):
return len(self.queue)
def is_empty(self):
return len(self.queue) == 0
def float_size(self):
return float(len(self.queue))
self.torch_bind_ops = [
torch.ops._TorchScriptTesting.queue_pop,
torch.ops._TorchScriptTesting.queue_push,
torch.ops._TorchScriptTesting.queue_size,
]
def tearDown(self):
torch._library.fake_class_registry.deregister_fake_class(
"_TorchScriptTesting::_TensorQueue"
)
def _check_equal_ts_ep_converter(
self,
M,
inp,
option: Optional[List[str]] = None,
check_persistent=False,
lifted_tensor_constants=None,
) -> List[ExportedProgram]:
# By default, it tests both jit.trace and jit.script.
if option is None:
option = ["trace", "script"]
if check_persistent:
num_iterations = 10
else:
num_iterations = 1
ep_list = []
for opt in option:
if opt == "script":
# Separate two models for testing non-functional effects
if check_persistent:
original_ts_model = torch.jit.script(M())
ts_model = torch.jit.script(M())
eager_model = M()
else:
original_ts_model = torch.jit.script(M)
ts_model = torch.jit.script(M)
eager_model = M
elif opt == "trace":
if check_persistent:
original_ts_model = torch.jit.trace(M(), inp)
ts_model = torch.jit.trace(M(), inp)
eager_model = M()
else:
original_ts_model = torch.jit.trace(M, inp)
ts_model = torch.jit.trace(M, inp)
eager_model = M
else:
raise RuntimeError(f"Unrecognized mode for torch.jit: {opt}")
converter = TS2EPConverter(ts_model, inp)
ep = converter.convert()
ep_list.append(ep)
for _ in range(num_iterations):
orig_out, _ = pytree.tree_flatten(original_ts_model(*inp))
ep_out, _ = pytree.tree_flatten(ep.module()(*inp))
# Check module.
if isinstance(eager_model, torch.nn.Module):
expected_state_dict = OrderedDict()
expected_state_dict.update(ts_model.state_dict())
if lifted_tensor_constants:
expected_state_dict.update(lifted_tensor_constants)
self.assertEqual(
ep.state_dict.keys(),
expected_state_dict.keys(),
)
# Check results
self._check_tensor_list_equal(ep_out, orig_out)
return ep_list
def _check_tensor_list_equal(self, xs: List[torch.Tensor], ys: List[torch.Tensor]):
self.assertEqual(len(xs), len(ys))
for x, y in zip(xs, ys):
if isinstance(x, torch.Tensor) and isinstance(y, torch.Tensor):
self.assertEqual(x.shape, y.shape)
self.assertTrue(torch.allclose(x, y))
else:
self.assertEqual(type(x), type(y))
self.assertEqual(x, y)
def test_ts2ep_converter_basic(self):
class MSingle(torch.nn.Module):
def forward(self, x, y):
return x + y
class MMulti(torch.nn.Module):
def forward(self, x, y):
x = x.cos() + 1
y = y.sin() - 1
return x, y
inp = (torch.ones(1, 3), torch.ones(1, 3))
self._check_equal_ts_ep_converter(MSingle(), inp)
self._check_equal_ts_ep_converter(MMulti(), inp)
def test_ts2ep_converter_container_output(self):
# Output is a List.
class MOutputList(torch.nn.Module):
def forward(self, x: torch.Tensor, y: torch.Tensor):
a = x * x
b = y + y
return [a, b]
# Output is a Tuple.
class MOutputTuple(torch.nn.Module):
def forward(self, x: torch.Tensor, y: torch.Tensor):
a = x * x
b = y + y
return (a, b)
# Output is a Dict.
class MOutputDict(torch.nn.Module):
def forward(self, x: torch.Tensor, y: torch.Tensor):
a = x * x
b = y + y
return {"data": {"mul": a, "add": b}}
inp = (torch.tensor(4), torch.tensor(4))
# Traced function must use immutable structure as output.
self._check_equal_ts_ep_converter(MOutputList(), inp, ["script"])
self._check_equal_ts_ep_converter(MOutputTuple(), inp)
self._check_equal_ts_ep_converter(MOutputDict(), inp, ["script"])
def test_aten_dim(self):
class Module(torch.nn.Module):
def forward(self, x):
num_dim = x.dim()
return torch.ones(num_dim)
inp = (torch.ones(1, 3),)
self._check_equal_ts_ep_converter(Module(), inp)
def test_aten_len(self):
class Module(torch.nn.Module):
def forward(self, x: torch.Tensor):
length = len(x)
return torch.ones(length)
# aten::len.Tensor
inp = (torch.ones(2, 3),)
self._check_equal_ts_ep_converter(Module(), inp)
class Module(torch.nn.Module):
def forward(self, x: List[int]):
length = len(x)
return torch.ones(length)
# aten::len.t
inp = ([1, 2, 3],)
self._check_equal_ts_ep_converter(Module(), inp, ["script"])
class Module(torch.nn.Module):
def forward(self, x: Dict[int, str]):
length = len(x)
return torch.ones(length)
# aten::len.Dict_int
inp = ({1: "a", 2: "b", 3: "c"},)
self._check_equal_ts_ep_converter(Module(), inp, ["script"])
class Module(torch.nn.Module):
def forward(self, x: Dict[bool, str]):
length = len(x)
return torch.ones(length)
# aten::len.Dict_bool
inp = ({True: "a", False: "b"},)
self._check_equal_ts_ep_converter(Module(), inp, ["script"])
class Module(torch.nn.Module):
def forward(self, x: Dict[float, str]):
length = len(x)
return torch.ones(length)
# aten::len.Dict_float
inp = ({1.2: "a", 3.4: "b"},)
self._check_equal_ts_ep_converter(Module(), inp, ["script"])
class Module(torch.nn.Module):
def forward(self, x: Dict[torch.Tensor, str]):
length = len(x)
return torch.ones(length)
# aten::len.Dict_Tensor
inp = ({torch.zeros(2, 3): "a", torch.ones(2, 3): "b"},)
self._check_equal_ts_ep_converter(Module(), inp, ["script"])
# aten::len.str and aten::len.Dict_str are not supported
# since torch._C._jit_flatten does not support str
# inp = ("abcdefg",)
# self._check_equal_ts_ep_converter(Module(), inp)
# inp = ({"a": 1, "b": 2},)
# self._check_equal_ts_ep_converter(Module(), inp)
def test_aten_add_t(self):
# python list append
class Module(torch.nn.Module):
def forward(self, x: List[torch.Tensor]):
out = []
out = out + x
a = torch.cat(out)
out = out + x
b = torch.cat(out)
return a, b
inp = ([torch.ones(2, 3), torch.ones(2, 3)],)
self._check_equal_ts_ep_converter(Module(), inp, ["script"])
def test_aten_to_dtype_with_mutating_storage(self):
class Module(torch.nn.Module):
def forward(self, x: torch.Tensor, y: torch.Tensor):
x = x.to(y.dtype)
torch.ops.aten.index_put_(x, [torch.tensor([0])], y)
return x
inp = (torch.ones(2, 3), torch.tensor([0, 0, 0]))
self._check_equal_ts_ep_converter(Module(), inp)
def test_prim_min(self):
class Module(torch.nn.Module):
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
x_len = len(x)
y_len = len(y)
# prim::min.int
len_int = min(x_len, y_len)
# prim::min.float
len_float = int(min(x_len * 2.0, y_len * 2.0))
# prim::min.self_int
len_self_int = min([x_len, y_len])
# prim::min.self_float
len_self_float = int(min([x_len * 2.0, y_len * 2.0]))
# prim::min.float_int
len_float_int = int(min(x_len * 2.0, y_len))
# prim::min.int_float
len_int_float = int(min(x_len, y_len * 2.0))
return torch.ones(
len_int
+ len_float
+ len_self_int
+ len_self_float
+ len_float_int
+ len_int_float
)
inp = (torch.randn(10, 2), torch.randn(5))
self._check_equal_ts_ep_converter(Module(), inp)
def test_prim_max(self):
class Module(torch.nn.Module):
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
x_len = len(x)
y_len = len(y)
# prim::max.int
len_int = max(x_len, y_len)
# prim::max.float
len_float = int(max(x_len * 2.0, y_len * 2.0))
# prim::max.self_int
len_self_int = max([x_len, y_len])
# prim::max.self_float
len_self_float = int(max([x_len * 2.0, y_len * 2.0]))
# prim::max.float_int
len_float_int = int(max(x_len * 2.0, y_len))
# prim::max.int_float
len_int_float = int(max(x_len, y_len * 2.0))
return torch.ones(
len_int
+ len_float
+ len_self_int
+ len_self_float
+ len_float_int
+ len_int_float
)
inp = (torch.randn(10, 2), torch.randn(5))
self._check_equal_ts_ep_converter(Module(), inp)
def test_aten___getitem___list(self):
class Module(torch.nn.Module):
def forward(self, x):
y = torch.split(x, 2)
return y[0]
inp = (torch.rand((3, 2)),)
self._check_equal_ts_ep_converter(Module(), inp)
def test_aten___getitem___dict(self):
class Module(torch.nn.Module):
def forward(self, x):
y = torch.split(x, 2)
d_int = {0: y[0], 1: y[1]}
d_str = {"0": y[0], "1": y[1]}
d_bool = {True: y[0], False: y[1]}
d_float = {0.1: y[0], 2.3: y[1]}
return d_int[0], d_str["0"], d_bool[True], d_float[0.1]
inp = (torch.rand((3, 2)),)
self._check_equal_ts_ep_converter(Module(), inp)
def test_prim_device(self):
class Module(torch.nn.Module):
def forward(self, x):
device = x.device
return torch.ones(2, 3, device=device)
inp = (torch.rand(3, 4),)
self._check_equal_ts_ep_converter(Module(), inp)
@requires_cuda
def test_prim_device_cuda(self):
class Module(torch.nn.Module):
def forward(self, x):
device = x.device
return torch.ones(2, 3, device=device)
inp = (torch.rand((3, 4), device="cuda:0"),)
self._check_equal_ts_ep_converter(Module(), inp)
def test_prim_dtype(self):
class Module(torch.nn.Module):
def forward(self, x):
dtype = x.dtype
return torch.ones(2, 3, dtype=dtype)
for dtype in [
torch.float32,
torch.double,
]:
inp = (torch.rand((3, 4), dtype=dtype),)
self._check_equal_ts_ep_converter(Module(), inp)
for dtype in [
torch.uint8,
torch.int8,
torch.int32,
]:
inp = (torch.randint(high=128, size=(3, 4), dtype=dtype),)
self._check_equal_ts_ep_converter(Module(), inp)
def test_convert_if_basic(self):
class M(torch.nn.Module):
def forward(self, x: torch.Tensor, y: torch.Tensor):
if x:
return y * y
else:
return y + y
inp = (torch.tensor(True), torch.tensor(4))
ep_list = self._check_equal_ts_ep_converter(M(), inp)
for ep in ep_list[1:]:
torch.testing.assert_close(
ep.module()(torch.tensor(False), torch.tensor(4)),
M()(torch.tensor(False), torch.tensor(4)),
)
def test_convert_if_tuple_out(self):
class M(torch.nn.Module):
def true_fn(self, y, z):
return (z * z, z + z)
def false_fn(self, y, z):
return (y * y * y, y + y)
def forward(self, x: torch.Tensor, y: torch.Tensor):
z = y * y
if x:
res = self.true_fn(y, z)
else:
res = self.false_fn(y, z)
return res[0] + res[1]
inp = (torch.tensor(True), torch.tensor(4))
ep_list = self._check_equal_ts_ep_converter(M(), inp)
for ep in ep_list[1:]:
torch.testing.assert_close(
ep.module()(torch.tensor(False), torch.tensor(4)),
M()(torch.tensor(False), torch.tensor(4)),
)
def test_convert_if_multiple_out(self):
class M(torch.nn.Module):
def true_fn(self, y, z):
return z * z
def false_fn(self, y, z):
return y * y * y
def forward(self, x: torch.Tensor, y: torch.Tensor):
z = y * y
if x:
res1 = self.true_fn(y, z)
res2 = y
else:
res1 = z
res2 = self.false_fn(y, z)
return res1 + res2
inp = (torch.tensor(True), torch.tensor(4))
ep_list = self._check_equal_ts_ep_converter(M(), inp)
for ep in ep_list[1:]:
torch.testing.assert_close(
ep.module()(torch.tensor(False), torch.tensor(4)),
M()(torch.tensor(False), torch.tensor(4)),
)
def test_profiler__record_function(self):
class Module(torch.nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor:
handle = torch.ops.profiler._record_function_enter_new("foo", None)
y = x * 2 + 4
torch.ops.profiler._record_function_exit(handle)
return y
x = torch.randn(10, 10)
self._check_equal_ts_ep_converter(Module(), (x,))
def test_aten_floordiv(self):
class Module(torch.nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor:
return x // 2
x = torch.randn(10, 10)
self._check_equal_ts_ep_converter(Module(), (x,))
def test_aten___is__(self):
class Module(torch.nn.Module):
def forward(
self, x: torch.Tensor, y: torch.Tensor
) -> Tuple[bool, torch.Tensor]:
z = x + 1
return x is y, z
# Traced function must return output that has tensors.
inp = (torch.randn(10, 10), torch.rand(10, 10))
self._check_equal_ts_ep_converter(Module(), inp, ["script"])
def test_aten___isnot__(self):
class Module(torch.nn.Module):
def forward(
self, x: torch.Tensor, y: torch.Tensor
) -> Tuple[bool, torch.Tensor]:
z = x + 1
return x is not y, z
# Traced function must return output that has tensors.
inp = (torch.randn(10, 10), torch.rand(10, 10))
self._check_equal_ts_ep_converter(Module(), inp, ["script"])
def test_aten___not__(self):
class Module(torch.nn.Module):
def forward(
self, x: torch.Tensor, y: torch.Tensor
) -> Tuple[bool, torch.Tensor]:
z = x + 1
return not (x is not y), z
# Traced function must return output that has tensors.
inp = (torch.randn(10, 10), torch.rand(10, 10))
self._check_equal_ts_ep_converter(Module(), inp, ["script"])
def test_ts2ep_converter_unpack(self):
class MUnpackList(torch.nn.Module):
def forward(self, x):
x, y = torch.split(x, 2)
return x + y
class MUnpackTuple(torch.nn.Module):
def forward(self, x_tuple: Tuple[torch.Tensor, torch.Tensor]):
x, y = x_tuple
x = x.cos()
return x + y
inp = (torch.ones(4),)
self._check_equal_ts_ep_converter(MUnpackList(), inp)
inp = ((torch.zeros(1, 4), torch.ones(1, 4)),)
self._check_equal_ts_ep_converter(MUnpackTuple(), inp)
@unittest.skipIf(
IS_WINDOWS,
"torch.cond doesn't go through torch.compile on windows"
"causing output not normalized as list",
)
def test_convert_retrace_nested_scripted_modules(self):
class Wrapper(torch.nn.Module):
def __init__(self, mod) -> None:
super().__init__()
self.mod = mod
def forward(self, x, y):
return self.mod(x, y)
class LinearM(torch.nn.Module):
def __init__(self, dim: int) -> None:
super().__init__()
self.linear = torch.nn.Linear(dim, dim)
def forward(self, x, y):
return self.linear(y)
class M(torch.nn.Module):
def __init__(self, dim: int) -> None:
super().__init__()
m = LinearM(dim)
m = torch.jit.script(m)
self.mod1 = m
self.mod2 = Wrapper(m)
def forward(self, x: torch.Tensor, y: torch.Tensor):
if x:
return -self.mod1(x, y) - self.mod2(x, y)
else:
return -self.mod1(x, y) + self.mod2(x, y)
class NestedM(torch.nn.Module):
def __init__(self, dim: int) -> None:
super().__init__()
m = M(dim)
m = torch.jit.script(m)
self.mod1 = m
self.mod2 = Wrapper(m)
def forward(self, x: torch.Tensor, y: torch.Tensor):
if x:
return self.mod1(x, y) + self.mod2(x, y)
else:
return self.mod1(x, y) - self.mod2(x, y)
inp = (
torch.tensor(True),
torch.randn([3, 3]),
)
self._check_equal_ts_ep_converter(NestedM(3), inp)
def test_convert_nn_module_with_nested_param(self):
class M(torch.nn.Module):
def __init__(self, dim: int) -> None:
super().__init__()
self.linear = torch.nn.Linear(dim, dim)
def forward(self, x: torch.Tensor):
return self.linear(x)
class NestedM(torch.nn.Module):
def __init__(self, dim: int) -> None:
super().__init__()
self.linear = torch.nn.Linear(dim, dim)
self.m = M(dim)
def forward(self, x: torch.Tensor):
return self.linear(self.m(x))
class SuperNestedM(torch.nn.Module):
def __init__(self, dim: int) -> None:
super().__init__()
self.linear = torch.nn.Linear(dim, dim)
self.m = NestedM(dim)
def forward(self, x: torch.Tensor):
return self.linear(self.m(x))
inp = (torch.ones(3),)
orig_m = NestedM(3)
self._check_equal_ts_ep_converter(orig_m, inp)
orig_m = SuperNestedM(3)
self._check_equal_ts_ep_converter(orig_m, inp)
def test_convert_nn_module_with_nested_buffer(self):
class M(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.w = torch.nn.Buffer(torch.randn(1))
def forward(self, x: torch.Tensor):
return self.w + x
class NestedM(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.m = M()
self.w = torch.nn.Buffer(torch.randn(1))
def forward(self, x: torch.Tensor):
return self.w + self.m(x)
class SuperNestedM(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.m = NestedM()
self.w = torch.nn.Buffer(torch.randn(1))
def forward(self, x: torch.Tensor):
return self.w + self.m(x)
inp = (torch.ones(1),)
orig_m = NestedM()
self._check_equal_ts_ep_converter(orig_m, inp)
orig_m = SuperNestedM()
self._check_equal_ts_ep_converter(orig_m, inp)
def test_convert_nn_module_with_nested_if_and_buffer(self):
class M(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.w = torch.nn.Buffer(torch.randn(1))
self.count = 1
def forward(self, x: torch.Tensor):
return self.w + x + self.count
class NestedM(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.m1 = M()
self.m2 = M()
self.w = torch.nn.Buffer(torch.randn(1))
def forward(self, x: torch.Tensor):
if torch.sum(x) > 1:
return self.w + self.m1(x)
else:
return self.w + self.m2(x)
# Super nested, parameters neeed to lifted
# multiple times.
class SuperNestedM(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.m1 = NestedM()
self.m2 = NestedM()
self.w = torch.nn.Buffer(torch.randn(1))
def forward(self, x: torch.Tensor):
if torch.max(x) > 1:
return self.w + self.m1(x)
else:
return self.w + self.m2(x)
# Super nested module testing.
inp = (torch.ones(1),)
orig_m = SuperNestedM()
ep_list = self._check_equal_ts_ep_converter(orig_m, inp)
t = inp[0]
t -= 1
for ep in ep_list:
torch.testing.assert_close(
ep.module()(*inp),
orig_m(*inp),
)
@unittest.skipIf(
IS_WINDOWS,
"torch.cond doesn't go through torch.compile on windows"
"causing output not normalized as list",
)
def test_convert_nn_module_with_nested_if_and_param(self):
class M(torch.nn.Module):
def __init__(self, dim: int) -> None:
super().__init__()
self.linear = torch.nn.Linear(dim, dim)
def forward(self, x: torch.Tensor):
return self.linear(x)
class NestedM(torch.nn.Module):
def __init__(self, dim: int) -> None:
super().__init__()
self.m1 = M(dim)
self.m2 = M(dim)
self.linear = torch.nn.Linear(dim, dim)
def forward(self, x: torch.Tensor):
if torch.sum(x) > 1:
return self.linear(self.m1(x))
else:
return self.linear(self.m2(x))
# Super nested, parameters neeed to lifted
# multiple times.
class SuperNestedM1(torch.nn.Module):
def __init__(self, dim: int) -> None:
super().__init__()
self.m1 = NestedM(dim)
self.m2 = NestedM(dim)
self.linear = torch.nn.Linear(dim, dim)
def forward(self, x: torch.Tensor):
if torch.max(x) > 1:
return self.linear(self.m1(x))
else:
return self.linear(self.m2(x))
# Super nested, even the input needs to be
# lifted recursively due to value propogation optimiztaion.
class SuperNestedM2(torch.nn.Module):
def __init__(self, dim: int) -> None:
super().__init__()
self.m1 = NestedM(dim)
self.m2 = NestedM(dim)
self.linear = torch.nn.Linear(dim, dim)
def forward(self, x: torch.Tensor):
if torch.sum(x) > 1:
return self.linear(self.m1(x))
else:
return self.linear(self.m2(x))
# Basic module testing.
inp = (torch.ones(3),)
orig_m = M(3)
ep_list = self._check_equal_ts_ep_converter(orig_m, inp)
t = inp[0]
t -= 0.8
for ep in ep_list[1:]:
torch.testing.assert_close(
ep.module()(*inp),
orig_m(*inp),
)
# Nested module testing.
inp = (torch.ones(3),)
orig_m = NestedM(3)
ep_list = self._check_equal_ts_ep_converter(orig_m, inp)
t = inp[0]
t -= 0.8
# Skip jit.traced because it specializes on one path.
for ep in ep_list[1:]:
torch.testing.assert_close(
ep.module()(*inp),
orig_m(*inp),
)
# Super nested module testing.
inp = (torch.ones(3),)
orig_m = SuperNestedM1(3)
ep_list = self._check_equal_ts_ep_converter(orig_m, inp)
t = inp[0]
t -= 0.8
# Skip jit.traced because it specializes on one path.
for ep in ep_list[1:]:
torch.testing.assert_close(
ep.module()(*inp),
orig_m(*inp),
)
# Super nested module testing.
inp = (torch.ones(3),)
orig_m = SuperNestedM2(3)
ep_list = self._check_equal_ts_ep_converter(orig_m, inp)
t = inp[0]
t -= 0.8
# Skip jit.traced because it specializes on one path.
for ep in ep_list[1:]:
torch.testing.assert_close(
ep.module()(*inp),
orig_m(*inp),
)
def test_ts2ep_converter_contains(self):
class MIn(torch.nn.Module):
def forward(self, x: torch.Tensor):
return x.dtype in [torch.float32, torch.float64]
class MNotIn(torch.nn.Module):
def forward(self, x: torch.Tensor):
return x.dtype in [torch.int8]
class MTensorIn(torch.nn.Module):
def forward(self, x: torch.Tensor, x_dict: Dict[torch.Tensor, str]):
return x in x_dict
# Traced function must return output that has tensors.
inp = (torch.tensor(4),)
self._check_equal_ts_ep_converter(MIn(), inp, ["script"])
self._check_equal_ts_ep_converter(MNotIn(), inp, ["script"])
# TODO: update test to use reference for in.
inp = (torch.tensor(4), {torch.tensor(4): "foo"})
self._check_equal_ts_ep_converter(MTensorIn(), inp, ["script"])
inp = (torch.tensor(1), {torch.tensor(4): "foo"})
self._check_equal_ts_ep_converter(MTensorIn(), inp, ["script"])
def test_ts2ep_converter_custom_op(self):
with torch.library._scoped_library("mylib", "FRAGMENT") as lib:
torch._dynamo.config.capture_scalar_outputs = True
torch._dynamo.config.capture_dynamic_output_shape_ops = True
torch.library.define(
"mylib::foo",
"(Tensor x) -> Tensor",
lib=lib,
)
# PyTorch custorm op implementation
@torch.library.impl(
"mylib::foo",
"CompositeExplicitAutograd",
lib=lib,
)
def foo_impl(x):
return x + x
# Meta function of the custom op.
@torch.library.impl_abstract(
"mylib::foo",
lib=lib,
)
def foo_meta(x):
return x + x
class M(torch.nn.Module):
def forward(self, x):
return torch.ops.mylib.foo(x)
inp = (torch.randn(3, 3),)
m = M()
self._check_equal_ts_ep_converter(m, inp)
def test_convert_func_without_param(self):
def func1(x, y):
return x + y
def func2(x, y):
if x.sum() > 0:
return x + y
else:
return x - y
inp = (
torch.tensor(1),
torch.tensor(1),
)
self._check_equal_ts_ep_converter(func1, inp)
ep_list = self._check_equal_ts_ep_converter(func2, inp)
t = inp[0]
t -= 1
for ep in ep_list[1:]:
torch.testing.assert_close(
ep.module()(*inp),
func2(*inp),
)
def test_implicit_constant_to_tensor_handling(self):
def func1(x):
return x + 2
def func2(x, y):
return x * y / (x - 2 * y) + y
def func3(x):
return x + torch.tensor([3])
def func4():
val = torch.tensor(float("inf"))
return torch.full((10, 10), val)
def func5():
x = -1
return x * torch.ones(1, dtype=torch.float), torch.zeros(
1, dtype=torch.float
)
def func6(x1, x2, x3, x4):
return (
x1.numel(),
x1.size(),
x2.numel(),
x2.size(),
x3.numel(),
x3.size(),
x4.numel(),
x4.size(),
torch.ones(x1.numel()), # Just make sure downstream ops still work.
torch.ones(x1.size()), # Just make sure downstream ops still work.
)
class M1(torch.nn.Module):
def __init__(self, value):
super().__init__()
self.x = torch.tensor(value)
def forward(self):
return self.x.clone()
class M2(torch.nn.Module):
def forward(self, x):
return torch.tensor(4) + x
inp = (torch.randn([2, 2]),)
self._check_equal_ts_ep_converter(func1, inp)
inp = (torch.randn([2, 2]), torch.randn([2, 2]))
self._check_equal_ts_ep_converter(func2, inp)
inp = (torch.randn([2, 2]),)
self._check_equal_ts_ep_converter(func3, inp)
self._check_equal_ts_ep_converter(func4, ())
self._check_equal_ts_ep_converter(M1(5), ())
inp = (torch.randn(2),)
self._check_equal_ts_ep_converter(M2(), inp)
self._check_equal_ts_ep_converter(func5, ())
inp = (
torch.randn([2, 3, 4]).to(torch.int8),
torch.randn([2, 3, 4]).to(torch.int32),
torch.randn([2, 3, 4]).to(torch.float32),
torch.randn([2, 3, 4]).to(torch.float64),
)
ep_list = self._check_equal_ts_ep_converter(func6, inp)
# TODO: Additional check once dynamic shape is supported.
# for ep in ep_list:
# self.assertEqual(
# ep.module()(
# torch.randn([1, 1, 1]).to(torch.int8),
# torch.randn([1, 1, 1]).to(torch.int32),
# torch.randn([1, 1, 1]).to(torch.float32),
# torch.randn([1, 1, 1]).to(torch.float64),
# )[0], 1
# )
def test_aten_tensor_dtype_int(self):
class M(torch.nn.Module):
def forward(self, x):
y = torch.tensor(1, dtype=torch.int32)
return y + x
ep_list = self._check_equal_ts_ep_converter(M(), (torch.tensor(1),))
for ep in ep_list:
self.assertEqual(len(ep.constants), 1)
def test_aten_tensor_prim_dtype(self):
class M(torch.nn.Module):
def forward(self, x):
y = torch.tensor(1, dtype=x.dtype)
return y + x
ep_list = self._check_equal_ts_ep_converter(M(), (torch.tensor(1),))
for ep in ep_list:
self.assertEqual(len(ep.constants), 1)
def test_aten_tensor_dynamic(self):
class M(torch.nn.Module):
def forward(self, x):
s = x.shape[0]
y = torch.tensor(s)
return y
ep_list = self._check_equal_ts_ep_converter(M(), (torch.ones(3),))
for ep in ep_list:
self.assertEqual(len(ep.constants), 0)
# TODO: Additional check once dynamic shape is supported.
# for ep in ep_list:
# torch.testing.assert_close(
# ep.module()(torch.ones(4)),
# M()(torch.ones(4)),
# )
class M(torch.nn.Module):
def forward(self, x):
s = x.shape[0]
y = torch.tensor([s, s * 2, 1])
return y
ep_list = self._check_equal_ts_ep_converter(M(), (torch.ones(3),))
# Trace directly inline a tensor constant.
for ep in ep_list[1:]:
self.assertEqual(len(ep.constants), 0)
# TODO: Additional check once dynamic shape is supported.
# for ep in ep_list:
# torch.testing.assert_close(
# ep.module()(torch.ones(4)),
# M()(torch.ones(4)),
# )
def test_prim_tolist(self):
class Module(torch.nn.Module):
def forward(self, x: torch.Tensor) -> List[int]:
return x.tolist()
inp = (torch.tensor([1, 2, 3]),)
self._check_equal_ts_ep_converter(Module(), inp, ["script"])
class Module(torch.nn.Module):
def forward(self, x: torch.Tensor) -> List[List[int]]:
return x.tolist()
inp = (torch.tensor([[1, 2, 3], [4, 5, 6]]),)
self._check_equal_ts_ep_converter(Module(), inp, ["script"])
def test_get_tensor_constants(self):
# Since self.data is only read but not written, it is lifted as
# constant tensors.
class Foo(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.data = torch.randn(3, 2)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return x + self.data
class Goo(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.data = torch.randn(3, 2)
self.foo = Foo()
def forward(self, x: torch.Tensor) -> torch.Tensor:
return x + self.data + self.foo.data + self.foo(x)
inp = (torch.randn(3, 2),)
goo = Goo()
self._check_equal_ts_ep_converter(goo, inp)
def test_prim_SetAttr(self):
class Module(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.data = torch.nn.Buffer(torch.ones(3, 2))
def forward(self, x: torch.Tensor) -> torch.Tensor:
self.data = self.data + x
return x + x
inp = (torch.ones(3, 2),)
self._check_equal_ts_ep_converter(
Module, inp, ["script"], check_persistent=True
)
class Module(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.data = torch.nn.Buffer(torch.ones(3, 2))
def forward(self, x: torch.Tensor) -> torch.Tensor:
self.data = self.data + x
return x + self.data
inp = (torch.ones(3, 2),)
self._check_equal_ts_ep_converter(
Module, inp, ["script"], check_persistent=True
)
# export lifts a tensor constant (self.data) as an input if it is not assigned.
# If it is assigned, export will error and ask users to register it as a buffer.
# In converter, we change tensor constants that are assigned as a buffer automatically,
# since it might be hard to manually register them as buffers.
class Module(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.data = torch.ones(3, 2)
def forward(self, x: torch.Tensor) -> torch.Tensor:
self.data = self.data + x
return x + self.data
inp = (torch.ones(3, 2),)
self._check_equal_ts_ep_converter(
Module,
inp,
["script"],
check_persistent=True,
lifted_tensor_constants=OrderedDict([("data", torch.ones(3, 2))]),
)
class Module(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.count = 0
def forward(self, x: torch.Tensor) -> torch.Tensor:
self.count += 1
return x + self.count
# check_persistent is False since export specializes on non-tensor constants
inp = (torch.ones(3, 2),)
self._check_equal_ts_ep_converter(
Module(), inp, ["script"], check_persistent=False
)
class M(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.count = 0
def forward(self, x):
count1 = self.count
self.count += 1
count2 = self.count
self.count += 1
count3 = self.count
return x + count1 + count2 + count3
inp = (torch.ones(1),)
self._check_equal_ts_ep_converter(M(), inp, ["script"], check_persistent=False)
class M(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.w2 = torch.nn.Buffer(torch.ones(1))
def forward(self, x: torch.Tensor):
self.w2 += 1
return self.w2
inp = (torch.ones(1),)
self._check_equal_ts_ep_converter(M, inp, ["script"], check_persistent=True)
def test_raise_exception(self):
class Module(torch.nn.Module):
def forward(self, x: torch.Tensor, y: int) -> torch.Tensor:
if y > 0:
raise RuntimeError("test")
return x + y
# match non-strict export behavior that errors when the given input leads to
# RaiseException.
with self.assertRaisesRegex(torch.jit.Error, "builtins.RuntimeError"):
inp = (torch.randn(3, 2), 1)
self._check_equal_ts_ep_converter(Module(), inp, ["script"])
# Matching non-strict export behavior that only executes 1 if-branch according
# to the given input.
inp = (torch.randn(3, 2), 0)
self._check_equal_ts_ep_converter(Module(), inp, ["script"])
class Module(torch.nn.Module):
def forward(self, x: torch.Tensor, y: int) -> torch.Tensor:
z = x
if y > 0:
raise RuntimeError("test")
# z = x
else:
z = x + y
return x + y + z
# match non-strict export behavior that errors when the given input leads to
# RaiseException.
with self.assertRaisesRegex(torch.jit.Error, "builtins.RuntimeError"):
inp = (torch.randn(3, 2), 1)
self._check_equal_ts_ep_converter(Module(), inp, ["script"])
# Matching non-strict export behavior that only executes 1 if-branch according
# to the given input.
inp = (torch.randn(3, 2), 0)
self._check_equal_ts_ep_converter(Module(), inp, ["script"])
def test_context_manager(self):
class ContextManager:
def __init__(self) -> None:
self.count = 0
return
def __enter__(self):
self.count += 1
return
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
self.count -= 1
return
class M(torch.nn.Module):
def forward(self, x, y):
with ContextManager():
res = x + y
return res
inp = (torch.ones(3, 3), torch.ones(3, 3))
self._check_equal_ts_ep_converter(M(), inp)
def test_hidden_input_name(self):
@torch.jit.script
def func1(x):
return x + 1
def func2(*args):
v = torch.cat(args, dim=1)
return v * v
inp = (torch.randn([1, 1]),)
self._check_equal_ts_ep_converter(func1, inp)
inp = (torch.ones(5, 5),)
# Cannot script again.
self._check_equal_ts_ep_converter(torch.ops.aten.relu, inp, ["trace"])
M = 2
Ns = [4, 2, 1]
empty = torch.tensor([], dtype=torch.double)
values = [empty] + [torch.randn(M, N) for N in Ns]
# Cannot script variable length inputs.
self._check_equal_ts_ep_converter(func2, tuple(values), ["trace"])
def test_ts2ep_multi_outputs_on_call_ops(self):
class M(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.pool = torch.nn.AdaptiveMaxPool2d((2, 2), return_indices=True)
def forward(self, x: torch.Tensor, y: torch.Tensor):
return (
torch.max(x, dim=0),
torch.topk(x, 3),
torch.sort(x, dim=0),
self.pool(y),
)
inp = (torch.randn([4, 4]), torch.randn([1, 1, 10, 10]))
self._check_equal_ts_ep_converter(M(), inp)
def test_aten_append_t(self):
class M(torch.nn.Module):
def forward(self, x: List[torch.Tensor]):
out = []
out.append(x[0] + x[1])
out.append(x[0] - x[1])
out1 = torch.cat(out)
out.append(x[0] * x[1])
out2 = torch.cat(out)
return out, out1, out2
inp = ([torch.ones(2, 3), torch.ones(2, 3)],)
# Trace already unrolls the list.
self._check_equal_ts_ep_converter(M(), inp, ["script"])
def test_convert_script_object(self):
class M1(torch.nn.Module):
def __init__(self):
super().__init__()
self.tq = _empty_tensor_queue()
def forward(self, x: torch.Tensor):
self.tq.push(x)
torch.ops._TorchScriptTesting.queue_push(self.tq, x.cos())
return torch.ops._TorchScriptTesting.queue_pop(self.tq), self.tq.pop()
inp = (torch.randn(2, 3),)
self._check_equal_ts_ep_converter(M1(), inp, ["script"])
def test_ts2ep_with_loop(self):
def func1(x, x_list: List[torch.Tensor]):
a, b, c = x, x, x
for i in range(1, 5, 2):
for k in range(5):
a = a + a + k
b = b + b - k
x_list.append(x_list[k] + x_list[k + 1])
for k in range(5):
b = b + b - k
c = c + c * k
x_list.append(x_list[k] + x_list[k + 1] - x_list[k + 2])
return x, x_list
def func2(x):
for i in range(x.size(0)):
x = x * x * i
return x
def func3(x):
while x.sum() < 10:
x += x.sin()
return x
inp = (
torch.tensor(1),
[torch.ones([2, 2]), torch.ones([2, 2]) * 2],
)
# Trace unrolls the loop.
self._check_equal_ts_ep_converter(func1, inp, ["script"])
# TODO: (2/N)
# Trace unrolls the loop.
# self._check_equal_ts_ep_converter(func2, inp, ["script"])
# TODO: (3/N)
# Trace unrolls the loop.
# self._check_equal_ts_ep_converter(func3, inp, ["script"])
@unittest.skipIf(
IS_WINDOWS,
"Windows does not support qnnpack",
)
def test_ts2ep_convert_quantized_model(self):
class Standalone(torch.nn.Module):
def __init__(self):
super().__init__()
self.quant = torch.ao.quantization.QuantStub()
self.conv1 = torch.nn.Conv2d(1, 1, 1)
self.conv2 = torch.nn.Conv2d(1, 1, 1)
self.relu = torch.nn.ReLU()
self.dequant = torch.ao.quantization.DeQuantStub()
def forward(self, x):
x = self.quant(x)
x = self.conv1(x)
x = self.conv2(x)
x = self.relu(x)
x = self.dequant(x)
return x
def fuse_model(self):
torch.ao.quantization.fuse_modules(
self, [["conv2", "relu"]], inplace=True
)
with override_quantized_engine("qnnpack"):
model = Standalone()
model.qconfig = torch.ao.quantization.get_default_qconfig("qnnpack")
model.fuse_model()
torch.ao.quantization.prepare(model, inplace=True)
model(torch.randn(4, 1, 4, 4))
torch.ao.quantization.convert(model, inplace=True)
# Use customized checking here, because state_dict of quantization will be
# modified by the quantization pass.
inp = (torch.randn(4, 1, 4, 4),)
original_ts_model = torch.jit.script(model)
ts_model = torch.jit.script(model)
converter = TS2EPConverter(ts_model, inp)
ep = converter.convert()
orig_out, _ = pytree.tree_flatten(original_ts_model(*inp))
ep_out, _ = pytree.tree_flatten(ep.module()(*inp))
self._check_tensor_list_equal(orig_out, ep_out)
def test_ts2ep_convert_quantized_model_with_opcontext(self):
class M(torch.nn.Module):
def __init__(self, linear_op):
super().__init__()
self.linear_op = linear_op
def forward(self, x):
x = torch.ops.prepacked.linear_clamp_run(x, self.linear_op)
return x
linear_op = torch.ops.prepacked.linear_clamp_prepack(
torch.randn(10, 10), torch.randn(10)
)
m = M(linear_op)
inp = (torch.randn(1, 10),)
self._check_equal_ts_ep_converter(m, inp, ["script"])
if __name__ == "__main__":
run_tests()