| import argparse |
| import os |
| import sys |
| import torch |
| |
| # grab modules from test_jit_hooks.cpp |
| pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) |
| sys.path.append(pytorch_test_dir) |
| from jit.test_hooks_modules import ( |
| create_forward_tuple_input, create_module_forward_multiple_inputs, |
| create_module_forward_single_input, create_module_hook_return_nothing, |
| create_module_multiple_hooks_multiple_inputs, |
| create_module_multiple_hooks_single_input, create_module_no_forward_input, |
| create_module_same_hook_repeated, create_submodule_forward_multiple_inputs, |
| create_submodule_forward_single_input, |
| create_submodule_hook_return_nothing, |
| create_submodule_multiple_hooks_multiple_inputs, |
| create_submodule_multiple_hooks_single_input, |
| create_submodule_same_hook_repeated, |
| create_submodule_to_call_directly_with_hooks) |
| |
| # Create saved modules for JIT forward hooks and pre-hooks |
| def main(): |
| parser = argparse.ArgumentParser( |
| description="Serialize a script modules with hooks attached" |
| ) |
| parser.add_argument("--export-script-module-to", required=True) |
| options = parser.parse_args() |
| global save_name |
| save_name = options.export_script_module_to + "_" |
| |
| tests = [ |
| ("test_submodule_forward_single_input", create_submodule_forward_single_input()), |
| ("test_submodule_forward_multiple_inputs", create_submodule_forward_multiple_inputs()), |
| ("test_submodule_multiple_hooks_single_input", create_submodule_multiple_hooks_single_input()), |
| ("test_submodule_multiple_hooks_multiple_inputs", create_submodule_multiple_hooks_multiple_inputs()), |
| ("test_submodule_hook_return_nothing", create_submodule_hook_return_nothing()), |
| ("test_submodule_same_hook_repeated", create_submodule_same_hook_repeated()), |
| |
| ("test_module_forward_single_input", create_module_forward_single_input()), |
| ("test_module_forward_multiple_inputs", create_module_forward_multiple_inputs()), |
| ("test_module_multiple_hooks_single_input", create_module_multiple_hooks_single_input()), |
| ("test_module_multiple_hooks_multiple_inputs", create_module_multiple_hooks_multiple_inputs()), |
| ("test_module_hook_return_nothing", create_module_hook_return_nothing()), |
| ("test_module_same_hook_repeated", create_module_same_hook_repeated()), |
| |
| ("test_module_no_forward_input", create_module_no_forward_input()), |
| ("test_forward_tuple_input", create_forward_tuple_input()), |
| ("test_submodule_to_call_directly_with_hooks", create_submodule_to_call_directly_with_hooks()) |
| ] |
| |
| for name, model in tests: |
| m_scripted = torch.jit.script(model) |
| filename = save_name + name + ".pt" |
| torch.jit.save(m_scripted, filename) |
| |
| print("OK: completed saving modules with hooks!") |
| |
| |
| if __name__ == "__main__": |
| main() |