blob: 05b9b30ea12aa80b40e926d5c9a79ec492a97e88 [file] [log] [blame]
# Owner(s): ["oncall: mobile"]
import inspect
import io
from tempfile import TemporaryFileName
from typing import Dict, List
import torch
import torch.utils.bundled_inputs
from torch.jit.mobile import _export_operator_list, _load_for_lite_interpreter
from torch.testing import FileCheck
from torch.testing._internal.common_quantization import (
AnnotatedNestedModel,
AnnotatedSingleLayerLinearModel,
QuantizationLiteTestCase,
TwoLayerLinearModel,
)
from torch.testing._internal.common_utils import run_tests, TestCase
class TestLiteScriptModule(TestCase):
def getScriptExportImportCopy(
self, m, save_mobile_debug_info=True, also_test_file=False
):
m_scripted = torch.jit.script(m)
if not also_test_file:
buffer = io.BytesIO(
m_scripted._save_to_buffer_for_lite_interpreter(
_save_mobile_debug_info=save_mobile_debug_info
)
)
buffer.seek(0)
mobile_module = _load_for_lite_interpreter(buffer)
return mobile_module
with TemporaryFileName() as fname:
m_scripted._save_for_lite_interpreter(
fname, _save_mobile_debug_info=save_mobile_debug_info
)
mobile_module = _load_for_lite_interpreter(fname)
return mobile_module
def test_load_mobile_module(self):
class MyTestModule(torch.nn.Module):
def forward(self, x):
return x + 10
input = torch.tensor([1])
script_module = torch.jit.script(MyTestModule())
script_module_result = script_module(input)
buffer = io.BytesIO(script_module._save_to_buffer_for_lite_interpreter())
buffer.seek(0)
mobile_module = _load_for_lite_interpreter(buffer)
mobile_module_result = mobile_module(input)
torch.testing.assert_close(script_module_result, mobile_module_result)
mobile_module_forward_result = mobile_module.forward(input)
torch.testing.assert_close(script_module_result, mobile_module_forward_result)
mobile_module_run_method_result = mobile_module.run_method("forward", input)
torch.testing.assert_close(
script_module_result, mobile_module_run_method_result
)
def test_save_mobile_module_with_debug_info_with_trace(self):
class A(torch.nn.Module):
def forward(self, x, y):
return x * y
class B(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.A0 = A()
self.A1 = A()
def forward(self, x, y, z):
return self.A0(x, y) + self.A1(y, z)
for export_method in ["trace", "script"]:
x = torch.rand((2, 3))
y = torch.rand((2, 3))
z = torch.rand((2, 3))
if export_method == "trace":
trace_module = torch.jit.trace(B(), [x, y, z])
else:
trace_module = torch.jit.script(B())
exported_module = trace_module._save_to_buffer_for_lite_interpreter(
_save_mobile_debug_info=True
)
buffer = io.BytesIO(exported_module)
buffer.seek(0)
assert b"callstack_debug_map.pkl" in exported_module
mobile_module = _load_for_lite_interpreter(buffer)
with self.assertRaisesRegex(
RuntimeError,
r"Module hierarchy:top\(B\)::<unknown>.A0\(A\)::forward.aten::mul",
):
x = torch.rand((2, 3))
y = torch.rand((8, 10))
z = torch.rand((8, 10))
mobile_module(x, y, z)
with self.assertRaisesRegex(
RuntimeError,
r"Module hierarchy:top\(B\)::<unknown>.A1\(A\)::forward.aten::mul",
):
x = torch.rand((2, 3))
y = torch.rand((2, 3))
z = torch.rand((8, 10))
mobile_module(x, y, z)
def test_load_mobile_module_with_debug_info(self):
class MyTestModule(torch.nn.Module):
def forward(self, x):
return x + 5
input = torch.tensor([3])
script_module = torch.jit.script(MyTestModule())
script_module_result = script_module(input)
buffer = io.BytesIO(
script_module._save_to_buffer_for_lite_interpreter(
_save_mobile_debug_info=True
)
)
buffer.seek(0)
mobile_module = _load_for_lite_interpreter(buffer)
mobile_module_result = mobile_module(input)
torch.testing.assert_close(script_module_result, mobile_module_result)
mobile_module_forward_result = mobile_module.forward(input)
torch.testing.assert_close(script_module_result, mobile_module_forward_result)
mobile_module_run_method_result = mobile_module.run_method("forward", input)
torch.testing.assert_close(
script_module_result, mobile_module_run_method_result
)
def test_find_and_run_method(self):
class MyTestModule(torch.nn.Module):
def forward(self, arg):
return arg
input = (torch.tensor([1]),)
script_module = torch.jit.script(MyTestModule())
script_module_result = script_module(*input)
buffer = io.BytesIO(script_module._save_to_buffer_for_lite_interpreter())
buffer.seek(0)
mobile_module = _load_for_lite_interpreter(buffer)
has_bundled_inputs = mobile_module.find_method("get_all_bundled_inputs")
self.assertFalse(has_bundled_inputs)
torch.utils.bundled_inputs.augment_model_with_bundled_inputs(
script_module, [input], []
)
buffer = io.BytesIO(script_module._save_to_buffer_for_lite_interpreter())
buffer.seek(0)
mobile_module = _load_for_lite_interpreter(buffer)
has_bundled_inputs = mobile_module.find_method("get_all_bundled_inputs")
self.assertTrue(has_bundled_inputs)
bundled_inputs = mobile_module.run_method("get_all_bundled_inputs")
mobile_module_result = mobile_module.forward(*bundled_inputs[0])
torch.testing.assert_close(script_module_result, mobile_module_result)
def test_method_calls_with_optional_arg(self):
class A(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
# opt arg in script-to-script invocation
def forward(self, x, two: int = 2):
return x + two
class B(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.A0 = A()
# opt arg in Python-to-script invocation
def forward(self, x, one: int = 1):
return self.A0(x) + one
script_module = torch.jit.script(B())
buffer = io.BytesIO(script_module._save_to_buffer_for_lite_interpreter())
mobile_module = _load_for_lite_interpreter(buffer)
input = torch.tensor([5])
script_module_forward_result = script_module.forward(input)
mobile_module_forward_result = mobile_module.forward(input)
torch.testing.assert_close(
script_module_forward_result, mobile_module_forward_result
)
# change ref only
script_module_forward_result = script_module.forward(input, 2)
self.assertFalse(
(script_module_forward_result == mobile_module_forward_result).all().item()
)
# now both match again
mobile_module_forward_result = mobile_module.forward(input, 2)
torch.testing.assert_close(
script_module_forward_result, mobile_module_forward_result
)
def test_unsupported_classtype(self):
class Foo:
def __init__(self) -> None:
return
def func(self, x: int, y: int):
return x + y
class MyTestModule(torch.nn.Module):
def forward(self, arg):
f = Foo()
return f.func(1, 2)
script_module = torch.jit.script(MyTestModule())
with self.assertRaisesRegex(
RuntimeError,
r"Workaround: instead of using arbitrary class type \(class Foo\(\)\), "
r"define a pytorch class \(class Foo\(torch\.nn\.Module\)\)\. "
r"The problematic type is: ",
):
script_module._save_to_buffer_for_lite_interpreter()
def test_unsupported_return_list_with_module_class(self):
class Foo(torch.nn.Module):
pass
class MyTestModuleForListWithModuleClass(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.foo = Foo()
def forward(self):
my_list: List[Foo] = [self.foo]
return my_list
script_module = torch.jit.script(MyTestModuleForListWithModuleClass())
with self.assertRaisesRegex(
RuntimeError,
r"^Returning a list or dictionary with pytorch class type "
r"is not supported in mobile module "
r"\(List\[Foo\] or Dict\[int\, Foo\] for class Foo\(torch\.nn\.Module\)\)\. "
r"Workaround\: instead of using pytorch class as their element type\, "
r"use a combination of list\, dictionary\, and single types\.$",
):
script_module._save_to_buffer_for_lite_interpreter()
def test_unsupported_return_dict_with_module_class(self):
class Foo(torch.nn.Module):
pass
class MyTestModuleForDictWithModuleClass(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.foo = Foo()
def forward(self):
my_dict: Dict[int, Foo] = {1: self.foo}
return my_dict
script_module = torch.jit.script(MyTestModuleForDictWithModuleClass())
with self.assertRaisesRegex(
RuntimeError,
r"^Returning a list or dictionary with pytorch class type "
r"is not supported in mobile module "
r"\(List\[Foo\] or Dict\[int\, Foo\] for class Foo\(torch\.nn\.Module\)\)\. "
r"Workaround\: instead of using pytorch class as their element type\, "
r"use a combination of list\, dictionary\, and single types\.$",
):
script_module._save_to_buffer_for_lite_interpreter()
def test_module_export_operator_list(self):
class Foo(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.weight = torch.ones((20, 1, 5, 5))
self.bias = torch.ones(20)
def forward(self, input):
x1 = torch.zeros(2, 2)
x2 = torch.empty_like(torch.empty(2, 2))
x3 = torch._convolution(
input,
self.weight,
self.bias,
[1, 1],
[0, 0],
[1, 1],
False,
[0, 0],
1,
False,
False,
True,
True,
)
return (x1, x2, x3)
m = torch.jit.script(Foo())
buffer = io.BytesIO(m._save_to_buffer_for_lite_interpreter())
buffer.seek(0)
mobile_module = _load_for_lite_interpreter(buffer)
expected_ops = {
"aten::_convolution",
"aten::empty.memory_format",
"aten::empty_like",
"aten::zeros",
}
actual_ops = _export_operator_list(mobile_module)
self.assertEqual(actual_ops, expected_ops)
def test_source_range_simple(self):
class FooTest(torch.jit.ScriptModule):
@torch.jit.script_method
def forward(self, x, w):
return torch.mm(x, w.t())
ft = FooTest()
loaded = self.getScriptExportImportCopy(ft)
_, lineno = inspect.getsourcelines(FooTest)
with self.assertRaisesRegex(
RuntimeError, f'test_lite_script_module.py", line {lineno + 3}'
):
loaded(torch.rand(3, 4), torch.rand(30, 40))
def test_source_range_raise_exception(self):
class FooTest2(torch.jit.ScriptModule):
@torch.jit.script_method
def forward(self):
raise RuntimeError("foo")
_, lineno = inspect.getsourcelines(FooTest2)
# In C++ code, the type of exception thrown is torch::jit::JITException
# which does not extend c10::Error, and hence it isn't possible to add
# additional context to the exception message and preserve the correct
# C++ stack trace for symbolication. i.e. it isn't possible to add
# the debug handle string to show where in the Python code the exception
# occured w/o first changing
# torch::jit::JITException to extend c10::Error.
with self.assertRaisesRegex(torch.jit.Error, "foo"):
ft = FooTest2()
loaded = self.getScriptExportImportCopy(ft)
loaded()
def test_source_range_function_call(self):
class FooTest3(torch.jit.ScriptModule):
@torch.jit.script_method
def add_method(self, x, w):
return x + w
@torch.jit.script_method
def forward(self, x, y, w):
x = x * y
x = x + 2
return self.add_method(x, w)
ft = FooTest3()
loaded = self.getScriptExportImportCopy(ft)
_, lineno = inspect.getsourcelines(FooTest3)
try:
loaded(torch.rand(3, 4), torch.rand(3, 4), torch.rand(30, 40))
except RuntimeError as e:
error_message = f"{e}"
self.assertTrue(
f'test_lite_script_module.py", line {lineno + 3}' in error_message
)
self.assertTrue(
f'test_lite_script_module.py", line {lineno + 9}' in error_message
)
self.assertTrue("top(FooTest3)" in error_message)
def test_source_range_no_debug_info(self):
class FooTest4(torch.jit.ScriptModule):
@torch.jit.script_method
def forward(self, x, w):
return torch.mm(x, w.t())
ft = FooTest4()
loaded = self.getScriptExportImportCopy(ft, save_mobile_debug_info=False)
try:
loaded(torch.rand(3, 4), torch.rand(30, 40))
except RuntimeError as e:
error_message = f"{e}"
self.assertTrue("test_lite_script_module.py" not in error_message)
def test_source_range_raise_exc(self):
class FooTest5(torch.jit.ScriptModule):
def __init__(self, val: int):
super().__init__()
self.val = val
@torch.jit.script_method
def add_method(self, val: int, x, w):
if val == self.val:
raise RuntimeError("self.val and val are same")
return x + w
@torch.jit.script_method
def forward(self, val: int, x, y, w):
x = x * y
x = x + 2
return self.add_method(val, x, w)
ft = FooTest5(42)
loaded = self.getScriptExportImportCopy(ft)
_, lineno = inspect.getsourcelines(FooTest5)
try:
loaded(42, torch.rand(3, 4), torch.rand(3, 4), torch.rand(30, 40))
except torch.jit.Error as e:
error_message = f"{e}"
# In C++ code, the type of exception thrown is torch::jit::JITException
# which does not extend c10::Error, and hence it isn't possible to add
# additional context to the exception message and preserve the correct
# C++ stack trace for symbolication. i.e. it isn't possible to add
# the debug handle string to show where in the Python code the exception
# occured w/o first changing
# torch::jit::JITException to extend c10::Error.
self.assertTrue("self.val and val are same" in error_message)
def test_stacktrace_interface_call(self):
@torch.jit.interface
class Forward(torch.nn.Module):
def forward(self, x) -> torch.Tensor:
pass
def forwardError(self, x) -> torch.Tensor:
pass
class B(torch.nn.Module):
def forward(self, x):
return x
def forwardError(self, x):
return self.call() + x
def call(self):
return torch.ones(-1)
class A(torch.nn.Module):
b: Forward
def __init__(self) -> None:
super().__init__()
self.b = B()
def forward(self):
self.b.forward(torch.ones(1))
self.b.forwardError(torch.ones(1))
a = torch.jit.script(A())
torch._C._enable_mobile_interface_call_export()
buffer = io.BytesIO(
a._save_to_buffer_for_lite_interpreter(_save_mobile_debug_info=True)
)
buffer.seek(0)
mobile_module = _load_for_lite_interpreter(buffer)
try:
mobile_module()
self.assertTrue(False)
except RuntimeError as exp:
FileCheck().check("Trying to create tensor with negative dimension").check(
"Traceback of TorchScript"
).check("self.b.forwardError").check_next(
"~~~~~~~~~~~~~~~~~~~ <--- HERE"
).check(
"return self.call"
).check_next(
"~~~~~~~~~ <--- HERE"
).check(
"return torch.ones"
).check_next(
"~~~~~~~~~~ <--- HERE"
).run(
str(exp)
)
class TestLiteScriptQuantizedModule(QuantizationLiteTestCase):
def test_single_layer(self):
input = torch.rand(2, 5, dtype=torch.float)
quantized_model = self._create_quantized_model(
model_class=AnnotatedSingleLayerLinearModel, qengine="qnnpack"
)
self._compare_script_and_mobile(model=quantized_model, input=input)
def test_two_layer(self):
input = torch.rand(2, 5, dtype=torch.float)
quantized_model = self._create_quantized_model(model_class=TwoLayerLinearModel)
self._compare_script_and_mobile(model=quantized_model, input=input)
def test_annotated_nested(self):
input = torch.rand(2, 5, dtype=torch.float)
quantized_model = self._create_quantized_model(
model_class=AnnotatedNestedModel, qengine="qnnpack"
)
self._compare_script_and_mobile(model=quantized_model, input=input)
def test_quantization_example(self):
# From the example in Static Quantization section of https://pytorch.org/docs/stable/quantization.html
class M(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.quant = torch.ao.quantization.QuantStub()
self.conv = torch.nn.Conv2d(1, 1, 1)
self.relu = torch.nn.ReLU()
self.dequant = torch.ao.quantization.DeQuantStub()
def forward(self, x):
x = self.quant(x)
x = self.conv(x)
x = self.relu(x)
x = self.dequant(x)
return x
model_fp32 = M()
model_fp32.eval()
model_fp32.qconfig = torch.ao.quantization.get_default_qconfig("qnnpack")
model_fp32_fused = torch.ao.quantization.fuse_modules(
model_fp32, [["conv", "relu"]]
)
model_fp32_prepared = torch.ao.quantization.prepare(model_fp32_fused)
input_fp32 = torch.randn(4, 1, 4, 4)
model_fp32_prepared(input_fp32)
model_int8 = torch.ao.quantization.convert(model_fp32_prepared)
input = torch.randn(4, 1, 4, 4)
self._compare_script_and_mobile(model=model_int8, input=input)
def test_bundled_input_with_dynamic_type(self):
class Model(torch.nn.Module):
def forward(
self,
x: Dict[int, torch.Tensor],
y: Dict[int, torch.Tensor],
z: Dict[int, torch.Tensor],
):
return x
model = Model()
script_module = torch.jit.script(model)
sample_input = {
script_module.forward: [
(
{0: torch.ones(1)},
{1: torch.ones(1)},
{2: torch.ones(1)},
)
]
}
bundled_model = torch.utils.bundled_inputs.bundle_inputs(
script_module, sample_input
)
buf = bundled_model._save_to_buffer_for_lite_interpreter()
mobile_module = _load_for_lite_interpreter(io.BytesIO(buf))
i = mobile_module.run_method("get_all_bundled_inputs")
self.assertEqual(
i[0],
(
{0: torch.ones(1)},
{1: torch.ones(1)},
{2: torch.ones(1)},
),
)
if __name__ == "__main__":
run_tests()