| # mypy: allow-untyped-defs |
| import io |
| |
| import torch |
| from torch.package import Importer, OrderedImporter, PackageImporter, sys_importer |
| from torch.package._package_pickler import create_pickler |
| from torch.package._package_unpickler import PackageUnpickler |
| from torch.serialization import _maybe_decode_ascii |
| |
| |
| def _save_storages(importer, obj): |
| serialized_storages = [] |
| serialized_dtypes = [] |
| |
| importer = importer if isinstance(importer, torch.package.PackageImporter) else None |
| importers: Importer |
| if importer is not None: |
| importers = OrderedImporter(importer, sys_importer) |
| else: |
| importers = sys_importer |
| |
| def persistent_id(obj): |
| if torch.is_storage(obj) or isinstance(obj, torch.storage.TypedStorage): |
| if isinstance(obj, torch.storage.TypedStorage): |
| # TODO: Once we decide to break serialization FC, we can |
| # remove this case |
| dtype = obj.dtype |
| else: |
| dtype = torch.uint8 |
| |
| serialized_storages.append(obj) |
| serialized_dtypes.append(dtype) |
| return ("storage", len(serialized_storages) - 1) |
| |
| if hasattr(obj, "__reduce_deploy__"): |
| if _serialized_reduces.get(id(obj)) is None: |
| _serialized_reduces[id(obj)] = ( |
| "reduce_deploy", |
| id(obj), |
| *obj.__reduce_deploy__(importers), |
| ) |
| return _serialized_reduces[id(obj)] |
| |
| return None |
| |
| # Write the pickle data for `obj` |
| data_buf = io.BytesIO() |
| pickler = create_pickler(data_buf, importers) |
| pickler.persistent_id = persistent_id |
| pickler.dump(obj) |
| data_value = data_buf.getvalue() |
| return ( |
| data_value, |
| serialized_storages, |
| serialized_dtypes, |
| importer.zip_reader if importer else None, |
| ) |
| |
| |
| def _load_storages(id, zip_reader, obj_bytes, serialized_storages, serialized_dtypes): |
| def persistent_load(saved_id): |
| assert isinstance(saved_id, tuple) |
| typename = _maybe_decode_ascii(saved_id[0]) |
| data = saved_id[1:] |
| |
| if typename == "storage": |
| # TODO: Once we decide to break serialization FC, we can |
| # stop wrapping with TypedStorage |
| storage = serialized_storages[data[0]] |
| dtype = serialized_dtypes[data[0]] |
| return torch.storage.TypedStorage( |
| wrap_storage=storage.untyped(), dtype=dtype |
| ) |
| |
| if typename == "reduce_deploy": |
| reduce_id, func, args = data |
| if reduce_id not in _loaded_reduces: |
| _loaded_reduces[reduce_id] = func(_raw_packages[zip_reader], *args) |
| return _loaded_reduces[reduce_id] |
| |
| return None |
| |
| importer: Importer |
| if zip_reader is not None: |
| importer = OrderedImporter(_get_package(zip_reader), sys_importer) |
| else: |
| importer = sys_importer |
| |
| unpickler = PackageUnpickler(importer, io.BytesIO(obj_bytes)) |
| unpickler.persistent_load = persistent_load # type: ignore[method-assign] |
| result = _deploy_objects[id] = unpickler.load() |
| return result |
| |
| |
| def _get_package(zip_reader): |
| if zip_reader not in _raw_packages: |
| _raw_packages[zip_reader] = PackageImporter(zip_reader) |
| return _raw_packages[zip_reader] |
| |
| |
| _raw_packages: dict = {} |
| _deploy_objects: dict = {} |
| _serialized_reduces: dict = {} |
| _loaded_reduces: dict = {} |