| # Owner(s): ["oncall: package/deploy"] |
| |
| import inspect |
| import os |
| import platform |
| import sys |
| from io import BytesIO |
| from pathlib import Path |
| from textwrap import dedent |
| from unittest import skipIf |
| |
| from torch.package import is_from_package, PackageExporter, PackageImporter |
| from torch.package.package_exporter import PackagingError |
| from torch.testing._internal.common_utils import ( |
| IS_FBCODE, |
| IS_SANDCASTLE, |
| run_tests, |
| skipIfTorchDynamo, |
| ) |
| |
| |
| try: |
| from .common import PackageTestCase |
| except ImportError: |
| # Support the case where we run this file directly. |
| from common import PackageTestCase |
| |
| |
| class TestMisc(PackageTestCase): |
| """Tests for one-off or random functionality. Try not to add to this!""" |
| |
| def test_file_structure(self): |
| """ |
| Tests package's Directory structure representation of a zip file. Ensures |
| that the returned Directory prints what is expected and filters |
| inputs/outputs correctly. |
| """ |
| buffer = BytesIO() |
| |
| export_plain = dedent( |
| """\ |
| \u251c\u2500\u2500 .data |
| \u2502 \u251c\u2500\u2500 extern_modules |
| \u2502 \u251c\u2500\u2500 python_version |
| \u2502 \u251c\u2500\u2500 serialization_id |
| \u2502 \u2514\u2500\u2500 version |
| \u251c\u2500\u2500 main |
| \u2502 \u2514\u2500\u2500 main |
| \u251c\u2500\u2500 obj |
| \u2502 \u2514\u2500\u2500 obj.pkl |
| \u251c\u2500\u2500 package_a |
| \u2502 \u251c\u2500\u2500 __init__.py |
| \u2502 \u2514\u2500\u2500 subpackage.py |
| \u251c\u2500\u2500 byteorder |
| \u2514\u2500\u2500 module_a.py |
| """ |
| ) |
| export_include = dedent( |
| """\ |
| \u251c\u2500\u2500 obj |
| \u2502 \u2514\u2500\u2500 obj.pkl |
| \u2514\u2500\u2500 package_a |
| \u2514\u2500\u2500 subpackage.py |
| """ |
| ) |
| import_exclude = dedent( |
| """\ |
| \u251c\u2500\u2500 .data |
| \u2502 \u251c\u2500\u2500 extern_modules |
| \u2502 \u251c\u2500\u2500 python_version |
| \u2502 \u251c\u2500\u2500 serialization_id |
| \u2502 \u2514\u2500\u2500 version |
| \u251c\u2500\u2500 main |
| \u2502 \u2514\u2500\u2500 main |
| \u251c\u2500\u2500 obj |
| \u2502 \u2514\u2500\u2500 obj.pkl |
| \u251c\u2500\u2500 package_a |
| \u2502 \u251c\u2500\u2500 __init__.py |
| \u2502 \u2514\u2500\u2500 subpackage.py |
| \u251c\u2500\u2500 byteorder |
| \u2514\u2500\u2500 module_a.py |
| """ |
| ) |
| |
| with PackageExporter(buffer) as he: |
| import module_a |
| import package_a |
| import package_a.subpackage |
| |
| obj = package_a.subpackage.PackageASubpackageObject() |
| he.intern("**") |
| he.save_module(module_a.__name__) |
| he.save_module(package_a.__name__) |
| he.save_pickle("obj", "obj.pkl", obj) |
| he.save_text("main", "main", "my string") |
| |
| buffer.seek(0) |
| hi = PackageImporter(buffer) |
| |
| file_structure = hi.file_structure() |
| # remove first line from testing because WINDOW/iOS/Unix treat the buffer differently |
| self.assertEqual( |
| dedent("\n".join(str(file_structure).split("\n")[1:])), |
| export_plain, |
| ) |
| file_structure = hi.file_structure(include=["**/subpackage.py", "**/*.pkl"]) |
| self.assertEqual( |
| dedent("\n".join(str(file_structure).split("\n")[1:])), |
| export_include, |
| ) |
| |
| file_structure = hi.file_structure(exclude="**/*.storage") |
| self.assertEqual( |
| dedent("\n".join(str(file_structure).split("\n")[1:])), |
| import_exclude, |
| ) |
| |
| def test_loaders_that_remap_files_work_ok(self): |
| from importlib.abc import MetaPathFinder |
| from importlib.machinery import SourceFileLoader |
| from importlib.util import spec_from_loader |
| |
| class LoaderThatRemapsModuleA(SourceFileLoader): |
| def get_filename(self, name): |
| result = super().get_filename(name) |
| if name == "module_a": |
| return os.path.join( |
| os.path.dirname(result), "module_a_remapped_path.py" |
| ) |
| else: |
| return result |
| |
| class FinderThatRemapsModuleA(MetaPathFinder): |
| def find_spec(self, fullname, path, target): |
| """Try to find the original spec for module_a using all the |
| remaining meta_path finders.""" |
| if fullname != "module_a": |
| return None |
| spec = None |
| for finder in sys.meta_path: |
| if finder is self: |
| continue |
| if hasattr(finder, "find_spec"): |
| spec = finder.find_spec(fullname, path, target=target) |
| elif hasattr(finder, "load_module"): |
| spec = spec_from_loader(fullname, finder) |
| if spec is not None: |
| break |
| assert spec is not None and isinstance(spec.loader, SourceFileLoader) |
| spec.loader = LoaderThatRemapsModuleA( |
| spec.loader.name, spec.loader.path |
| ) |
| return spec |
| |
| sys.meta_path.insert(0, FinderThatRemapsModuleA()) |
| # clear it from sys.modules so that we use the custom finder next time |
| # it gets imported |
| sys.modules.pop("module_a", None) |
| try: |
| buffer = BytesIO() |
| with PackageExporter(buffer) as he: |
| import module_a |
| |
| he.intern("**") |
| he.save_module(module_a.__name__) |
| |
| buffer.seek(0) |
| hi = PackageImporter(buffer) |
| self.assertTrue("remapped_path" in hi.get_source("module_a")) |
| finally: |
| # pop it again to ensure it does not mess up other tests |
| sys.modules.pop("module_a", None) |
| sys.meta_path.pop(0) |
| |
| def test_python_version(self): |
| """ |
| Tests that the current python version is stored in the package and is available |
| via PackageImporter's python_version() method. |
| """ |
| buffer = BytesIO() |
| |
| with PackageExporter(buffer) as he: |
| from package_a.test_module import SimpleTest |
| |
| he.intern("**") |
| obj = SimpleTest() |
| he.save_pickle("obj", "obj.pkl", obj) |
| |
| buffer.seek(0) |
| hi = PackageImporter(buffer) |
| |
| self.assertEqual(hi.python_version(), platform.python_version()) |
| |
| @skipIf( |
| IS_FBCODE or IS_SANDCASTLE, |
| "Tests that use temporary files are disabled in fbcode", |
| ) |
| def test_load_python_version_from_package(self): |
| """Tests loading a package with a python version embdded""" |
| importer1 = PackageImporter( |
| f"{Path(__file__).parent}/package_e/test_nn_module.pt" |
| ) |
| self.assertEqual(importer1.python_version(), "3.9.7") |
| |
| def test_file_structure_has_file(self): |
| """ |
| Test Directory's has_file() method. |
| """ |
| buffer = BytesIO() |
| with PackageExporter(buffer) as he: |
| import package_a.subpackage |
| |
| he.intern("**") |
| obj = package_a.subpackage.PackageASubpackageObject() |
| he.save_pickle("obj", "obj.pkl", obj) |
| |
| buffer.seek(0) |
| |
| importer = PackageImporter(buffer) |
| file_structure = importer.file_structure() |
| self.assertTrue(file_structure.has_file("package_a/subpackage.py")) |
| self.assertFalse(file_structure.has_file("package_a/subpackage")) |
| |
| def test_exporter_content_lists(self): |
| """ |
| Test content list API for PackageExporter's contained modules. |
| """ |
| |
| with PackageExporter(BytesIO()) as he: |
| import package_b |
| |
| he.extern("package_b.subpackage_1") |
| he.mock("package_b.subpackage_2") |
| he.intern("**") |
| he.save_pickle("obj", "obj.pkl", package_b.PackageBObject(["a"])) |
| self.assertEqual(he.externed_modules(), ["package_b.subpackage_1"]) |
| self.assertEqual(he.mocked_modules(), ["package_b.subpackage_2"]) |
| self.assertEqual( |
| he.interned_modules(), |
| ["package_b", "package_b.subpackage_0.subsubpackage_0"], |
| ) |
| self.assertEqual(he.get_rdeps("package_b.subpackage_2"), ["package_b"]) |
| |
| with self.assertRaises(PackagingError) as e: |
| with PackageExporter(BytesIO()) as he: |
| import package_b |
| |
| he.deny("package_b") |
| he.save_pickle("obj", "obj.pkl", package_b.PackageBObject(["a"])) |
| self.assertEqual(he.denied_modules(), ["package_b"]) |
| |
| def test_is_from_package(self): |
| """is_from_package should work for objects and modules""" |
| import package_a.subpackage |
| |
| buffer = BytesIO() |
| obj = package_a.subpackage.PackageASubpackageObject() |
| |
| with PackageExporter(buffer) as pe: |
| pe.intern("**") |
| pe.save_pickle("obj", "obj.pkl", obj) |
| |
| buffer.seek(0) |
| pi = PackageImporter(buffer) |
| mod = pi.import_module("package_a.subpackage") |
| loaded_obj = pi.load_pickle("obj", "obj.pkl") |
| |
| self.assertFalse(is_from_package(package_a.subpackage)) |
| self.assertTrue(is_from_package(mod)) |
| |
| self.assertFalse(is_from_package(obj)) |
| self.assertTrue(is_from_package(loaded_obj)) |
| |
| def test_inspect_class(self): |
| """Should be able to retrieve source for a packaged class.""" |
| import package_a.subpackage |
| |
| buffer = BytesIO() |
| obj = package_a.subpackage.PackageASubpackageObject() |
| |
| with PackageExporter(buffer) as pe: |
| pe.intern("**") |
| pe.save_pickle("obj", "obj.pkl", obj) |
| |
| buffer.seek(0) |
| pi = PackageImporter(buffer) |
| packaged_class = pi.import_module( |
| "package_a.subpackage" |
| ).PackageASubpackageObject |
| regular_class = package_a.subpackage.PackageASubpackageObject |
| |
| packaged_src = inspect.getsourcelines(packaged_class) |
| regular_src = inspect.getsourcelines(regular_class) |
| self.assertEqual(packaged_src, regular_src) |
| |
| def test_dunder_package_present(self): |
| """ |
| The attribute '__torch_package__' should be populated on imported modules. |
| """ |
| import package_a.subpackage |
| |
| buffer = BytesIO() |
| obj = package_a.subpackage.PackageASubpackageObject() |
| |
| with PackageExporter(buffer) as pe: |
| pe.intern("**") |
| pe.save_pickle("obj", "obj.pkl", obj) |
| |
| buffer.seek(0) |
| pi = PackageImporter(buffer) |
| mod = pi.import_module("package_a.subpackage") |
| self.assertTrue(hasattr(mod, "__torch_package__")) |
| |
| def test_dunder_package_works_from_package(self): |
| """ |
| The attribute '__torch_package__' should be accessible from within |
| the module itself, so that packaged code can detect whether it's |
| being used in a packaged context or not. |
| """ |
| import package_a.use_dunder_package as mod |
| |
| buffer = BytesIO() |
| |
| with PackageExporter(buffer) as pe: |
| pe.intern("**") |
| pe.save_module(mod.__name__) |
| |
| buffer.seek(0) |
| pi = PackageImporter(buffer) |
| imported_mod = pi.import_module(mod.__name__) |
| self.assertTrue(imported_mod.is_from_package()) |
| self.assertFalse(mod.is_from_package()) |
| |
| @skipIfTorchDynamo("Not a suitable test for TorchDynamo") |
| def test_std_lib_sys_hackery_checks(self): |
| """ |
| The standard library performs sys.module assignment hackery which |
| causes modules who do this hackery to fail on import. See |
| https://github.com/pytorch/pytorch/issues/57490 for more information. |
| """ |
| import package_a.std_sys_module_hacks |
| |
| buffer = BytesIO() |
| mod = package_a.std_sys_module_hacks.Module() |
| |
| with PackageExporter(buffer) as pe: |
| pe.intern("**") |
| pe.save_pickle("obj", "obj.pkl", mod) |
| |
| buffer.seek(0) |
| pi = PackageImporter(buffer) |
| mod = pi.load_pickle("obj", "obj.pkl") |
| mod() |
| |
| |
| if __name__ == "__main__": |
| run_tests() |