| # Owner(s): ["oncall: package/deploy"] |
| |
| from io import BytesIO |
| |
| from torch.package import PackageExporter, PackageImporter |
| from torch.package._mangling import ( |
| demangle, |
| get_mangle_prefix, |
| is_mangled, |
| PackageMangler, |
| ) |
| from torch.testing._internal.common_utils import run_tests |
| |
| try: |
| from .common import PackageTestCase |
| except ImportError: |
| # Support the case where we run this file directly. |
| from common import PackageTestCase |
| |
| |
| class TestMangling(PackageTestCase): |
| def test_unique_manglers(self): |
| """ |
| Each mangler instance should generate a unique mangled name for a given input. |
| """ |
| a = PackageMangler() |
| b = PackageMangler() |
| self.assertNotEqual(a.mangle("foo.bar"), b.mangle("foo.bar")) |
| |
| def test_mangler_is_consistent(self): |
| """ |
| Mangling the same name twice should produce the same result. |
| """ |
| a = PackageMangler() |
| self.assertEqual(a.mangle("abc.def"), a.mangle("abc.def")) |
| |
| def test_roundtrip_mangling(self): |
| a = PackageMangler() |
| self.assertEqual("foo", demangle(a.mangle("foo"))) |
| |
| def test_is_mangled(self): |
| a = PackageMangler() |
| b = PackageMangler() |
| self.assertTrue(is_mangled(a.mangle("foo.bar"))) |
| self.assertTrue(is_mangled(b.mangle("foo.bar"))) |
| |
| self.assertFalse(is_mangled("foo.bar")) |
| self.assertFalse(is_mangled(demangle(a.mangle("foo.bar")))) |
| |
| def test_demangler_multiple_manglers(self): |
| """ |
| PackageDemangler should be able to demangle name generated by any PackageMangler. |
| """ |
| a = PackageMangler() |
| b = PackageMangler() |
| |
| self.assertEqual("foo.bar", demangle(a.mangle("foo.bar"))) |
| self.assertEqual("bar.foo", demangle(b.mangle("bar.foo"))) |
| |
| def test_mangle_empty_errors(self): |
| a = PackageMangler() |
| with self.assertRaises(AssertionError): |
| a.mangle("") |
| |
| def test_demangle_base(self): |
| """ |
| Demangling a mangle parent directly should currently return an empty string. |
| """ |
| a = PackageMangler() |
| mangled = a.mangle("foo") |
| mangle_parent = mangled.partition(".")[0] |
| self.assertEqual("", demangle(mangle_parent)) |
| |
| def test_mangle_prefix(self): |
| a = PackageMangler() |
| mangled = a.mangle("foo.bar") |
| mangle_prefix = get_mangle_prefix(mangled) |
| self.assertEqual(mangle_prefix + "." + "foo.bar", mangled) |
| |
| def test_unique_module_names(self): |
| import package_a.subpackage |
| |
| obj = package_a.subpackage.PackageASubpackageObject() |
| obj2 = package_a.PackageAObject(obj) |
| f1 = BytesIO() |
| with PackageExporter(f1) as pe: |
| pe.intern("**") |
| pe.save_pickle("obj", "obj.pkl", obj2) |
| f1.seek(0) |
| importer1 = PackageImporter(f1) |
| loaded1 = importer1.load_pickle("obj", "obj.pkl") |
| f1.seek(0) |
| importer2 = PackageImporter(f1) |
| loaded2 = importer2.load_pickle("obj", "obj.pkl") |
| |
| # Modules from loaded packages should not shadow the names of modules. |
| # See mangling.md for more info. |
| self.assertNotEqual(type(obj2).__module__, type(loaded1).__module__) |
| self.assertNotEqual(type(loaded1).__module__, type(loaded2).__module__) |
| |
| def test_package_mangler(self): |
| a = PackageMangler() |
| b = PackageMangler() |
| a_mangled = a.mangle("foo.bar") |
| # Since `a` mangled this string, it should demangle properly. |
| self.assertEqual(a.demangle(a_mangled), "foo.bar") |
| # Since `b` did not mangle this string, demangling should leave it alone. |
| self.assertEqual(b.demangle(a_mangled), a_mangled) |
| |
| |
| if __name__ == "__main__": |
| run_tests() |