| # Owner(s): ["oncall: jit"] |
| |
| import io |
| import os |
| import sys |
| import unittest |
| |
| import torch |
| import torch._C |
| from torch.jit.mobile import _load_for_lite_interpreter |
| from torch.testing import FileCheck |
| from torch.testing._internal.common_utils import ( |
| find_library_location, |
| IS_FBCODE, |
| IS_MACOS, |
| IS_SANDCASTLE, |
| IS_WINDOWS, |
| skipIfRocm, |
| TEST_WITH_ROCM, |
| ) |
| from torch.testing._internal.jit_utils import JitTestCase |
| |
| |
| # Make the helper files in test/ importable |
| pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) |
| sys.path.append(pytorch_test_dir) |
| |
| if __name__ == "__main__": |
| raise RuntimeError( |
| "This test file is not meant to be run directly, use:\n\n" |
| "\tpython test/test_jit.py TESTNAME\n\n" |
| "instead." |
| ) |
| |
| |
| def to_test_backend(module, method_compile_spec): |
| return torch._C._jit_to_backend( |
| "test_backend", module, {"forward": method_compile_spec} |
| ) |
| |
| |
| def to_test_backend_multi(module, method_compile_spec): |
| return torch._C._jit_to_backend("test_backend", module, method_compile_spec) |
| |
| |
| def to_test_backend_selective(module, method_compile_spec, submodules): |
| def _to_test_backend(module): |
| return to_test_backend(module, method_compile_spec) |
| |
| return torch._C._jit_to_backend_selective(module, _to_test_backend, submodules) |
| |
| |
| class BasicModule(torch.nn.Module): |
| """ |
| A simple Module used to test to_backend lowering machinery. |
| """ |
| |
| def forward(self, x, h): |
| return self.accum(x, h), self.sub_accum(x, h) |
| |
| def accum(self, x, h): |
| return x + h |
| |
| def sub_accum(self, x, h): |
| return x - h |
| |
| |
| # This is ignored in IS_WINDOWS or IS_MACOS cases. Hence we need the one in TestBackends. |
| @unittest.skipIf( |
| TEST_WITH_ROCM or IS_SANDCASTLE or IS_WINDOWS or IS_MACOS or IS_FBCODE, |
| "Non-portable load_library call used in test", |
| ) |
| class JitBackendTestCase(JitTestCase): |
| """ |
| A common base class for JIT backend tests that contains common utility |
| functions for output comparison and serialization/deserialization. |
| """ |
| |
| def setUp(self): |
| super().setUp() |
| lib_file_path = find_library_location("libjitbackend_test.so") |
| torch.ops.load_library(str(lib_file_path)) |
| # Subclasses are expected to set up three variables in their setUp methods: |
| # module - a regular, Python version of the module being tested |
| # scripted_module - a scripted version of module |
| # lowered_module - a version of module lowered to a backend |
| |
| def check_function(self, function_name, input): |
| """ |
| Check that the function named 'function_name' produces the same output using |
| Python, regular JIT and the backend for the given 'input'. |
| """ |
| # Get handles for Python, JIT and backend methods. |
| python_method = self.module.__getattribute__(function_name) |
| jit_method = self.scripted_module.__getattr__(function_name) |
| backend_method = self.lowered_module.__getattr__(function_name) |
| |
| # Run methods. |
| python_output = python_method(*input) |
| jit_output = jit_method(*input) |
| backend_output = backend_method(*input) |
| |
| # The answers returned by Python, JIT and to_backend should all match. |
| self.assertEqual(python_output, backend_output) |
| self.assertEqual(jit_output, backend_output) |
| |
| def save_load(self): |
| """ |
| Save and load the lowered module. |
| """ |
| self.lowered_module = self.getExportImportCopy(self.lowered_module) |
| |
| def test_execution(self): |
| """ |
| Stub for correctness tests. |
| """ |
| |
| def test_save_load(self): |
| """ |
| Stub for serialization tests. |
| """ |
| |
| def test_errors(self): |
| """ |
| Stub for testing error checking. |
| """ |
| |
| |
| class BasicModuleTest(JitBackendTestCase): |
| """ |
| Tests for BasicModule. |
| """ |
| |
| def setUp(self): |
| super().setUp() |
| # Create Python, JIT and backend versions of BasicModule. |
| self.module = BasicModule() |
| self.scripted_module = torch.jit.script(BasicModule()) |
| self.lowered_module = to_test_backend_multi( |
| self.scripted_module, |
| {"accum": {"": ""}, "sub_accum": {"": ""}, "forward": {"": ""}}, |
| ) |
| |
| def test_execution(self): |
| # Test execution with backend against Python and JIT. |
| input = torch.randn(5) |
| |
| # Test all three module methods. |
| self.check_function("accum", (input, input)) |
| self.check_function("sub_accum", (input, input)) |
| self.check_function("forward", (input, input)) |
| |
| @skipIfRocm |
| def test_save_load(self): |
| # Lowered module should produce the same outputs. |
| self.test_execution() |
| |
| # Save the compile spec to compare against the version retrieved after loading. |
| pre_compile_spec = self.lowered_module.__getattr__( |
| "__loweredModule__" |
| ).__getattr__("__method_compile_spec") |
| |
| # Save and load the lowered module. |
| self.save_load() |
| |
| # Get the compile spec after loading. |
| post_compile_spec = self.lowered_module.__getattr__( |
| "__loweredModule__" |
| ).__getattr__("__method_compile_spec") |
| |
| # Compile specs should match. |
| self.assertEqual(pre_compile_spec, post_compile_spec) |
| |
| # Loaded module should produce the same outputs. |
| self.test_execution() |
| |
| |
| class BasicModuleUnavailableTest(JitBackendTestCase): |
| """ |
| Tests for BasicModule with a backend that is not available. |
| Fundamentally: |
| * _jit_to_backend is successful. |
| * Execution fails with an exception. |
| * Saving is successful. |
| * Loading fails with an exception. |
| """ |
| |
| def setUp(self): |
| super().setUp() |
| # Create Python, JIT and backend versions of BasicModule. |
| self.module = BasicModule() |
| self.scripted_module = torch.jit.script(BasicModule()) |
| self.lowered_module = torch._C._jit_to_backend( |
| "test_backend_unavailable", |
| self.scripted_module, |
| {"forward": {"": ""}}, |
| ) |
| |
| def test_execution(self): |
| # Test execution with backend fails because the backend that is not available. |
| input = torch.randn(5) |
| |
| # Test exception is thrown. |
| with self.assertRaisesRegexWithHighlight( |
| Exception, |
| r"Backend is not available.", |
| 'raise Exception("Backend is not available."', |
| ): |
| backend_method = self.lowered_module.__getattr__("forward") |
| backend_output = backend_method(*(input, input)) |
| |
| @skipIfRocm |
| def test_save_load(self): |
| # Test that saving the lowered module is OK but loading fails because the backend is not available. |
| buffer = io.BytesIO() |
| torch.jit.save(self.lowered_module, buffer) |
| buffer.seek(0) |
| with self.assertRaisesRegexWithHighlight( |
| Exception, |
| r"Backend is not available.", |
| 'raise Exception("Backend is not available."', |
| ): |
| imported = torch.jit.load(buffer) |
| |
| |
| class NestedModuleTest(JitBackendTestCase): |
| """ |
| Tests for NestedModule that check that a module lowered to a backend can be used |
| as a submodule. |
| """ |
| |
| class NestedModule(torch.nn.Module): |
| """ |
| A Module with one submodule that is used to test that lowered Modules |
| can be used as submodules. |
| """ |
| |
| def __init__(self, submodule): |
| super().__init__() |
| self.submodule = submodule |
| |
| def forward(self, x, h): |
| return self.submodule.forward(x, h) |
| |
| def setUp(self): |
| super().setUp() |
| # Create Python, JIT and backend versions of NestedModule. |
| # Both modules in self.module are regular Python modules. |
| self.module = NestedModuleTest.NestedModule(BasicModule()) |
| # Both modules in self.scripted_module are ScriptModules. |
| self.scripted_module = torch.jit.script( |
| NestedModuleTest.NestedModule(BasicModule()) |
| ) |
| |
| # First, script another instance of NestedModule with share_types=False so that it can be |
| # selectively lowered without modifying the type of self.scripted_module. |
| lowered_module = to_test_backend_multi( |
| torch.jit.script(BasicModule()), |
| {"accum": {"": ""}, "sub_accum": {"": ""}, "forward": {"": ""}}, |
| ) |
| # self.lowered_module is a ScriptModule, but its submodule is a lowered module. |
| self.lowered_module = torch.jit.script( |
| NestedModuleTest.NestedModule(lowered_module) |
| ) |
| |
| def test_execution(self): |
| # Test execution with backend against Python and JIT. |
| input = torch.randn(5) |
| |
| # Test forward. |
| self.check_function("forward", (input, input)) |
| |
| def test_save_load(self): |
| # Lowered module should produce the same outputs. |
| self.test_execution() |
| |
| # Save and load the lowered module. |
| self.save_load() |
| |
| # Loaded module should produce the same outputs. |
| self.test_execution() |
| |
| |
| class SelectiveLoweringTest(JitBackendTestCase): |
| """ |
| Tests for the selective lowering API. |
| """ |
| |
| class OuterModule(torch.nn.Module): |
| def __init__(self, sub1, sub2, other): |
| super().__init__() |
| self.sub1 = sub1 |
| self.sub2 = sub2 |
| self.other = other |
| |
| def forward(self, x, y): |
| # Call the module that will be lowered directly to test |
| # type remapping in modules that are not its parent. |
| a, b = self.sub1.submodule.forward(x, y) |
| c, d = self.sub2.forward(x, y) |
| e, f = self.other.forward(x, y) |
| return a + c + e, b + d + f |
| |
| class MiddleModule(torch.nn.Module): |
| def __init__(self, submodule): |
| super().__init__() |
| self.submodule = submodule |
| |
| def forward(self, x, y): |
| return self.submodule.forward(x, y) |
| |
| def setUp(self): |
| super().setUp() |
| OuterModule = SelectiveLoweringTest.OuterModule |
| MiddleModule = SelectiveLoweringTest.MiddleModule |
| |
| def script_without_type_sharing(mod): |
| return torch.jit._recursive.create_script_module( |
| mod, torch.jit._recursive.infer_methods_to_compile, share_types=False |
| ) |
| |
| # Create Python, JIT and backend versions of a hierarchy that looks like this: |
| # --------- OuterModule -------- |
| # | | | |
| # MiddleModule MiddleModule MiddleModule |
| # | | | |
| # BasicModule BasicModule BasicModule |
| # |
| # Two BasicModules will be lowered and the third will not. |
| self.module = OuterModule( |
| MiddleModule(BasicModule()), |
| MiddleModule(BasicModule()), |
| MiddleModule(BasicModule()), |
| ) |
| self.scripted_module = script_without_type_sharing( |
| OuterModule( |
| MiddleModule(BasicModule()), |
| MiddleModule(BasicModule()), |
| MiddleModule(BasicModule()), |
| ) |
| ) |
| self.lowered_module = script_without_type_sharing( |
| OuterModule( |
| MiddleModule(BasicModule()), |
| MiddleModule(BasicModule()), |
| MiddleModule(BasicModule()), |
| ) |
| ) |
| self.lowered_module = to_test_backend_selective( |
| self.lowered_module, {"forward": ""}, ["sub1.submodule", "sub2.submodule"] |
| ) |
| |
| def test_execution(self): |
| input = torch.randn(5) |
| self.check_function("forward", (input, input)) |
| |
| self.test_selective_lowering_type_remap() |
| |
| def test_save_load(self): |
| self.test_execution() |
| self.save_load() |
| self.test_execution() |
| |
| self.test_selective_lowering_type_remap() |
| |
| def test_selective_lowering_type_remap(self): |
| """ |
| Check that type remapping and replacement occurred during selective lowering. |
| """ |
| # Check that self.lowered_module was not lowered, but that it does contain test_backendLoweredModule due to it |
| # calling the lowered module directly. |
| FileCheck().check("OuterModule").check("BasicModule").run( |
| self.scripted_module.graph |
| ) |
| FileCheck().check("OuterModule").check_not( |
| "__torch__.torch.classes.__backends__.test_backend" |
| ).check("LoweredWrapper.test_backend").run(self.lowered_module.graph) |
| |
| # Check that self.lowered_module.sub1/sub2 were not lowered but that BasicModule has been replaced in their graphs. |
| FileCheck().check("MiddleModule").check("BasicModule").check_not( |
| "LoweredWrapper.test_backend" |
| ).run(self.scripted_module.sub1.graph) |
| FileCheck().check("MiddleModule").check_not( |
| "__torch__.torch.classes.__backends__.test_backend" |
| ).check("LoweredWrapper.test_backend").run(self.lowered_module.sub1.graph) |
| |
| FileCheck().check("MiddleModule").check("BasicModule").check_not( |
| "LoweredWrapper.test_backend" |
| ).run(self.scripted_module.sub2.graph) |
| FileCheck().check("MiddleModule").check_not( |
| "__torch__.torch.classes.__backends__.test_backend" |
| ).check("LoweredWrapper.test_backend").run(self.lowered_module.sub2.graph) |
| |
| # Check that self.lowered_module.sub1/sub2.submodule were lowered. They should have a new attribute |
| # __loweredModule__ whose graph should mention __torch__.torch.classes.__backends__.test_backend, |
| # the TorchBind class for executing functions on the test JIT backend. |
| FileCheck().check("LoweredModule.test_backend").check( |
| "__torch__.torch.classes.__backends__.test_backend" |
| ).run(self.lowered_module.sub1.submodule.__loweredModule__.graph) |
| |
| FileCheck().check("LoweredModule.test_backend").check( |
| "__torch__.torch.classes.__backends__.test_backend" |
| ).run(self.lowered_module.sub2.submodule.__loweredModule__.graph) |
| |
| # Check that self.other and self.other.submodule have been left untouched by the selective lowering process. |
| FileCheck().check("MiddleModule").check("BasicModule").check_not( |
| "__torch__.torch.classes.__backends__.test_backend" |
| ).check_not("LoweredWrapper.test_backend").run(self.scripted_module.other.graph) |
| FileCheck().check("BasicModule").check_not( |
| "__torch__.torch.classes.__backends__.test_backend" |
| ).check_not("LoweredModule.test_backend").run( |
| self.scripted_module.other.submodule.graph |
| ) |
| |
| def test_errors(self): |
| """ |
| Check errors associated with selective lowering. |
| """ |
| # Check error messages thrown when attempting to lower something that is not a ScriptModule. |
| with self.assertRaisesRegexWithHighlight( |
| RuntimeError, r"Object .* is not a ScriptModule", "" |
| ): |
| to_test_backend_selective(torch.nn.ReLU(), {"forward": ""}, ["submodule"]) |
| |
| MiddleModule = SelectiveLoweringTest.MiddleModule |
| mod = MiddleModule(BasicModule()) |
| mod.new_attr = 3 |
| |
| with self.assertRaisesRegexWithHighlight( |
| RuntimeError, r"Attribute named new_attr is not a Module", "" |
| ): |
| to_test_backend_selective( |
| torch.jit.script(mod), {"forward": ""}, ["new_attr"] |
| ) |
| |
| # Check error message thrown when module hierarchy doesn't have unique types. |
| OuterModule = SelectiveLoweringTest.OuterModule |
| mod = OuterModule( |
| MiddleModule(BasicModule()), |
| MiddleModule(BasicModule()), |
| MiddleModule(BasicModule()), |
| ) |
| |
| with self.assertRaisesRegexWithHighlight( |
| RuntimeError, |
| r"Selective lowering is only supported for module hierarchies with unique types", |
| "", |
| ): |
| to_test_backend_selective( |
| torch.jit.script(mod), {"forward": ""}, ["sub1.submodule"] |
| ) |
| |
| |
| # This is needed for IS_WINDOWS or IS_MACOS to skip the tests. |
| @unittest.skipIf( |
| TEST_WITH_ROCM or IS_SANDCASTLE or IS_WINDOWS or IS_MACOS or IS_FBCODE, |
| "Non-portable load_library call used in test", |
| ) |
| class TestBackends(JitTestCase): |
| """ |
| This class wraps and invokes all subclasses of JitBackendTestCase so that each one |
| does not have to be individually imported in test_jit.py. |
| """ |
| |
| def __init__(self, name): |
| super().__init__(name) |
| self.basic_module_test = BasicModuleTest(name) |
| self.basic_module_unavailable_test = BasicModuleUnavailableTest(name) |
| self.nested_module_test = NestedModuleTest(name) |
| self.selective_lowering_test = SelectiveLoweringTest(name) |
| |
| def setUp(self): |
| super().setUp() |
| if not TEST_WITH_ROCM: |
| self.basic_module_test.setUp() |
| self.basic_module_unavailable_test.setUp() |
| self.nested_module_test.setUp() |
| self.selective_lowering_test.setUp() |
| |
| @skipIfRocm |
| def test_execution(self): |
| self.basic_module_test.test_execution() |
| self.basic_module_unavailable_test.test_execution() |
| self.nested_module_test.test_execution() |
| self.selective_lowering_test.test_execution() |
| |
| @skipIfRocm |
| def test_save_load(self): |
| self.basic_module_test.test_save_load() |
| self.basic_module_unavailable_test.test_save_load() |
| self.nested_module_test.test_save_load() |
| self.selective_lowering_test.test_save_load() |
| |
| @skipIfRocm |
| def test_errors(self): |
| self.selective_lowering_test.test_errors() |
| |
| |
| """ |
| Unit Tests for backend with compiler |
| This test case and the existing TestBackends are separate because they cover different aspects. |
| The actual backend implementation in this test is different. |
| It has a simple demo compiler to test the end-to-end flow in mobile. |
| However, this test cannot cover the selective_lowering for now, which is covered in TestBackends. |
| """ |
| |
| |
| class BasicModuleAdd(torch.nn.Module): |
| """ |
| A simple add Module used to test to_backend lowering machinery. |
| """ |
| |
| def forward(self, x, h): |
| return x + h |
| |
| |
| # This is ignored in IS_WINDOWS or IS_MACOS cases. Hence we need the one in TestBackends. |
| @unittest.skipIf( |
| TEST_WITH_ROCM or IS_SANDCASTLE or IS_WINDOWS or IS_MACOS or IS_FBCODE, |
| "Non-portable load_library call used in test", |
| ) |
| class JitBackendTestCaseWithCompiler(JitTestCase): |
| """ |
| A common base class for JIT backend tests with compilers that contains common utility |
| functions for output comparison. |
| """ |
| |
| def setUp(self): |
| super().setUp() |
| lib_file_path = find_library_location("libbackend_with_compiler.so") |
| torch.ops.load_library(str(lib_file_path)) |
| # Subclasses are expected to set up four variables in their setUp methods: |
| # module - a regular, Python version of the module being tested |
| # scripted_module - a scripted version of module |
| # lowered_module - a version of module lowered to a backend |
| # mobile_module - a module with a format that Pytorch Mobile can execute |
| |
| def check_forward(self, input): |
| """ |
| Check that the forward function produces the same output using |
| Python, regular JIT, the backend, and mobile for the given 'input'. |
| """ |
| |
| # Get outputs from forward. |
| python_output = self.module.forward(*input) |
| jit_output = self.scripted_module.forward(*input) |
| backend_output = self.lowered_module(*input) |
| mobile_output = self.mobile_module(*input) |
| |
| # The answers returned by Python, JIT, to_backend, and mobile should all match. |
| self.assertEqual(python_output, backend_output) |
| self.assertEqual(jit_output, backend_output) |
| self.assertEqual(mobile_output, backend_output) |
| |
| def test_execution(self): |
| """ |
| Stub for correctness tests. |
| """ |
| |
| def test_errors(self): |
| """ |
| Stub for testing error checking. |
| """ |
| |
| |
| class BasicModuleTestWithCompiler(JitBackendTestCaseWithCompiler): |
| """ |
| Tests for BasicModuleAdd. |
| """ |
| |
| def setUp(self): |
| super().setUp() |
| # Create Python, JIT and backend versions of BasicModuleAdd. |
| self.module = BasicModuleAdd() |
| self.scripted_module = torch.jit.script(BasicModuleAdd()) |
| compile_spec = { |
| "forward": { |
| "input_shapes": "((1, 1, 320, 240), (1, 3))", |
| "some_other_option": "True", |
| }, |
| } |
| self.lowered_module = torch._C._jit_to_backend( |
| "backend_with_compiler_demo", self.scripted_module, compile_spec |
| ) |
| # Create mobile version of BasicModuleAdd |
| buffer = io.BytesIO(self.lowered_module._save_to_buffer_for_lite_interpreter()) |
| buffer.seek(0) |
| self.mobile_module = _load_for_lite_interpreter(buffer) |
| |
| def test_execution(self): |
| # Test execution with backend against Python and JIT. |
| input = torch.ones(1, dtype=torch.float) |
| self.check_forward((input, input)) |
| |
| |
| class ErrorMessagesWithCompiler(JitBackendTestCase): |
| """ |
| Tests for errors that occur with compiler, specifically: |
| * an operator is not supported by the backend |
| """ |
| |
| class ModuleNotSupported(torch.nn.Module): |
| """ |
| A module with an operator that is not supported. |
| """ |
| |
| def forward(self, x, h): |
| return x * h |
| self._loweredmodule.forward() |
| |
| def test_errors(self): |
| scripted_module_n = torch.jit.script( |
| ErrorMessagesWithCompiler.ModuleNotSupported() |
| ) |
| # Test exception is thrown when lowering a module with an unsupported operator |
| with self.assertRaisesRegexWithHighlight( |
| RuntimeError, |
| # Special escape characters are replaced with '.' |
| r"""The node of aten::mul is not supported in this compiler. .* |
| def forward.self, x, h.: |
| return x . h |
| ~~~~~ <--- HERE |
| self._loweredmodule.forward.. |
| """, |
| "", |
| ): |
| lowered_module_n = torch._C._jit_to_backend( |
| "backend_with_compiler_demo", scripted_module_n, {"forward": {"": ""}} |
| ) |
| |
| |
| class CompModuleTestWithCompiler(JitBackendTestCase): |
| """ |
| Tests for CompModule, which is a module with two lowered submodules |
| """ |
| |
| class BasicModuleSub(torch.nn.Module): |
| """ |
| A simple subtraction Module to be used in CompModule. |
| """ |
| |
| def forward(self, x, h): |
| return x - h |
| |
| class CompModule(torch.nn.Module): |
| """ |
| A module with two lowered submodules. |
| """ |
| |
| def __init__(self, addmodule, submodule): |
| super().__init__() |
| self.lowered_add = addmodule |
| self.lowered_sub = submodule |
| |
| def forward(self, a, b, s): |
| c = self.lowered_add.forward(a, b) |
| d = self.lowered_sub.forward(a, b) |
| y = s * (c * d) |
| return y |
| |
| def setUp(self): |
| super().setUp() |
| # Create Python and JIT versions of CompModule with lowered submodules. |
| compile_spec = { |
| "forward": { |
| "input_shapes": "((1, 1, 320, 240), (1, 3))", |
| "some_other_option": "True", |
| }, |
| } |
| lowered_add = torch._C._jit_to_backend( |
| "backend_with_compiler_demo", |
| torch.jit.script(BasicModuleAdd()), |
| compile_spec, |
| ) |
| lowered_sub = torch._C._jit_to_backend( |
| "backend_with_compiler_demo", |
| torch.jit.script(CompModuleTestWithCompiler.BasicModuleSub()), |
| {"forward": {"": ""}}, |
| ) |
| self.module = CompModuleTestWithCompiler.CompModule(lowered_add, lowered_sub) |
| self.scripted_module = torch.jit.script( |
| CompModuleTestWithCompiler.CompModule(lowered_add, lowered_sub) |
| ) |
| # No backend version of CompModule currently, so this is filler. |
| self.lowered_module = self.scripted_module |
| # Create a mobile version of CompModule from JIT version |
| buffer = io.BytesIO(self.scripted_module._save_to_buffer_for_lite_interpreter()) |
| buffer.seek(0) |
| self.mobile_module = _load_for_lite_interpreter(buffer) |
| |
| def test_execution(self): |
| # Test execution with backend against Python and JIT. |
| input1 = torch.ones(1, dtype=torch.float) |
| input2 = torch.ones(1, dtype=torch.float) |
| |
| # Test forward. |
| self.check_function("forward", (input1, input2, input2)) |
| |
| |
| # This is needed for IS_WINDOWS or IS_MACOS to skip the tests. |
| @unittest.skipIf( |
| IS_SANDCASTLE or IS_WINDOWS or IS_MACOS or IS_FBCODE, |
| "Non-portable load_library call used in test", |
| ) |
| class TestBackendsWithCompiler(JitTestCase): |
| """ |
| This class wraps and invokes all subclasses of JitBackendTestCaseWithCompiler |
| so that each one does not have to be individually imported in test_jit.py. |
| """ |
| |
| def __init__(self, name): |
| super().__init__(name) |
| self.basic_module_compiler_test = BasicModuleTestWithCompiler(name) |
| self.error_module_compiler_test = ErrorMessagesWithCompiler(name) |
| self.comp_module_compiler_test = CompModuleTestWithCompiler(name) |
| |
| def setUp(self): |
| super().setUp() |
| self.basic_module_compiler_test.setUp() |
| self.error_module_compiler_test.setUp() |
| self.comp_module_compiler_test.setUp() |
| |
| def test_execution(self): |
| self.basic_module_compiler_test.test_execution() |
| self.comp_module_compiler_test.test_execution() |
| |
| def test_errors(self): |
| self.error_module_compiler_test.test_errors() |
| |
| |
| class CompModuleTestSameNameWithCompiler(JitBackendTestCase): |
| """ |
| Tests for CompModule, which is a module with two lowered submodules with same module name |
| """ |
| |
| class ModuleAdd(torch.nn.Module): |
| """ |
| A simple Module used to test to_backend lowering machinery. |
| """ |
| |
| def forward(self, x, h): |
| return x + h |
| |
| class CompModule(torch.nn.Module): |
| """ |
| A module with two lowered submodules. |
| """ |
| |
| def __init__(self) -> None: |
| super().__init__() |
| compile_spec = { |
| "forward": { |
| "some_other_option": "True", |
| }, |
| } |
| self.add = torch._C._jit_to_backend( |
| "backend_with_compiler_demo", |
| torch.jit.script(ModuleAdd()), # noqa: F821 |
| compile_spec, |
| ) |
| self.sub = torch._C._jit_to_backend( |
| "backend_with_compiler_demo", |
| torch.jit.script(ModuleAdd()), # noqa: F821 |
| compile_spec, |
| ) |
| |
| def forward(self, a, b, s: int): |
| c = self.add.forward(a, b) |
| d = self.sub.forward(a, b) |
| y = s * (c * d) |
| return y |
| |
| def setUp(self): |
| super().setUp() |
| |
| self.module = CompModule() # noqa: F821 |
| self.scripted_module = torch.jit.script(self.module) |
| buffer = io.BytesIO(self.scripted_module._save_to_buffer_for_lite_interpreter()) |
| buffer.seek(0) |
| self.mobile_module = _load_for_lite_interpreter(buffer) |
| |
| def test_execution(self): |
| a = torch.ones(1) |
| b = 3 * torch.ones(1) |
| s = 3 |
| # Test forward. |
| self.check_function("forward", (a, b, s)) |
| |
| |
| class AddedAttributesTest(JitBackendTestCase): |
| """ |
| Tests for adding attributes to a model after lowering. |
| """ |
| |
| def setUp(self): |
| super().setUp() |
| # Create Python, JIT and backend versions of BasicModule. |
| self.module = BasicModule() |
| self.scripted_module = torch.jit.script(BasicModule()) |
| self.lowered_module = to_test_backend_multi( |
| self.scripted_module, |
| {"accum": {"": ""}, "sub_accum": {"": ""}, "forward": {"": ""}}, |
| ) |
| |
| def test_attribute(self): |
| input = [(torch.ones(5),)] |
| pre_bundled = self.lowered_module(*input[0]) |
| # Attach bundled inputs which adds several attributes and functions to the model |
| self.lowered_module = ( |
| torch.utils.bundled_inputs.augment_model_with_bundled_inputs( |
| lowered_module, input # noqa: F821 |
| ) |
| ) |
| post_bundled = self.lowered_module( |
| *self.lowered_module.get_all_bundled_inputs()[0] |
| ) |
| # Save and load the lowered module. |
| self.save_load() |
| # Use bundled after save and load to prove its preserved |
| post_load = self.lowered_module( |
| *self.lowered_module.get_all_bundled_inputs()[0] |
| ) |
| self.assertEqual(pre_bundled, post_bundled) |
| self.assertEqual(post_bundled, post_load) |