| # Owner(s): ["oncall: mobile"] |
| |
| import fnmatch |
| import io |
| import shutil |
| import tempfile |
| from pathlib import Path |
| |
| import torch |
| import torch.utils.show_pickle |
| |
| # from torch.utils.mobile_optimizer import optimize_for_mobile |
| from torch.jit.mobile import ( |
| _backport_for_mobile, |
| _backport_for_mobile_to_buffer, |
| _get_mobile_model_contained_types, |
| _get_model_bytecode_version, |
| _get_model_ops_and_info, |
| _load_for_lite_interpreter, |
| ) |
| from torch.testing._internal.common_utils import run_tests, TestCase |
| |
| |
| pytorch_test_dir = Path(__file__).resolve().parents[1] |
| |
| # script_module_v4.ptl and script_module_v5.ptl source code |
| # class TestModule(torch.nn.Module): |
| # def __init__(self, v): |
| # super().__init__() |
| # self.x = v |
| |
| # def forward(self, y: int): |
| # increment = torch.ones([2, 4], dtype=torch.float64) |
| # return self.x + y + increment |
| |
| # output_model_path = Path(tmpdirname, "script_module_v5.ptl") |
| # script_module = torch.jit.script(TestModule(1)) |
| # optimized_scripted_module = optimize_for_mobile(script_module) |
| # exported_optimized_scripted_module = optimized_scripted_module._save_for_lite_interpreter( |
| # str(output_model_path)) |
| |
| SCRIPT_MODULE_V4_BYTECODE_PKL = """ |
| (4, |
| ('__torch__.*.TestModule.forward', |
| (('instructions', |
| (('STOREN', 1, 2), |
| ('DROPR', 1, 0), |
| ('LOADC', 0, 0), |
| ('LOADC', 1, 0), |
| ('MOVE', 2, 0), |
| ('OP', 0, 0), |
| ('LOADC', 1, 0), |
| ('OP', 1, 0), |
| ('RET', 0, 0))), |
| ('operators', (('aten::add', 'int'), ('aten::add', 'Scalar'))), |
| ('constants', |
| (torch._utils._rebuild_tensor_v2(pers.obj(('storage', torch.DoubleStorage, '0', 'cpu', 8),), |
| 0, |
| (2, 4), |
| (4, 1), |
| False, |
| collections.OrderedDict()), |
| 1)), |
| ('types', ()), |
| ('register_size', 2)), |
| (('arguments', |
| ((('name', 'self'), |
| ('type', '__torch__.*.TestModule'), |
| ('default_value', None)), |
| (('name', 'y'), ('type', 'int'), ('default_value', None)))), |
| ('returns', |
| ((('name', ''), ('type', 'Tensor'), ('default_value', None)),))))) |
| """ |
| |
| SCRIPT_MODULE_V5_BYTECODE_PKL = """ |
| (5, |
| ('__torch__.*.TestModule.forward', |
| (('instructions', |
| (('STOREN', 1, 2), |
| ('DROPR', 1, 0), |
| ('LOADC', 0, 0), |
| ('LOADC', 1, 0), |
| ('MOVE', 2, 0), |
| ('OP', 0, 0), |
| ('LOADC', 1, 0), |
| ('OP', 1, 0), |
| ('RET', 0, 0))), |
| ('operators', (('aten::add', 'int'), ('aten::add', 'Scalar'))), |
| ('constants', |
| (torch._utils._rebuild_tensor_v2(pers.obj(('storage', torch.DoubleStorage, 'constants/0', 'cpu', 8),), |
| 0, |
| (2, 4), |
| (4, 1), |
| False, |
| collections.OrderedDict()), |
| 1)), |
| ('types', ()), |
| ('register_size', 2)), |
| (('arguments', |
| ((('name', 'self'), |
| ('type', '__torch__.*.TestModule'), |
| ('default_value', None)), |
| (('name', 'y'), ('type', 'int'), ('default_value', None)))), |
| ('returns', |
| ((('name', ''), ('type', 'Tensor'), ('default_value', None)),))))) |
| """ |
| |
| SCRIPT_MODULE_V6_BYTECODE_PKL = """ |
| (6, |
| ('__torch__.*.TestModule.forward', |
| (('instructions', |
| (('STOREN', 1, 2), |
| ('DROPR', 1, 0), |
| ('LOADC', 0, 0), |
| ('LOADC', 1, 0), |
| ('MOVE', 2, 0), |
| ('OP', 0, 0), |
| ('OP', 1, 0), |
| ('RET', 0, 0))), |
| ('operators', (('aten::add', 'int', 2), ('aten::add', 'Scalar', 2))), |
| ('constants', |
| (torch._utils._rebuild_tensor_v2(pers.obj(('storage', torch.DoubleStorage, '0', 'cpu', 8),), |
| 0, |
| (2, 4), |
| (4, 1), |
| False, |
| collections.OrderedDict()), |
| 1)), |
| ('types', ()), |
| ('register_size', 2)), |
| (('arguments', |
| ((('name', 'self'), |
| ('type', '__torch__.*.TestModule'), |
| ('default_value', None)), |
| (('name', 'y'), ('type', 'int'), ('default_value', None)))), |
| ('returns', |
| ((('name', ''), ('type', 'Tensor'), ('default_value', None)),))))) |
| """ |
| |
| SCRIPT_MODULE_BYTECODE_PKL = { |
| 4: { |
| "bytecode_pkl": SCRIPT_MODULE_V4_BYTECODE_PKL, |
| "model_name": "script_module_v4.ptl", |
| }, |
| } |
| |
| # The minimum version a model can be backported to |
| # Need to be updated when a bytecode version is completely retired |
| MINIMUM_TO_VERSION = 4 |
| |
| |
| class testVariousModelVersions(TestCase): |
| def test_get_model_bytecode_version(self): |
| def check_model_version(model_path, expect_version): |
| actual_version = _get_model_bytecode_version(model_path) |
| assert actual_version == expect_version |
| |
| for version, model_info in SCRIPT_MODULE_BYTECODE_PKL.items(): |
| model_path = pytorch_test_dir / "cpp" / "jit" / model_info["model_name"] |
| check_model_version(model_path, version) |
| |
| def test_bytecode_values_for_all_backport_functions(self): |
| # Find the maximum version of the checked in models, start backporting to the minimum support version, |
| # and comparing the bytecode pkl content. |
| # It can't be merged to the test `test_all_backport_functions`, because optimization is dynamic and |
| # the content might change when optimize function changes. This test focuses |
| # on bytecode.pkl content validation. For the content validation, it is not byte to byte check, but |
| # regular expression matching. The wildcard can be used to skip some specific content comparison. |
| maximum_checked_in_model_version = max(SCRIPT_MODULE_BYTECODE_PKL.keys()) |
| current_from_version = maximum_checked_in_model_version |
| |
| with tempfile.TemporaryDirectory() as tmpdirname: |
| while current_from_version > MINIMUM_TO_VERSION: |
| # Load model v5 and run forward method |
| model_name = SCRIPT_MODULE_BYTECODE_PKL[current_from_version][ |
| "model_name" |
| ] |
| input_model_path = pytorch_test_dir / "cpp" / "jit" / model_name |
| |
| # A temporary model file will be export to this path, and run through bytecode.pkl |
| # content check. |
| tmp_output_model_path_backport = Path( |
| tmpdirname, "tmp_script_module_backport.ptl" |
| ) |
| |
| current_to_version = current_from_version - 1 |
| backport_success = _backport_for_mobile( |
| input_model_path, tmp_output_model_path_backport, current_to_version |
| ) |
| assert backport_success |
| |
| expect_bytecode_pkl = SCRIPT_MODULE_BYTECODE_PKL[current_to_version][ |
| "bytecode_pkl" |
| ] |
| |
| buf = io.StringIO() |
| torch.utils.show_pickle.main( |
| [ |
| "", |
| tmpdirname |
| + "/" |
| + tmp_output_model_path_backport.name |
| + "@*/bytecode.pkl", |
| ], |
| output_stream=buf, |
| ) |
| output = buf.getvalue() |
| |
| acutal_result_clean = "".join(output.split()) |
| expect_result_clean = "".join(expect_bytecode_pkl.split()) |
| isMatch = fnmatch.fnmatch(acutal_result_clean, expect_result_clean) |
| assert isMatch |
| |
| current_from_version -= 1 |
| shutil.rmtree(tmpdirname) |
| |
| # Please run this test manually when working on backport. |
| # This test passes in OSS, but fails internally, likely due to missing step in build |
| # def test_all_backport_functions(self): |
| # # Backport from the latest bytecode version to the minimum support version |
| # # Load, run the backport model, and check version |
| # class TestModule(torch.nn.Module): |
| # def __init__(self, v): |
| # super().__init__() |
| # self.x = v |
| |
| # def forward(self, y: int): |
| # increment = torch.ones([2, 4], dtype=torch.float64) |
| # return self.x + y + increment |
| |
| # module_input = 1 |
| # expected_mobile_module_result = 3 * torch.ones([2, 4], dtype=torch.float64) |
| |
| # # temporary input model file and output model file will be exported in the temporary folder |
| # with tempfile.TemporaryDirectory() as tmpdirname: |
| # tmp_input_model_path = Path(tmpdirname, "tmp_script_module.ptl") |
| # script_module = torch.jit.script(TestModule(1)) |
| # optimized_scripted_module = optimize_for_mobile(script_module) |
| # exported_optimized_scripted_module = optimized_scripted_module._save_for_lite_interpreter(str(tmp_input_model_path)) |
| |
| # current_from_version = _get_model_bytecode_version(tmp_input_model_path) |
| # current_to_version = current_from_version - 1 |
| # tmp_output_model_path = Path(tmpdirname, "tmp_script_module_backport.ptl") |
| |
| # while current_to_version >= MINIMUM_TO_VERSION: |
| # # Backport the latest model to `to_version` to a tmp file "tmp_script_module_backport" |
| # backport_success = _backport_for_mobile(tmp_input_model_path, tmp_output_model_path, current_to_version) |
| # assert(backport_success) |
| |
| # backport_version = _get_model_bytecode_version(tmp_output_model_path) |
| # assert(backport_version == current_to_version) |
| |
| # # Load model and run forward method |
| # mobile_module = _load_for_lite_interpreter(str(tmp_input_model_path)) |
| # mobile_module_result = mobile_module(module_input) |
| # torch.testing.assert_close(mobile_module_result, expected_mobile_module_result) |
| # current_to_version -= 1 |
| |
| # # Check backport failure case |
| # backport_success = _backport_for_mobile(tmp_input_model_path, tmp_output_model_path, MINIMUM_TO_VERSION - 1) |
| # assert(not backport_success) |
| # # need to clean the folder before it closes, otherwise will run into git not clean error |
| # shutil.rmtree(tmpdirname) |
| |
| # Check just the test_backport_bytecode_from_file_to_file mechanism but not the function implementations |
| def test_backport_bytecode_from_file_to_file(self): |
| maximum_checked_in_model_version = max(SCRIPT_MODULE_BYTECODE_PKL.keys()) |
| script_module_v5_path = ( |
| pytorch_test_dir |
| / "cpp" |
| / "jit" |
| / SCRIPT_MODULE_BYTECODE_PKL[maximum_checked_in_model_version]["model_name"] |
| ) |
| |
| if maximum_checked_in_model_version > MINIMUM_TO_VERSION: |
| with tempfile.TemporaryDirectory() as tmpdirname: |
| tmp_backport_model_path = Path( |
| tmpdirname, "tmp_script_module_v5_backported_to_v4.ptl" |
| ) |
| # backport from file |
| success = _backport_for_mobile( |
| script_module_v5_path, |
| tmp_backport_model_path, |
| maximum_checked_in_model_version - 1, |
| ) |
| assert success |
| |
| buf = io.StringIO() |
| torch.utils.show_pickle.main( |
| [ |
| "", |
| tmpdirname |
| + "/" |
| + tmp_backport_model_path.name |
| + "@*/bytecode.pkl", |
| ], |
| output_stream=buf, |
| ) |
| output = buf.getvalue() |
| |
| expected_result = SCRIPT_MODULE_V4_BYTECODE_PKL |
| acutal_result_clean = "".join(output.split()) |
| expect_result_clean = "".join(expected_result.split()) |
| isMatch = fnmatch.fnmatch(acutal_result_clean, expect_result_clean) |
| assert isMatch |
| |
| # Load model v4 and run forward method |
| mobile_module = _load_for_lite_interpreter(str(tmp_backport_model_path)) |
| module_input = 1 |
| mobile_module_result = mobile_module(module_input) |
| expected_mobile_module_result = 3 * torch.ones( |
| [2, 4], dtype=torch.float64 |
| ) |
| torch.testing.assert_close( |
| mobile_module_result, expected_mobile_module_result |
| ) |
| shutil.rmtree(tmpdirname) |
| |
| # Check just the _backport_for_mobile_to_buffer mechanism but not the function implementations |
| def test_backport_bytecode_from_file_to_buffer(self): |
| maximum_checked_in_model_version = max(SCRIPT_MODULE_BYTECODE_PKL.keys()) |
| script_module_v5_path = ( |
| pytorch_test_dir |
| / "cpp" |
| / "jit" |
| / SCRIPT_MODULE_BYTECODE_PKL[maximum_checked_in_model_version]["model_name"] |
| ) |
| |
| if maximum_checked_in_model_version > MINIMUM_TO_VERSION: |
| # Backport model to v4 |
| script_module_v4_buffer = _backport_for_mobile_to_buffer( |
| script_module_v5_path, maximum_checked_in_model_version - 1 |
| ) |
| buf = io.StringIO() |
| |
| # Check version of the model v4 from backport |
| bytesio = io.BytesIO(script_module_v4_buffer) |
| backport_version = _get_model_bytecode_version(bytesio) |
| assert backport_version == maximum_checked_in_model_version - 1 |
| |
| # Load model v4 from backport and run forward method |
| bytesio = io.BytesIO(script_module_v4_buffer) |
| mobile_module = _load_for_lite_interpreter(bytesio) |
| module_input = 1 |
| mobile_module_result = mobile_module(module_input) |
| expected_mobile_module_result = 3 * torch.ones([2, 4], dtype=torch.float64) |
| torch.testing.assert_close( |
| mobile_module_result, expected_mobile_module_result |
| ) |
| |
| def test_get_model_ops_and_info(self): |
| # TODO update this to be more in the style of the above tests after a backport from 6 -> 5 exists |
| script_module_v6 = pytorch_test_dir / "cpp" / "jit" / "script_module_v6.ptl" |
| ops_v6 = _get_model_ops_and_info(script_module_v6) |
| assert ops_v6["aten::add.int"].num_schema_args == 2 |
| assert ops_v6["aten::add.Scalar"].num_schema_args == 2 |
| |
| def test_get_mobile_model_contained_types(self): |
| class MyTestModule(torch.nn.Module): |
| def forward(self, x): |
| return x + 10 |
| |
| sample_input = torch.tensor([1]) |
| |
| script_module = torch.jit.script(MyTestModule()) |
| script_module_result = script_module(sample_input) |
| |
| buffer = io.BytesIO(script_module._save_to_buffer_for_lite_interpreter()) |
| buffer.seek(0) |
| type_list = _get_mobile_model_contained_types(buffer) |
| assert len(type_list) >= 0 |
| |
| |
| if __name__ == "__main__": |
| run_tests() |