| # Owner(s): ["oncall: package/deploy"] |
| |
| import importlib |
| from io import BytesIO |
| from sys import version_info |
| from textwrap import dedent |
| from unittest import skipIf |
| |
| import torch.nn |
| |
| from torch.package import EmptyMatchError, Importer, PackageExporter, PackageImporter |
| from torch.package.package_exporter import PackagingError |
| from torch.testing._internal.common_utils import IS_WINDOWS, run_tests |
| |
| try: |
| from .common import PackageTestCase |
| except ImportError: |
| # Support the case where we run this file directly. |
| from common import PackageTestCase |
| |
| |
| class TestDependencyAPI(PackageTestCase): |
| """Dependency management API tests. |
| - mock() |
| - extern() |
| - deny() |
| """ |
| |
| def test_extern(self): |
| buffer = BytesIO() |
| with PackageExporter(buffer) as he: |
| he.extern(["package_a.subpackage", "module_a"]) |
| he.save_source_string("foo", "import package_a.subpackage; import module_a") |
| buffer.seek(0) |
| hi = PackageImporter(buffer) |
| import module_a |
| import package_a.subpackage |
| |
| module_a_im = hi.import_module("module_a") |
| hi.import_module("package_a.subpackage") |
| package_a_im = hi.import_module("package_a") |
| |
| self.assertIs(module_a, module_a_im) |
| self.assertIsNot(package_a, package_a_im) |
| self.assertIs(package_a.subpackage, package_a_im.subpackage) |
| |
| def test_extern_glob(self): |
| buffer = BytesIO() |
| with PackageExporter(buffer) as he: |
| he.extern(["package_a.*", "module_*"]) |
| he.save_module("package_a") |
| he.save_source_string( |
| "test_module", |
| dedent( |
| """\ |
| import package_a.subpackage |
| import module_a |
| """ |
| ), |
| ) |
| buffer.seek(0) |
| hi = PackageImporter(buffer) |
| import module_a |
| import package_a.subpackage |
| |
| module_a_im = hi.import_module("module_a") |
| hi.import_module("package_a.subpackage") |
| package_a_im = hi.import_module("package_a") |
| |
| self.assertIs(module_a, module_a_im) |
| self.assertIsNot(package_a, package_a_im) |
| self.assertIs(package_a.subpackage, package_a_im.subpackage) |
| |
| def test_extern_glob_allow_empty(self): |
| """ |
| Test that an error is thrown when a extern glob is specified with allow_empty=True |
| and no matching module is required during packaging. |
| """ |
| import package_a.subpackage # noqa: F401 |
| |
| buffer = BytesIO() |
| with self.assertRaisesRegex(EmptyMatchError, r"did not match any modules"): |
| with PackageExporter(buffer) as exporter: |
| exporter.extern(include=["package_b.*"], allow_empty=False) |
| exporter.save_module("package_a.subpackage") |
| |
| def test_deny(self): |
| """ |
| Test marking packages as "deny" during export. |
| """ |
| buffer = BytesIO() |
| |
| with self.assertRaisesRegex(PackagingError, "denied"): |
| with PackageExporter(buffer) as exporter: |
| exporter.deny(["package_a.subpackage", "module_a"]) |
| exporter.save_source_string("foo", "import package_a.subpackage") |
| |
| def test_deny_glob(self): |
| """ |
| Test marking packages as "deny" using globs instead of package names. |
| """ |
| buffer = BytesIO() |
| with self.assertRaises(PackagingError): |
| with PackageExporter(buffer) as exporter: |
| exporter.deny(["package_a.*", "module_*"]) |
| exporter.save_source_string( |
| "test_module", |
| dedent( |
| """\ |
| import package_a.subpackage |
| import module_a |
| """ |
| ), |
| ) |
| |
| @skipIf(version_info < (3, 7), "mock uses __getattr__ a 3.7 feature") |
| def test_mock(self): |
| buffer = BytesIO() |
| with PackageExporter(buffer) as he: |
| he.mock(["package_a.subpackage", "module_a"]) |
| # Import something that dependso n package_a.subpackage |
| he.save_source_string("foo", "import package_a.subpackage") |
| buffer.seek(0) |
| hi = PackageImporter(buffer) |
| import package_a.subpackage |
| |
| _ = package_a.subpackage |
| import module_a |
| |
| _ = module_a |
| |
| m = hi.import_module("package_a.subpackage") |
| r = m.result |
| with self.assertRaisesRegex(NotImplementedError, "was mocked out"): |
| r() |
| |
| @skipIf(version_info < (3, 7), "mock uses __getattr__ a 3.7 feature") |
| def test_mock_glob(self): |
| buffer = BytesIO() |
| with PackageExporter(buffer) as he: |
| he.mock(["package_a.*", "module*"]) |
| he.save_module("package_a") |
| he.save_source_string( |
| "test_module", |
| dedent( |
| """\ |
| import package_a.subpackage |
| import module_a |
| """ |
| ), |
| ) |
| buffer.seek(0) |
| hi = PackageImporter(buffer) |
| import package_a.subpackage |
| |
| _ = package_a.subpackage |
| import module_a |
| |
| _ = module_a |
| |
| m = hi.import_module("package_a.subpackage") |
| r = m.result |
| with self.assertRaisesRegex(NotImplementedError, "was mocked out"): |
| r() |
| |
| def test_mock_glob_allow_empty(self): |
| """ |
| Test that an error is thrown when a mock glob is specified with allow_empty=True |
| and no matching module is required during packaging. |
| """ |
| import package_a.subpackage # noqa: F401 |
| |
| buffer = BytesIO() |
| with self.assertRaisesRegex(EmptyMatchError, r"did not match any modules"): |
| with PackageExporter(buffer) as exporter: |
| exporter.mock(include=["package_b.*"], allow_empty=False) |
| exporter.save_module("package_a.subpackage") |
| |
| @skipIf(version_info < (3, 7), "mock uses __getattr__ a 3.7 feature") |
| def test_pickle_mocked(self): |
| import package_a.subpackage |
| |
| obj = package_a.subpackage.PackageASubpackageObject() |
| obj2 = package_a.PackageAObject(obj) |
| |
| buffer = BytesIO() |
| with self.assertRaises(PackagingError): |
| with PackageExporter(buffer) as he: |
| he.mock(include="package_a.subpackage") |
| he.intern("**") |
| he.save_pickle("obj", "obj.pkl", obj2) |
| |
| @skipIf(version_info < (3, 7), "mock uses __getattr__ a 3.7 feature") |
| def test_pickle_mocked_all(self): |
| import package_a.subpackage |
| |
| obj = package_a.subpackage.PackageASubpackageObject() |
| obj2 = package_a.PackageAObject(obj) |
| |
| buffer = BytesIO() |
| with PackageExporter(buffer) as he: |
| he.intern(include="package_a.**") |
| he.mock("**") |
| he.save_pickle("obj", "obj.pkl", obj2) |
| |
| def test_allow_empty_with_error(self): |
| """If an error occurs during packaging, it should not be shadowed by the allow_empty error.""" |
| buffer = BytesIO() |
| with self.assertRaises(ModuleNotFoundError): |
| with PackageExporter(buffer) as pe: |
| # Even though we did not extern a module that matches this |
| # pattern, we want to show the save_module error, not the allow_empty error. |
| |
| pe.extern("foo", allow_empty=False) |
| pe.save_module("aodoifjodisfj") # will error |
| |
| # we never get here, so technically the allow_empty check |
| # should raise an error. However, the error above is more |
| # informative to what's actually going wrong with packaging. |
| pe.save_source_string("bar", "import foo\n") |
| |
| def test_implicit_intern(self): |
| """The save_module APIs should implicitly intern the module being saved.""" |
| import package_a # noqa: F401 |
| |
| buffer = BytesIO() |
| with PackageExporter(buffer) as he: |
| he.save_module("package_a") |
| |
| def test_intern_error(self): |
| """Failure to handle all dependencies should lead to an error.""" |
| import package_a.subpackage |
| |
| obj = package_a.subpackage.PackageASubpackageObject() |
| obj2 = package_a.PackageAObject(obj) |
| |
| buffer = BytesIO() |
| |
| with self.assertRaises(PackagingError) as e: |
| with PackageExporter(buffer) as he: |
| he.save_pickle("obj", "obj.pkl", obj2) |
| |
| self.assertEqual( |
| str(e.exception), |
| dedent( |
| """ |
| * Module did not match against any action pattern. Extern, mock, or intern it. |
| package_a |
| package_a.subpackage |
| """ |
| ), |
| ) |
| |
| # Interning all dependencies should work |
| with PackageExporter(buffer) as he: |
| he.intern(["package_a", "package_a.subpackage"]) |
| he.save_pickle("obj", "obj.pkl", obj2) |
| |
| @skipIf(IS_WINDOWS, "extension modules have a different file extension on windows") |
| def test_broken_dependency(self): |
| """A unpackageable dependency should raise a PackagingError.""" |
| |
| def create_module(name): |
| spec = importlib.machinery.ModuleSpec(name, self, is_package=False) # type: ignore[arg-type] |
| module = importlib.util.module_from_spec(spec) |
| ns = module.__dict__ |
| ns["__spec__"] = spec |
| ns["__loader__"] = self |
| ns["__file__"] = f"{name}.so" |
| ns["__cached__"] = None |
| return module |
| |
| class BrokenImporter(Importer): |
| def __init__(self): |
| self.modules = { |
| "foo": create_module("foo"), |
| "bar": create_module("bar"), |
| } |
| |
| def import_module(self, module_name): |
| return self.modules[module_name] |
| |
| buffer = BytesIO() |
| |
| with self.assertRaises(PackagingError) as e: |
| with PackageExporter(buffer, importer=BrokenImporter()) as exporter: |
| exporter.intern(["foo", "bar"]) |
| exporter.save_source_string("my_module", "import foo; import bar") |
| |
| self.assertEqual( |
| str(e.exception), |
| dedent( |
| """ |
| * Module is a C extension module. torch.package supports Python modules only. |
| foo |
| bar |
| """ |
| ), |
| ) |
| |
| def test_invalid_import(self): |
| """An incorrectly-formed import should raise a PackagingError.""" |
| buffer = BytesIO() |
| with self.assertRaises(PackagingError) as e: |
| with PackageExporter(buffer) as exporter: |
| # This import will fail to load. |
| exporter.save_source_string("foo", "from ........ import lol") |
| |
| self.assertEqual( |
| str(e.exception), |
| dedent( |
| """ |
| * Dependency resolution failed. |
| foo |
| Context: attempted relative import beyond top-level package |
| """ |
| ), |
| ) |
| |
| @skipIf(version_info < (3, 7), "mock uses __getattr__ a 3.7 feature") |
| def test_repackage_mocked_module(self): |
| """Re-packaging a package that contains a mocked module should work correctly.""" |
| buffer = BytesIO() |
| with PackageExporter(buffer) as exporter: |
| exporter.mock("package_a") |
| exporter.save_source_string("foo", "import package_a") |
| |
| buffer.seek(0) |
| importer = PackageImporter(buffer) |
| foo = importer.import_module("foo") |
| |
| # "package_a" should be mocked out. |
| with self.assertRaises(NotImplementedError): |
| foo.package_a.get_something() |
| |
| # Re-package the model, but intern the previously-mocked module and mock |
| # everything else. |
| buffer2 = BytesIO() |
| with PackageExporter(buffer2, importer=importer) as exporter: |
| exporter.intern("package_a") |
| exporter.mock("**") |
| exporter.save_source_string("foo", "import package_a") |
| |
| buffer2.seek(0) |
| importer2 = PackageImporter(buffer2) |
| foo2 = importer2.import_module("foo") |
| |
| # "package_a" should still be mocked out. |
| with self.assertRaises(NotImplementedError): |
| foo2.package_a.get_something() |
| |
| def test_externing_c_extension(self): |
| """Externing c extensions modules should allow us to still access them especially those found in torch._C.""" |
| |
| buffer = BytesIO() |
| # The C extension module in question is F.gelu which comes from torch._C._nn |
| model = torch.nn.TransformerEncoderLayer( |
| d_model=64, |
| nhead=2, |
| dim_feedforward=64, |
| dropout=1.0, |
| batch_first=True, |
| activation="gelu", |
| norm_first=True, |
| ) |
| with PackageExporter(buffer) as e: |
| e.extern("torch.**") |
| e.intern("**") |
| |
| e.save_pickle("model", "model.pkl", model) |
| buffer.seek(0) |
| imp = PackageImporter(buffer) |
| imp.load_pickle("model", "model.pkl") |
| |
| |
| if __name__ == "__main__": |
| run_tests() |