| # -*- coding: utf-8 -*- |
| # Owner(s): ["oncall: package/deploy"] |
| |
| import inspect |
| import platform |
| from io import BytesIO |
| from pathlib import Path |
| from textwrap import dedent |
| from unittest import skipIf |
| |
| from torch.package import is_from_package, PackageExporter, PackageImporter |
| from torch.package.package_exporter import PackagingError |
| from torch.testing._internal.common_utils import IS_FBCODE, IS_SANDCASTLE, run_tests, skipIfTorchDynamo |
| |
| try: |
| from .common import PackageTestCase |
| except ImportError: |
| # Support the case where we run this file directly. |
| from common import PackageTestCase |
| |
| |
| class TestMisc(PackageTestCase): |
| """Tests for one-off or random functionality. Try not to add to this!""" |
| |
| def test_file_structure(self): |
| """ |
| Tests package's Directory structure representation of a zip file. Ensures |
| that the returned Directory prints what is expected and filters |
| inputs/outputs correctly. |
| """ |
| buffer = BytesIO() |
| |
| export_plain = dedent( |
| """\ |
| ├── .data |
| │ ├── extern_modules |
| │ ├── python_version |
| │ └── version |
| ├── main |
| │ └── main |
| ├── obj |
| │ └── obj.pkl |
| ├── package_a |
| │ ├── __init__.py |
| │ └── subpackage.py |
| └── module_a.py |
| """ |
| ) |
| export_include = dedent( |
| """\ |
| ├── obj |
| │ └── obj.pkl |
| └── package_a |
| └── subpackage.py |
| """ |
| ) |
| import_exclude = dedent( |
| """\ |
| ├── .data |
| │ ├── extern_modules |
| │ ├── python_version |
| │ └── version |
| ├── main |
| │ └── main |
| ├── obj |
| │ └── obj.pkl |
| ├── package_a |
| │ ├── __init__.py |
| │ └── subpackage.py |
| └── module_a.py |
| """ |
| ) |
| |
| with PackageExporter(buffer) as he: |
| import module_a |
| import package_a |
| import package_a.subpackage |
| |
| obj = package_a.subpackage.PackageASubpackageObject() |
| he.intern("**") |
| he.save_module(module_a.__name__) |
| he.save_module(package_a.__name__) |
| he.save_pickle("obj", "obj.pkl", obj) |
| he.save_text("main", "main", "my string") |
| |
| buffer.seek(0) |
| hi = PackageImporter(buffer) |
| |
| file_structure = hi.file_structure() |
| # remove first line from testing because WINDOW/iOS/Unix treat the buffer differently |
| self.assertEqual( |
| dedent("\n".join(str(file_structure).split("\n")[1:])), |
| export_plain, |
| ) |
| file_structure = hi.file_structure(include=["**/subpackage.py", "**/*.pkl"]) |
| self.assertEqual( |
| dedent("\n".join(str(file_structure).split("\n")[1:])), |
| export_include, |
| ) |
| |
| file_structure = hi.file_structure(exclude="**/*.storage") |
| self.assertEqual( |
| dedent("\n".join(str(file_structure).split("\n")[1:])), |
| import_exclude, |
| ) |
| |
| def test_python_version(self): |
| """ |
| Tests that the current python version is stored in the package and is available |
| via PackageImporter's python_version() method. |
| """ |
| buffer = BytesIO() |
| |
| with PackageExporter(buffer) as he: |
| from package_a.test_module import SimpleTest |
| |
| he.intern("**") |
| obj = SimpleTest() |
| he.save_pickle("obj", "obj.pkl", obj) |
| |
| buffer.seek(0) |
| hi = PackageImporter(buffer) |
| |
| self.assertEqual(hi.python_version(), platform.python_version()) |
| |
| @skipIf( |
| IS_FBCODE or IS_SANDCASTLE, |
| "Tests that use temporary files are disabled in fbcode", |
| ) |
| def test_load_python_version_from_package(self): |
| """Tests loading a package with a python version embdded""" |
| importer1 = PackageImporter( |
| f"{Path(__file__).parent}/package_e/test_nn_module.pt" |
| ) |
| self.assertEqual(importer1.python_version(), "3.9.7") |
| |
| def test_file_structure_has_file(self): |
| """ |
| Test Directory's has_file() method. |
| """ |
| buffer = BytesIO() |
| with PackageExporter(buffer) as he: |
| import package_a.subpackage |
| |
| he.intern("**") |
| obj = package_a.subpackage.PackageASubpackageObject() |
| he.save_pickle("obj", "obj.pkl", obj) |
| |
| buffer.seek(0) |
| |
| importer = PackageImporter(buffer) |
| file_structure = importer.file_structure() |
| self.assertTrue(file_structure.has_file("package_a/subpackage.py")) |
| self.assertFalse(file_structure.has_file("package_a/subpackage")) |
| |
| def test_exporter_content_lists(self): |
| """ |
| Test content list API for PackageExporter's contained modules. |
| """ |
| |
| with PackageExporter(BytesIO()) as he: |
| import package_b |
| |
| he.extern("package_b.subpackage_1") |
| he.mock("package_b.subpackage_2") |
| he.intern("**") |
| he.save_pickle("obj", "obj.pkl", package_b.PackageBObject(["a"])) |
| self.assertEqual(he.externed_modules(), ["package_b.subpackage_1"]) |
| self.assertEqual(he.mocked_modules(), ["package_b.subpackage_2"]) |
| self.assertEqual( |
| he.interned_modules(), |
| ["package_b", "package_b.subpackage_0.subsubpackage_0"], |
| ) |
| self.assertEqual(he.get_rdeps("package_b.subpackage_2"), ["package_b"]) |
| |
| with self.assertRaises(PackagingError) as e: |
| with PackageExporter(BytesIO()) as he: |
| import package_b |
| |
| he.deny("package_b") |
| he.save_pickle("obj", "obj.pkl", package_b.PackageBObject(["a"])) |
| self.assertEqual(he.denied_modules(), ["package_b"]) |
| |
| def test_is_from_package(self): |
| """is_from_package should work for objects and modules""" |
| import package_a.subpackage |
| |
| buffer = BytesIO() |
| obj = package_a.subpackage.PackageASubpackageObject() |
| |
| with PackageExporter(buffer) as pe: |
| pe.intern("**") |
| pe.save_pickle("obj", "obj.pkl", obj) |
| |
| buffer.seek(0) |
| pi = PackageImporter(buffer) |
| mod = pi.import_module("package_a.subpackage") |
| loaded_obj = pi.load_pickle("obj", "obj.pkl") |
| |
| self.assertFalse(is_from_package(package_a.subpackage)) |
| self.assertTrue(is_from_package(mod)) |
| |
| self.assertFalse(is_from_package(obj)) |
| self.assertTrue(is_from_package(loaded_obj)) |
| |
| def test_inspect_class(self): |
| """Should be able to retrieve source for a packaged class.""" |
| import package_a.subpackage |
| |
| buffer = BytesIO() |
| obj = package_a.subpackage.PackageASubpackageObject() |
| |
| with PackageExporter(buffer) as pe: |
| pe.intern("**") |
| pe.save_pickle("obj", "obj.pkl", obj) |
| |
| buffer.seek(0) |
| pi = PackageImporter(buffer) |
| packaged_class = pi.import_module( |
| "package_a.subpackage" |
| ).PackageASubpackageObject |
| regular_class = package_a.subpackage.PackageASubpackageObject |
| |
| packaged_src = inspect.getsourcelines(packaged_class) |
| regular_src = inspect.getsourcelines(regular_class) |
| self.assertEqual(packaged_src, regular_src) |
| |
| def test_dunder_package_present(self): |
| """ |
| The attribute '__torch_package__' should be populated on imported modules. |
| """ |
| import package_a.subpackage |
| |
| buffer = BytesIO() |
| obj = package_a.subpackage.PackageASubpackageObject() |
| |
| with PackageExporter(buffer) as pe: |
| pe.intern("**") |
| pe.save_pickle("obj", "obj.pkl", obj) |
| |
| buffer.seek(0) |
| pi = PackageImporter(buffer) |
| mod = pi.import_module("package_a.subpackage") |
| self.assertTrue(hasattr(mod, "__torch_package__")) |
| |
| def test_dunder_package_works_from_package(self): |
| """ |
| The attribute '__torch_package__' should be accessible from within |
| the module itself, so that packaged code can detect whether it's |
| being used in a packaged context or not. |
| """ |
| import package_a.use_dunder_package as mod |
| |
| buffer = BytesIO() |
| |
| with PackageExporter(buffer) as pe: |
| pe.intern("**") |
| pe.save_module(mod.__name__) |
| |
| buffer.seek(0) |
| pi = PackageImporter(buffer) |
| imported_mod = pi.import_module(mod.__name__) |
| self.assertTrue(imported_mod.is_from_package()) |
| self.assertFalse(mod.is_from_package()) |
| |
| @skipIfTorchDynamo("Not a suitable test for TorchDynamo") |
| def test_std_lib_sys_hackery_checks(self): |
| """ |
| The standard library performs sys.module assignment hackery which |
| causes modules who do this hackery to fail on import. See |
| https://github.com/pytorch/pytorch/issues/57490 for more information. |
| """ |
| import package_a.std_sys_module_hacks |
| |
| buffer = BytesIO() |
| mod = package_a.std_sys_module_hacks.Module() |
| |
| with PackageExporter(buffer) as pe: |
| pe.intern("**") |
| pe.save_pickle("obj", "obj.pkl", mod) |
| |
| buffer.seek(0) |
| pi = PackageImporter(buffer) |
| mod = pi.load_pickle("obj", "obj.pkl") |
| mod() |
| |
| |
| if __name__ == "__main__": |
| run_tests() |