| # Owner(s): ["oncall: package/deploy"] |
| |
| from pathlib import Path |
| from unittest import skipIf |
| |
| from torch.package import PackageImporter |
| from torch.testing._internal.common_utils import IS_FBCODE, IS_SANDCASTLE, run_tests |
| |
| try: |
| from .common import PackageTestCase |
| except ImportError: |
| # Support the case where we run this file directly. |
| from common import PackageTestCase |
| |
| packaging_directory = f"{Path(__file__).parent}/package_bc" |
| |
| |
| class TestLoadBCPackages(PackageTestCase): |
| """Tests for checking loading has backwards compatiblity""" |
| |
| @skipIf( |
| IS_FBCODE or IS_SANDCASTLE, |
| "Tests that use temporary files are disabled in fbcode", |
| ) |
| def test_load_bc_packages_nn_module(self): |
| """Tests for backwards compatible nn module""" |
| importer1 = PackageImporter(f"{packaging_directory}/test_nn_module.pt") |
| loaded1 = importer1.load_pickle("nn_module", "nn_module.pkl") |
| |
| @skipIf( |
| IS_FBCODE or IS_SANDCASTLE, |
| "Tests that use temporary files are disabled in fbcode", |
| ) |
| def test_load_bc_packages_torchscript_module(self): |
| |
| """Tests for backwards compatible torchscript module""" |
| importer2 = PackageImporter(f"{packaging_directory}/test_torchscript_module.pt") |
| loaded2 = importer2.load_pickle("torchscript_module", "torchscript_module.pkl") |
| |
| @skipIf( |
| IS_FBCODE or IS_SANDCASTLE, |
| "Tests that use temporary files are disabled in fbcode", |
| ) |
| def test_load_bc_packages_fx_module(self): |
| """Tests for backwards compatible fx module""" |
| importer3 = PackageImporter(f"{packaging_directory}/test_fx_module.pt") |
| loaded3 = importer3.load_pickle("fx_module", "fx_module.pkl") |
| |
| |
| if __name__ == "__main__": |
| run_tests() |