| # Owner(s): ["oncall: mobile"] |
| |
| import inspect |
| import io |
| from tempfile import TemporaryFileName |
| from typing import Dict, List |
| |
| import torch |
| import torch.utils.bundled_inputs |
| from torch.jit.mobile import _export_operator_list, _load_for_lite_interpreter |
| from torch.testing import FileCheck |
| from torch.testing._internal.common_quantization import ( |
| AnnotatedNestedModel, |
| AnnotatedSingleLayerLinearModel, |
| QuantizationLiteTestCase, |
| TwoLayerLinearModel, |
| ) |
| from torch.testing._internal.common_utils import run_tests, TestCase |
| |
| |
| class TestLiteScriptModule(TestCase): |
| def getScriptExportImportCopy( |
| self, m, save_mobile_debug_info=True, also_test_file=False |
| ): |
| m_scripted = torch.jit.script(m) |
| |
| if not also_test_file: |
| buffer = io.BytesIO( |
| m_scripted._save_to_buffer_for_lite_interpreter( |
| _save_mobile_debug_info=save_mobile_debug_info |
| ) |
| ) |
| buffer.seek(0) |
| mobile_module = _load_for_lite_interpreter(buffer) |
| return mobile_module |
| |
| with TemporaryFileName() as fname: |
| m_scripted._save_for_lite_interpreter( |
| fname, _save_mobile_debug_info=save_mobile_debug_info |
| ) |
| mobile_module = _load_for_lite_interpreter(fname) |
| return mobile_module |
| |
| def test_load_mobile_module(self): |
| class MyTestModule(torch.nn.Module): |
| def forward(self, x): |
| return x + 10 |
| |
| input = torch.tensor([1]) |
| |
| script_module = torch.jit.script(MyTestModule()) |
| script_module_result = script_module(input) |
| |
| buffer = io.BytesIO(script_module._save_to_buffer_for_lite_interpreter()) |
| buffer.seek(0) |
| mobile_module = _load_for_lite_interpreter(buffer) |
| |
| mobile_module_result = mobile_module(input) |
| torch.testing.assert_close(script_module_result, mobile_module_result) |
| |
| mobile_module_forward_result = mobile_module.forward(input) |
| torch.testing.assert_close(script_module_result, mobile_module_forward_result) |
| |
| mobile_module_run_method_result = mobile_module.run_method("forward", input) |
| torch.testing.assert_close( |
| script_module_result, mobile_module_run_method_result |
| ) |
| |
| def test_save_mobile_module_with_debug_info_with_trace(self): |
| class A(torch.nn.Module): |
| def forward(self, x, y): |
| return x * y |
| |
| class B(torch.nn.Module): |
| def __init__(self) -> None: |
| super().__init__() |
| self.A0 = A() |
| self.A1 = A() |
| |
| def forward(self, x, y, z): |
| return self.A0(x, y) + self.A1(y, z) |
| |
| for export_method in ["trace", "script"]: |
| x = torch.rand((2, 3)) |
| y = torch.rand((2, 3)) |
| z = torch.rand((2, 3)) |
| if export_method == "trace": |
| trace_module = torch.jit.trace(B(), [x, y, z]) |
| else: |
| trace_module = torch.jit.script(B()) |
| exported_module = trace_module._save_to_buffer_for_lite_interpreter( |
| _save_mobile_debug_info=True |
| ) |
| buffer = io.BytesIO(exported_module) |
| buffer.seek(0) |
| |
| assert b"callstack_debug_map.pkl" in exported_module |
| |
| mobile_module = _load_for_lite_interpreter(buffer) |
| with self.assertRaisesRegex( |
| RuntimeError, |
| r"Module hierarchy:top\(B\)::<unknown>.A0\(A\)::forward.aten::mul", |
| ): |
| x = torch.rand((2, 3)) |
| y = torch.rand((8, 10)) |
| z = torch.rand((8, 10)) |
| mobile_module(x, y, z) |
| with self.assertRaisesRegex( |
| RuntimeError, |
| r"Module hierarchy:top\(B\)::<unknown>.A1\(A\)::forward.aten::mul", |
| ): |
| x = torch.rand((2, 3)) |
| y = torch.rand((2, 3)) |
| z = torch.rand((8, 10)) |
| mobile_module(x, y, z) |
| |
| def test_load_mobile_module_with_debug_info(self): |
| class MyTestModule(torch.nn.Module): |
| def forward(self, x): |
| return x + 5 |
| |
| input = torch.tensor([3]) |
| |
| script_module = torch.jit.script(MyTestModule()) |
| script_module_result = script_module(input) |
| |
| buffer = io.BytesIO( |
| script_module._save_to_buffer_for_lite_interpreter( |
| _save_mobile_debug_info=True |
| ) |
| ) |
| buffer.seek(0) |
| mobile_module = _load_for_lite_interpreter(buffer) |
| |
| mobile_module_result = mobile_module(input) |
| torch.testing.assert_close(script_module_result, mobile_module_result) |
| |
| mobile_module_forward_result = mobile_module.forward(input) |
| torch.testing.assert_close(script_module_result, mobile_module_forward_result) |
| |
| mobile_module_run_method_result = mobile_module.run_method("forward", input) |
| torch.testing.assert_close( |
| script_module_result, mobile_module_run_method_result |
| ) |
| |
| def test_find_and_run_method(self): |
| class MyTestModule(torch.nn.Module): |
| def forward(self, arg): |
| return arg |
| |
| input = (torch.tensor([1]),) |
| |
| script_module = torch.jit.script(MyTestModule()) |
| script_module_result = script_module(*input) |
| |
| buffer = io.BytesIO(script_module._save_to_buffer_for_lite_interpreter()) |
| buffer.seek(0) |
| mobile_module = _load_for_lite_interpreter(buffer) |
| |
| has_bundled_inputs = mobile_module.find_method("get_all_bundled_inputs") |
| self.assertFalse(has_bundled_inputs) |
| |
| torch.utils.bundled_inputs.augment_model_with_bundled_inputs( |
| script_module, [input], [] |
| ) |
| |
| buffer = io.BytesIO(script_module._save_to_buffer_for_lite_interpreter()) |
| buffer.seek(0) |
| mobile_module = _load_for_lite_interpreter(buffer) |
| |
| has_bundled_inputs = mobile_module.find_method("get_all_bundled_inputs") |
| self.assertTrue(has_bundled_inputs) |
| |
| bundled_inputs = mobile_module.run_method("get_all_bundled_inputs") |
| mobile_module_result = mobile_module.forward(*bundled_inputs[0]) |
| torch.testing.assert_close(script_module_result, mobile_module_result) |
| |
| def test_method_calls_with_optional_arg(self): |
| class A(torch.nn.Module): |
| def __init__(self) -> None: |
| super().__init__() |
| |
| # opt arg in script-to-script invocation |
| def forward(self, x, two: int = 2): |
| return x + two |
| |
| class B(torch.nn.Module): |
| def __init__(self) -> None: |
| super().__init__() |
| self.A0 = A() |
| |
| # opt arg in Python-to-script invocation |
| def forward(self, x, one: int = 1): |
| return self.A0(x) + one |
| |
| script_module = torch.jit.script(B()) |
| buffer = io.BytesIO(script_module._save_to_buffer_for_lite_interpreter()) |
| mobile_module = _load_for_lite_interpreter(buffer) |
| |
| input = torch.tensor([5]) |
| script_module_forward_result = script_module.forward(input) |
| mobile_module_forward_result = mobile_module.forward(input) |
| torch.testing.assert_close( |
| script_module_forward_result, mobile_module_forward_result |
| ) |
| |
| # change ref only |
| script_module_forward_result = script_module.forward(input, 2) |
| self.assertFalse( |
| (script_module_forward_result == mobile_module_forward_result).all().item() |
| ) |
| |
| # now both match again |
| mobile_module_forward_result = mobile_module.forward(input, 2) |
| torch.testing.assert_close( |
| script_module_forward_result, mobile_module_forward_result |
| ) |
| |
| def test_unsupported_classtype(self): |
| class Foo: |
| def __init__(self) -> None: |
| return |
| |
| def func(self, x: int, y: int): |
| return x + y |
| |
| class MyTestModule(torch.nn.Module): |
| def forward(self, arg): |
| f = Foo() |
| return f.func(1, 2) |
| |
| script_module = torch.jit.script(MyTestModule()) |
| with self.assertRaisesRegex( |
| RuntimeError, |
| r"Workaround: instead of using arbitrary class type \(class Foo\(\)\), " |
| r"define a pytorch class \(class Foo\(torch\.nn\.Module\)\)\. " |
| r"The problematic type is: ", |
| ): |
| script_module._save_to_buffer_for_lite_interpreter() |
| |
| def test_unsupported_return_list_with_module_class(self): |
| class Foo(torch.nn.Module): |
| pass |
| |
| class MyTestModuleForListWithModuleClass(torch.nn.Module): |
| def __init__(self) -> None: |
| super().__init__() |
| self.foo = Foo() |
| |
| def forward(self): |
| my_list: List[Foo] = [self.foo] |
| return my_list |
| |
| script_module = torch.jit.script(MyTestModuleForListWithModuleClass()) |
| with self.assertRaisesRegex( |
| RuntimeError, |
| r"^Returning a list or dictionary with pytorch class type " |
| r"is not supported in mobile module " |
| r"\(List\[Foo\] or Dict\[int\, Foo\] for class Foo\(torch\.nn\.Module\)\)\. " |
| r"Workaround\: instead of using pytorch class as their element type\, " |
| r"use a combination of list\, dictionary\, and single types\.$", |
| ): |
| script_module._save_to_buffer_for_lite_interpreter() |
| |
| def test_unsupported_return_dict_with_module_class(self): |
| class Foo(torch.nn.Module): |
| pass |
| |
| class MyTestModuleForDictWithModuleClass(torch.nn.Module): |
| def __init__(self) -> None: |
| super().__init__() |
| self.foo = Foo() |
| |
| def forward(self): |
| my_dict: Dict[int, Foo] = {1: self.foo} |
| return my_dict |
| |
| script_module = torch.jit.script(MyTestModuleForDictWithModuleClass()) |
| with self.assertRaisesRegex( |
| RuntimeError, |
| r"^Returning a list or dictionary with pytorch class type " |
| r"is not supported in mobile module " |
| r"\(List\[Foo\] or Dict\[int\, Foo\] for class Foo\(torch\.nn\.Module\)\)\. " |
| r"Workaround\: instead of using pytorch class as their element type\, " |
| r"use a combination of list\, dictionary\, and single types\.$", |
| ): |
| script_module._save_to_buffer_for_lite_interpreter() |
| |
| def test_module_export_operator_list(self): |
| class Foo(torch.nn.Module): |
| def __init__(self) -> None: |
| super().__init__() |
| self.weight = torch.ones((20, 1, 5, 5)) |
| self.bias = torch.ones(20) |
| |
| def forward(self, input): |
| x1 = torch.zeros(2, 2) |
| x2 = torch.empty_like(torch.empty(2, 2)) |
| x3 = torch._convolution( |
| input, |
| self.weight, |
| self.bias, |
| [1, 1], |
| [0, 0], |
| [1, 1], |
| False, |
| [0, 0], |
| 1, |
| False, |
| False, |
| True, |
| True, |
| ) |
| return (x1, x2, x3) |
| |
| m = torch.jit.script(Foo()) |
| |
| buffer = io.BytesIO(m._save_to_buffer_for_lite_interpreter()) |
| buffer.seek(0) |
| mobile_module = _load_for_lite_interpreter(buffer) |
| |
| expected_ops = { |
| "aten::_convolution", |
| "aten::empty.memory_format", |
| "aten::empty_like", |
| "aten::zeros", |
| } |
| actual_ops = _export_operator_list(mobile_module) |
| self.assertEqual(actual_ops, expected_ops) |
| |
| def test_source_range_simple(self): |
| class FooTest(torch.jit.ScriptModule): |
| @torch.jit.script_method |
| def forward(self, x, w): |
| return torch.mm(x, w.t()) |
| |
| ft = FooTest() |
| loaded = self.getScriptExportImportCopy(ft) |
| _, lineno = inspect.getsourcelines(FooTest) |
| |
| with self.assertRaisesRegex( |
| RuntimeError, f'test_lite_script_module.py", line {lineno + 3}' |
| ): |
| loaded(torch.rand(3, 4), torch.rand(30, 40)) |
| |
| def test_source_range_raise_exception(self): |
| class FooTest2(torch.jit.ScriptModule): |
| @torch.jit.script_method |
| def forward(self): |
| raise RuntimeError("foo") |
| |
| _, lineno = inspect.getsourcelines(FooTest2) |
| |
| # In C++ code, the type of exception thrown is torch::jit::JITException |
| # which does not extend c10::Error, and hence it isn't possible to add |
| # additional context to the exception message and preserve the correct |
| # C++ stack trace for symbolication. i.e. it isn't possible to add |
| # the debug handle string to show where in the Python code the exception |
| # occured w/o first changing |
| # torch::jit::JITException to extend c10::Error. |
| with self.assertRaisesRegex(torch.jit.Error, "foo"): |
| ft = FooTest2() |
| loaded = self.getScriptExportImportCopy(ft) |
| loaded() |
| |
| def test_source_range_function_call(self): |
| class FooTest3(torch.jit.ScriptModule): |
| @torch.jit.script_method |
| def add_method(self, x, w): |
| return x + w |
| |
| @torch.jit.script_method |
| def forward(self, x, y, w): |
| x = x * y |
| x = x + 2 |
| return self.add_method(x, w) |
| |
| ft = FooTest3() |
| loaded = self.getScriptExportImportCopy(ft) |
| _, lineno = inspect.getsourcelines(FooTest3) |
| |
| try: |
| loaded(torch.rand(3, 4), torch.rand(3, 4), torch.rand(30, 40)) |
| except RuntimeError as e: |
| error_message = f"{e}" |
| self.assertTrue( |
| f'test_lite_script_module.py", line {lineno + 3}' in error_message |
| ) |
| self.assertTrue( |
| f'test_lite_script_module.py", line {lineno + 9}' in error_message |
| ) |
| self.assertTrue("top(FooTest3)" in error_message) |
| |
| def test_source_range_no_debug_info(self): |
| class FooTest4(torch.jit.ScriptModule): |
| @torch.jit.script_method |
| def forward(self, x, w): |
| return torch.mm(x, w.t()) |
| |
| ft = FooTest4() |
| loaded = self.getScriptExportImportCopy(ft, save_mobile_debug_info=False) |
| |
| try: |
| loaded(torch.rand(3, 4), torch.rand(30, 40)) |
| except RuntimeError as e: |
| error_message = f"{e}" |
| self.assertTrue("test_lite_script_module.py" not in error_message) |
| |
| def test_source_range_raise_exc(self): |
| class FooTest5(torch.jit.ScriptModule): |
| def __init__(self, val: int): |
| super().__init__() |
| self.val = val |
| |
| @torch.jit.script_method |
| def add_method(self, val: int, x, w): |
| if val == self.val: |
| raise RuntimeError("self.val and val are same") |
| return x + w |
| |
| @torch.jit.script_method |
| def forward(self, val: int, x, y, w): |
| x = x * y |
| x = x + 2 |
| return self.add_method(val, x, w) |
| |
| ft = FooTest5(42) |
| loaded = self.getScriptExportImportCopy(ft) |
| _, lineno = inspect.getsourcelines(FooTest5) |
| |
| try: |
| loaded(42, torch.rand(3, 4), torch.rand(3, 4), torch.rand(30, 40)) |
| except torch.jit.Error as e: |
| error_message = f"{e}" |
| |
| # In C++ code, the type of exception thrown is torch::jit::JITException |
| # which does not extend c10::Error, and hence it isn't possible to add |
| # additional context to the exception message and preserve the correct |
| # C++ stack trace for symbolication. i.e. it isn't possible to add |
| # the debug handle string to show where in the Python code the exception |
| # occured w/o first changing |
| # torch::jit::JITException to extend c10::Error. |
| self.assertTrue("self.val and val are same" in error_message) |
| |
| def test_stacktrace_interface_call(self): |
| @torch.jit.interface |
| class Forward(torch.nn.Module): |
| def forward(self, x) -> torch.Tensor: |
| pass |
| |
| def forwardError(self, x) -> torch.Tensor: |
| pass |
| |
| class B(torch.nn.Module): |
| def forward(self, x): |
| return x |
| |
| def forwardError(self, x): |
| return self.call() + x |
| |
| def call(self): |
| return torch.ones(-1) |
| |
| class A(torch.nn.Module): |
| b: Forward |
| |
| def __init__(self) -> None: |
| super().__init__() |
| self.b = B() |
| |
| def forward(self): |
| self.b.forward(torch.ones(1)) |
| self.b.forwardError(torch.ones(1)) |
| |
| a = torch.jit.script(A()) |
| torch._C._enable_mobile_interface_call_export() |
| buffer = io.BytesIO( |
| a._save_to_buffer_for_lite_interpreter(_save_mobile_debug_info=True) |
| ) |
| buffer.seek(0) |
| mobile_module = _load_for_lite_interpreter(buffer) |
| try: |
| mobile_module() |
| self.assertTrue(False) |
| except RuntimeError as exp: |
| FileCheck().check("Trying to create tensor with negative dimension").check( |
| "Traceback of TorchScript" |
| ).check("self.b.forwardError").check_next( |
| "~~~~~~~~~~~~~~~~~~~ <--- HERE" |
| ).check( |
| "return self.call" |
| ).check_next( |
| "~~~~~~~~~ <--- HERE" |
| ).check( |
| "return torch.ones" |
| ).check_next( |
| "~~~~~~~~~~ <--- HERE" |
| ).run( |
| str(exp) |
| ) |
| |
| |
| class TestLiteScriptQuantizedModule(QuantizationLiteTestCase): |
| def test_single_layer(self): |
| input = torch.rand(2, 5, dtype=torch.float) |
| quantized_model = self._create_quantized_model( |
| model_class=AnnotatedSingleLayerLinearModel, qengine="qnnpack" |
| ) |
| self._compare_script_and_mobile(model=quantized_model, input=input) |
| |
| def test_two_layer(self): |
| input = torch.rand(2, 5, dtype=torch.float) |
| quantized_model = self._create_quantized_model(model_class=TwoLayerLinearModel) |
| self._compare_script_and_mobile(model=quantized_model, input=input) |
| |
| def test_annotated_nested(self): |
| input = torch.rand(2, 5, dtype=torch.float) |
| quantized_model = self._create_quantized_model( |
| model_class=AnnotatedNestedModel, qengine="qnnpack" |
| ) |
| self._compare_script_and_mobile(model=quantized_model, input=input) |
| |
| def test_quantization_example(self): |
| # From the example in Static Quantization section of https://pytorch.org/docs/stable/quantization.html |
| class M(torch.nn.Module): |
| def __init__(self) -> None: |
| super().__init__() |
| self.quant = torch.ao.quantization.QuantStub() |
| self.conv = torch.nn.Conv2d(1, 1, 1) |
| self.relu = torch.nn.ReLU() |
| self.dequant = torch.ao.quantization.DeQuantStub() |
| |
| def forward(self, x): |
| x = self.quant(x) |
| x = self.conv(x) |
| x = self.relu(x) |
| x = self.dequant(x) |
| return x |
| |
| model_fp32 = M() |
| |
| model_fp32.eval() |
| model_fp32.qconfig = torch.ao.quantization.get_default_qconfig("qnnpack") |
| model_fp32_fused = torch.ao.quantization.fuse_modules( |
| model_fp32, [["conv", "relu"]] |
| ) |
| model_fp32_prepared = torch.ao.quantization.prepare(model_fp32_fused) |
| input_fp32 = torch.randn(4, 1, 4, 4) |
| model_fp32_prepared(input_fp32) |
| model_int8 = torch.ao.quantization.convert(model_fp32_prepared) |
| |
| input = torch.randn(4, 1, 4, 4) |
| self._compare_script_and_mobile(model=model_int8, input=input) |
| |
| def test_bundled_input_with_dynamic_type(self): |
| class Model(torch.nn.Module): |
| def forward( |
| self, |
| x: Dict[int, torch.Tensor], |
| y: Dict[int, torch.Tensor], |
| z: Dict[int, torch.Tensor], |
| ): |
| return x |
| |
| model = Model() |
| script_module = torch.jit.script(model) |
| |
| sample_input = { |
| script_module.forward: [ |
| ( |
| {0: torch.ones(1)}, |
| {1: torch.ones(1)}, |
| {2: torch.ones(1)}, |
| ) |
| ] |
| } |
| |
| bundled_model = torch.utils.bundled_inputs.bundle_inputs( |
| script_module, sample_input |
| ) |
| |
| buf = bundled_model._save_to_buffer_for_lite_interpreter() |
| mobile_module = _load_for_lite_interpreter(io.BytesIO(buf)) |
| |
| i = mobile_module.run_method("get_all_bundled_inputs") |
| |
| self.assertEqual( |
| i[0], |
| ( |
| {0: torch.ones(1)}, |
| {1: torch.ones(1)}, |
| {2: torch.ones(1)}, |
| ), |
| ) |
| |
| |
| if __name__ == "__main__": |
| run_tests() |