blob: 112c43b4dc9c186a065d927a649593c3c792e4e3 [file] [log] [blame]
# Owner(s): ["oncall: package/deploy"]
import pickle
from io import BytesIO
from textwrap import dedent
from torch.package import PackageExporter, PackageImporter, sys_importer
from torch.testing._internal.common_utils import run_tests
try:
from .common import PackageTestCase
except ImportError:
# Support the case where we run this file directly.
from common import PackageTestCase
from pathlib import Path
packaging_directory = Path(__file__).parent
class TestSaveLoad(PackageTestCase):
"""Core save_* and loading API tests."""
def test_saving_source(self):
buffer = BytesIO()
with PackageExporter(buffer) as he:
he.save_source_file("foo", str(packaging_directory / "module_a.py"))
he.save_source_file("foodir", str(packaging_directory / "package_a"))
buffer.seek(0)
hi = PackageImporter(buffer)
foo = hi.import_module("foo")
s = hi.import_module("foodir.subpackage")
self.assertEqual(foo.result, "module_a")
self.assertEqual(s.result, "package_a.subpackage")
def test_saving_string(self):
buffer = BytesIO()
with PackageExporter(buffer) as he:
src = dedent(
"""\
import math
the_math = math
"""
)
he.save_source_string("my_mod", src)
buffer.seek(0)
hi = PackageImporter(buffer)
m = hi.import_module("math")
import math
self.assertIs(m, math)
my_mod = hi.import_module("my_mod")
self.assertIs(my_mod.math, math)
def test_save_module(self):
buffer = BytesIO()
with PackageExporter(buffer) as he:
import module_a
import package_a
he.save_module(module_a.__name__)
he.save_module(package_a.__name__)
buffer.seek(0)
hi = PackageImporter(buffer)
module_a_i = hi.import_module("module_a")
self.assertEqual(module_a_i.result, "module_a")
self.assertIsNot(module_a, module_a_i)
package_a_i = hi.import_module("package_a")
self.assertEqual(package_a_i.result, "package_a")
self.assertIsNot(package_a_i, package_a)
def test_dunder_imports(self):
buffer = BytesIO()
with PackageExporter(buffer) as he:
import package_b
obj = package_b.PackageBObject
he.intern("**")
he.save_pickle("res", "obj.pkl", obj)
buffer.seek(0)
hi = PackageImporter(buffer)
loaded_obj = hi.load_pickle("res", "obj.pkl")
package_b = hi.import_module("package_b")
self.assertEqual(package_b.result, "package_b")
math = hi.import_module("math")
self.assertEqual(math.__name__, "math")
xml_sub_sub_package = hi.import_module("xml.sax.xmlreader")
self.assertEqual(xml_sub_sub_package.__name__, "xml.sax.xmlreader")
subpackage_1 = hi.import_module("package_b.subpackage_1")
self.assertEqual(subpackage_1.result, "subpackage_1")
subpackage_2 = hi.import_module("package_b.subpackage_2")
self.assertEqual(subpackage_2.result, "subpackage_2")
subsubpackage_0 = hi.import_module("package_b.subpackage_0.subsubpackage_0")
self.assertEqual(subsubpackage_0.result, "subsubpackage_0")
def test_bad_dunder_imports(self):
"""Test to ensure bad __imports__ don't cause PackageExporter to fail."""
buffer = BytesIO()
with PackageExporter(buffer) as e:
e.save_source_string(
"m", '__import__(these, unresolvable, "things", wont, crash, me)'
)
def test_save_module_binary(self):
f = BytesIO()
with PackageExporter(f) as he:
import module_a
import package_a
he.save_module(module_a.__name__)
he.save_module(package_a.__name__)
f.seek(0)
hi = PackageImporter(f)
module_a_i = hi.import_module("module_a")
self.assertEqual(module_a_i.result, "module_a")
self.assertIsNot(module_a, module_a_i)
package_a_i = hi.import_module("package_a")
self.assertEqual(package_a_i.result, "package_a")
self.assertIsNot(package_a_i, package_a)
def test_pickle(self):
import package_a.subpackage
obj = package_a.subpackage.PackageASubpackageObject()
obj2 = package_a.PackageAObject(obj)
buffer = BytesIO()
with PackageExporter(buffer) as he:
he.intern("**")
he.save_pickle("obj", "obj.pkl", obj2)
buffer.seek(0)
hi = PackageImporter(buffer)
# check we got dependencies
sp = hi.import_module("package_a.subpackage")
# check we didn't get other stuff
with self.assertRaises(ImportError):
hi.import_module("module_a")
obj_loaded = hi.load_pickle("obj", "obj.pkl")
self.assertIsNot(obj2, obj_loaded)
self.assertIsInstance(obj_loaded.obj, sp.PackageASubpackageObject)
self.assertIsNot(
package_a.subpackage.PackageASubpackageObject, sp.PackageASubpackageObject
)
def test_pickle_long_name_with_protocol_4(self):
import package_a.long_name
container = []
# Indirectly grab the function to avoid pasting a 256 character
# function into the test
package_a.long_name.add_function(container)
buffer = BytesIO()
with PackageExporter(buffer) as exporter:
exporter.intern("**")
exporter.save_pickle("container", "container.pkl", container, pickle_protocol=4)
buffer.seek(0)
importer = PackageImporter(buffer)
unpickled_container = importer.load_pickle("container", "container.pkl")
self.assertIsNot(container, unpickled_container)
self.assertEqual(len(unpickled_container), 1)
self.assertEqual(container[0](), unpickled_container[0]())
def test_exporting_mismatched_code(self):
"""
If an object with the same qualified name is loaded from different
packages, the user should get an error if they try to re-save the
object with the wrong package's source code.
"""
import package_a.subpackage
obj = package_a.subpackage.PackageASubpackageObject()
obj2 = package_a.PackageAObject(obj)
b1 = BytesIO()
with PackageExporter(b1) as pe:
pe.intern("**")
pe.save_pickle("obj", "obj.pkl", obj2)
b1.seek(0)
importer1 = PackageImporter(b1)
loaded1 = importer1.load_pickle("obj", "obj.pkl")
b1.seek(0)
importer2 = PackageImporter(b1)
loaded2 = importer2.load_pickle("obj", "obj.pkl")
def make_exporter():
pe = PackageExporter(BytesIO(), importer=[importer1, sys_importer])
# Ensure that the importer finds the 'PackageAObject' defined in 'importer1' first.
return pe
# This should fail. The 'PackageAObject' type defined from 'importer1'
# is not necessarily the same 'obj2's version of 'PackageAObject'.
pe = make_exporter()
with self.assertRaises(pickle.PicklingError):
pe.save_pickle("obj", "obj.pkl", obj2)
# This should also fail. The 'PackageAObject' type defined from 'importer1'
# is not necessarily the same as the one defined from 'importer2'
pe = make_exporter()
with self.assertRaises(pickle.PicklingError):
pe.save_pickle("obj", "obj.pkl", loaded2)
# This should succeed. The 'PackageAObject' type defined from
# 'importer1' is a match for the one used by loaded1.
pe = make_exporter()
pe.save_pickle("obj", "obj.pkl", loaded1)
def test_save_imported_module(self):
"""Saving a module that came from another PackageImporter should work."""
import package_a.subpackage
obj = package_a.subpackage.PackageASubpackageObject()
obj2 = package_a.PackageAObject(obj)
buffer = BytesIO()
with PackageExporter(buffer) as exporter:
exporter.intern("**")
exporter.save_pickle("model", "model.pkl", obj2)
buffer.seek(0)
importer = PackageImporter(buffer)
imported_obj2 = importer.load_pickle("model", "model.pkl")
imported_obj2_module = imported_obj2.__class__.__module__
# Should export without error.
buffer2 = BytesIO()
with PackageExporter(buffer2, importer=(importer, sys_importer)) as exporter:
exporter.intern("**")
exporter.save_module(imported_obj2_module)
def test_save_imported_module_using_package_importer(self):
"""Exercise a corner case: re-packaging a module that uses `torch_package_importer`"""
import package_a.use_torch_package_importer # noqa: F401
buffer = BytesIO()
with PackageExporter(buffer) as exporter:
exporter.intern("**")
exporter.save_module("package_a.use_torch_package_importer")
buffer.seek(0)
importer = PackageImporter(buffer)
# Should export without error.
buffer2 = BytesIO()
with PackageExporter(buffer2, importer=(importer, sys_importer)) as exporter:
exporter.intern("**")
exporter.save_module("package_a.use_torch_package_importer")
if __name__ == "__main__":
run_tests()