| import collections |
| import importlib.machinery |
| import io |
| import linecache |
| import pickletools |
| import platform |
| import types |
| from collections import defaultdict, OrderedDict |
| from dataclasses import dataclass |
| from enum import Enum |
| from importlib.machinery import SourceFileLoader |
| from pathlib import Path |
| from typing import ( |
| Any, |
| BinaryIO, |
| Callable, |
| cast, |
| DefaultDict, |
| Dict, |
| List, |
| Optional, |
| Sequence, |
| Set, |
| Union, |
| ) |
| |
| import torch |
| from torch.serialization import location_tag, normalize_storage_type |
| from torch.types import Storage |
| from torch.utils.hooks import RemovableHandle |
| |
| from ._digraph import DiGraph |
| from ._importlib import _normalize_path |
| from ._mangling import demangle, is_mangled |
| from ._package_pickler import create_pickler |
| from ._stdlib import is_stdlib_module |
| from .find_file_dependencies import find_files_source_depends_on |
| from .glob_group import GlobGroup, GlobPattern |
| from .importer import Importer, OrderedImporter, sys_importer |
| |
| __all__ = [ |
| "PackagingErrorReason", |
| "EmptyMatchError", |
| "PackagingError", |
| "PackageExporter", |
| ] |
| |
| _gate_torchscript_serialization = True |
| |
| ActionHook = Callable[["PackageExporter", str], None] |
| |
| |
| class _ModuleProviderAction(Enum): |
| """Represents one of the actions that :class:`PackageExporter` can take on a module. |
| |
| See :meth:`PackageExporter.extern` and friends for a description of what the actions do. |
| """ |
| |
| INTERN = 1 |
| EXTERN = 2 |
| MOCK = 3 |
| DENY = 4 |
| # Special case: when a module is mocked, PackageExporter writes out a |
| # `_mock` module that implements our mocking stubs. If we re-package code, |
| # we may encounter a `_mock` module from the original package. If we do, |
| # just ignore it and write a `_mock` module once. |
| REPACKAGED_MOCK_MODULE = 5 |
| # Special case: PackageImporter adds a fake module |
| # (`torch_package_importer`) that allows packaged code to access it. Don't |
| # re-export this. |
| SKIP = 6 |
| |
| |
| class PackagingErrorReason(Enum): |
| """Listing of different reasons a dependency may fail to package. |
| |
| This enum is used to provide good error messages when |
| :class:`PackagingError` is raised. |
| """ |
| |
| def __repr__(self): |
| return f"<{self.__class__.__name__}.{self.name}>" |
| |
| IS_EXTENSION_MODULE = ( |
| "Module is a C extension module. torch.package supports Python modules only." |
| ) |
| NO_DUNDER_FILE = "Module had no __file__ defined." |
| SOURCE_FILE_NOT_FOUND = ( |
| "Module had a __file__, but we could not find it in your filesystem." |
| ) |
| DEPENDENCY_RESOLUTION_FAILED = "Dependency resolution failed." |
| NO_ACTION = ( |
| "Module did not match against any action pattern. Extern, mock, or intern it." |
| ) |
| DENIED = "Module was denied by a pattern." |
| MOCKED_BUT_STILL_USED = ( |
| "Module was mocked out, but is still being used in the package. " |
| "Please intern or extern the mocked modules if objects are supposed to be in " |
| "the package." |
| ) |
| |
| |
| @dataclass |
| class _PatternInfo: |
| """Holds :class:`PackageExporter`-specific info about how to execute matches against""" |
| |
| # What action to take on a module that matches this pattern. |
| action: _ModuleProviderAction |
| # The value of `allow_empty` the user gave when specifying the pattern. |
| allow_empty: bool |
| # Whether this pattern has been matched during packaging. |
| was_matched: bool |
| |
| def __init__(self, action, allow_empty): |
| self.action = action |
| self.allow_empty = allow_empty |
| self.was_matched = False |
| |
| |
| class EmptyMatchError(Exception): |
| """This is an exception that is thrown when a mock or extern is marked as |
| ``allow_empty=False``, and is not matched with any module during packaging. |
| """ |
| |
| pass |
| |
| |
| class PackagingError(Exception): |
| """This exception is raised when there is an issue with exporting a package. |
| ``PackageExporter`` will attempt to gather up all the errors and present |
| them to you at once. |
| """ |
| |
| def __init__(self, dependency_graph: DiGraph, debug=False): |
| # Group errors by reason. |
| broken: Dict[PackagingErrorReason, List[str]] = defaultdict(list) |
| for module_name, attrs in dependency_graph.nodes.items(): |
| error = attrs.get("error") |
| if error is None: |
| continue |
| if error == PackagingErrorReason.NO_ACTION: |
| assert "action" not in attrs |
| broken[error].append(module_name) |
| |
| message = io.StringIO() |
| message.write("\n") |
| |
| for reason, module_names in broken.items(): |
| message.write(f"* {reason.value}\n") |
| for module_name in module_names: |
| message.write(f" {module_name}\n") |
| |
| # Print additional context if it's provided. |
| error_context = dependency_graph.nodes[module_name].get("error_context") |
| if error_context is not None: |
| message.write(f" Context: {error_context}\n") |
| if module_name in _DISALLOWED_MODULES: |
| message.write( |
| " Note: While we usually use modules in the python standard library " |
| f"from the local environment, `{module_name}` has a lot of system " |
| "level access and therefore can pose a security risk. We heavily " |
| f"recommend removing `{module_name}` from your packaged code. However, if that " |
| "is not possible, add it to the extern list by calling " |
| f'PackageExporter.extern("`{module_name}`")\n' |
| ) |
| if debug: |
| module_path = dependency_graph.first_path(module_name) |
| message.write( |
| f" A path to {module_name}: {' -> '.join(module_path)}" |
| ) |
| if not debug: |
| message.write("\n") |
| message.write( |
| "Set debug=True when invoking PackageExporter for a visualization of where " |
| "broken modules are coming from!\n" |
| ) |
| # Save the dependency graph so that tooling can get at it. |
| self.dependency_graph = dependency_graph |
| super().__init__(message.getvalue()) |
| |
| |
| class PackageExporter: |
| """Exporters allow you to write packages of code, pickled Python data, and |
| arbitrary binary and text resources into a self-contained package. |
| |
| Imports can load this code in a hermetic way, such that code is loaded |
| from the package rather than the normal Python import system. This allows |
| for the packaging of PyTorch model code and data so that it can be run |
| on a server or used in the future for transfer learning. |
| |
| The code contained in packages is copied file-by-file from the original |
| source when it is created, and the file format is a specially organized |
| zip file. Future users of the package can unzip the package, and edit the code |
| in order to perform custom modifications to it. |
| |
| The importer for packages ensures that code in the module can only be loaded from |
| within the package, except for modules explicitly listed as external using :meth:`extern`. |
| The file ``extern_modules`` in the zip archive lists all the modules that a package externally depends on. |
| This prevents "implicit" dependencies where the package runs locally because it is importing |
| a locally-installed package, but then fails when the package is copied to another machine. |
| |
| When source code is added to the package, the exporter can optionally scan it |
| for further code dependencies (``dependencies=True``). It looks for import statements, |
| resolves relative references to qualified module names, and performs an action specified by the user |
| (See: :meth:`extern`, :meth:`mock`, and :meth:`intern`). |
| """ |
| |
| """A importer that will be searched in order to find the modules referenced by other modules or by |
| pickled objects. The default module environment just uses sys_importer, which searches the Python environment. |
| """ |
| importer: Importer |
| |
| def __init__( |
| self, |
| f: Union[str, Path, BinaryIO], |
| importer: Union[Importer, Sequence[Importer]] = sys_importer, |
| debug: bool = False, |
| ): |
| """ |
| Create an exporter. |
| |
| Args: |
| f: The location to export to. Can be a ``string``/``Path`` object containing a filename |
| or a binary I/O object. |
| importer: If a single Importer is passed, use that to search for modules. |
| If a sequence of importers are passed, an ``OrderedImporter`` will be constructed out of them. |
| debug: If set to True, add path of broken modules to PackagingErrors. |
| """ |
| torch._C._log_api_usage_once("torch.package.PackageExporter") |
| self.debug = debug |
| if isinstance(f, (Path, str)): |
| f = str(f) |
| self.buffer: Optional[BinaryIO] = None |
| else: # is a byte buffer |
| self.buffer = f |
| |
| self.zip_file = torch._C.PyTorchFileWriter(f) |
| self.zip_file.set_min_version(6) |
| self._written_files: Set[str] = set() |
| |
| self.serialized_reduces: Dict[int, Any] = {} |
| |
| # A graph tracking all the modules and pickle objects added to this |
| # package and the dependencies between them. |
| # - Each node is a module name (or a pickle name that looks like '<foo.obj.pkl>') |
| # - Each directed edge (u, v) means u depends on v. |
| # - Nodes may contain metadata that describe how to write the thing to the zipfile. |
| self.dependency_graph = DiGraph() |
| self.script_module_serializer = torch._C.ScriptModuleSerializer(self.zip_file) |
| self.storage_context = self.script_module_serializer.storage_context() |
| |
| # These are OrderedDicts for compatibility with RemovableHandle. |
| # Generic OrderedDict type annotations are not present until 3.7. |
| # The real type signature is OrderedDict[int, Callable[[PackageExporter, str], None]] |
| self._extern_hooks: OrderedDict = OrderedDict() |
| self._mock_hooks: OrderedDict = OrderedDict() |
| self._intern_hooks: OrderedDict = OrderedDict() |
| |
| if isinstance(importer, Importer): |
| self.importer = importer |
| else: |
| if not isinstance(importer, collections.abc.Sequence): |
| raise TypeError( |
| "importer arg should be an Importer or a sequence of Importers, " |
| f"got {type(importer)} instead." |
| ) |
| self.importer = OrderedImporter(*importer) |
| |
| self.patterns: Dict[GlobGroup, _PatternInfo] = {} |
| self._unique_id = 0 |
| |
| def save_source_file( |
| self, module_name: str, file_or_directory: str, dependencies=True |
| ): |
| """Adds the local file system ``file_or_directory`` to the source package to provide the code |
| for ``module_name``. |
| |
| Args: |
| module_name (str): e.g. ``"my_package.my_subpackage"``, code will be saved to provide code for this package. |
| file_or_directory (str): the path to a file or directory of code. When a directory, all python files in the directory |
| are recursively copied using :meth:`save_source_file`. If a file is named ``"/__init__.py"`` the code is treated |
| as a package. |
| dependencies (bool, optional): If ``True``, we scan the source for dependencies. |
| """ |
| path = Path(file_or_directory) |
| if path.is_dir(): |
| to_save = [] # list of tuples with arguments to save_source_string |
| module_path = module_name.replace(".", "/") |
| for filename in path.glob("**/*.py"): |
| relative_path = filename.relative_to(path).as_posix() |
| archivename = module_path + "/" + relative_path |
| submodule_name = None |
| if filename.name == "__init__.py": |
| submodule_name = archivename[: -len("/__init__.py")].replace( |
| "/", "." |
| ) |
| is_package = True |
| else: |
| submodule_name = archivename[: -len(".py")].replace("/", ".") |
| is_package = False |
| |
| # we delay the call to save_source_string so that we record all the source files |
| # being provided by this directory structure _before_ attempting to resolve the dependencies |
| # on the source. This makes sure we don't try to copy over modules that will just get |
| # overwritten by this directory blob |
| to_save.append( |
| ( |
| submodule_name, |
| _read_file(str(filename)), |
| is_package, |
| dependencies, |
| ) |
| ) |
| |
| for item in to_save: |
| self.save_source_string(*item) |
| else: |
| is_package = path.name == "__init__.py" |
| self.save_source_string( |
| module_name, |
| _read_file(file_or_directory), |
| is_package, |
| dependencies, |
| ) |
| |
| def get_unique_id(self) -> str: |
| """Get an id. This id is guaranteed to only be handed out once for this package.""" |
| ret = str(self._unique_id) |
| self._unique_id += 1 |
| return ret |
| |
| def _get_dependencies( |
| self, src: str, module_name: str, is_package: bool |
| ) -> List[str]: |
| """Return all modules that this source code depends on. |
| |
| Dependencies are found by scanning the source code for import-like statements. |
| |
| Arguments: |
| src: The Python source code to analyze for dependencies. |
| module_name: The name of the module that ``src`` corresponds to. |
| is_package: Whether this module should be treated as a package. |
| See :py:meth:`save_source_string` for more info. |
| |
| Returns: |
| A list containing modules detected as direct dependencies in |
| ``src``. The items in the list are guaranteed to be unique. |
| """ |
| package_name = ( |
| module_name if is_package else module_name.rsplit(".", maxsplit=1)[0] |
| ) |
| try: |
| dep_pairs = find_files_source_depends_on(src, package_name) |
| except Exception as e: |
| self.dependency_graph.add_node( |
| module_name, |
| error=PackagingErrorReason.DEPENDENCY_RESOLUTION_FAILED, |
| error_context=str(e), |
| ) |
| return [] |
| |
| # Use a dict to get uniquing but also deterministic order |
| dependencies = {} |
| for dep_module_name, dep_module_obj in dep_pairs: |
| # handle the case where someone did something like `from pack import sub` |
| # where `sub` is a submodule. In this case we don't have to save pack, just sub. |
| # this ensures we don't pick up additional dependencies on pack. |
| # However, in the case where `sub` is not a submodule but an object, then we do have |
| # to save pack. |
| if dep_module_obj is not None: |
| possible_submodule = f"{dep_module_name}.{dep_module_obj}" |
| if self._module_exists(possible_submodule): |
| dependencies[possible_submodule] = True |
| # we don't need to save `pack` |
| continue |
| if self._module_exists(dep_module_name): |
| dependencies[dep_module_name] = True |
| |
| return list(dependencies.keys()) |
| |
| def save_source_string( |
| self, |
| module_name: str, |
| src: str, |
| is_package: bool = False, |
| dependencies: bool = True, |
| ): |
| """Adds ``src`` as the source code for ``module_name`` in the exported package. |
| |
| Args: |
| module_name (str): e.g. ``my_package.my_subpackage``, code will be saved to provide code for this package. |
| src (str): The Python source code to save for this package. |
| is_package (bool, optional): If ``True``, this module is treated as a package. Packages are allowed to have submodules |
| (e.g. ``my_package.my_subpackage.my_subsubpackage``), and resources can be saved inside them. Defaults to ``False``. |
| dependencies (bool, optional): If ``True``, we scan the source for dependencies. |
| """ |
| self.dependency_graph.add_node( |
| module_name, |
| source=src, |
| is_package=is_package, |
| provided=True, |
| action=_ModuleProviderAction.INTERN, |
| ) |
| |
| if dependencies: |
| deps = self._get_dependencies(src, module_name, is_package) |
| |
| for dep in deps: |
| self.dependency_graph.add_edge(module_name, dep) |
| self.add_dependency(dep) |
| |
| def _write_source_string( |
| self, |
| module_name: str, |
| src: str, |
| is_package: bool = False, |
| ): |
| """Write ``src`` as the source code for ``module_name`` in the zip archive. |
| |
| Arguments are otherwise the same as for :meth:`save_source_string`. |
| """ |
| extension = "/__init__.py" if is_package else ".py" |
| filename = module_name.replace(".", "/") + extension |
| |
| self._write(filename, src) |
| |
| def _import_module(self, module_name: str): |
| try: |
| return self.importer.import_module(module_name) |
| except ModuleNotFoundError as e: |
| if not is_mangled(module_name): |
| raise |
| msg = ( |
| f"Module not found: '{module_name}'. Make sure the PackageImporter that " |
| "created this module is present in `self.importer`" |
| ) |
| raise ModuleNotFoundError(msg) from None |
| |
| def _module_exists(self, module_name: str) -> bool: |
| try: |
| self._import_module(module_name) |
| return True |
| except Exception: |
| return False |
| |
| def _get_source_of_module(self, module: types.ModuleType) -> Optional[str]: |
| filename = None |
| spec = getattr(module, "__spec__", None) |
| if spec is not None: |
| loader = getattr(spec, "loader", None) |
| if loader is not None and isinstance(loader, SourceFileLoader): |
| try: |
| filename = loader.get_filename(module.__name__) |
| except ImportError: |
| pass |
| if filename is None: |
| filename = getattr(module, "__file__", None) |
| if isinstance(filename, str) and filename.endswith(".py"): |
| return "".join(linecache.getlines(filename, module.__dict__)) |
| return None |
| |
| def add_dependency(self, module_name: str, dependencies=True): |
| """Given a module, add it to the dependency graph according to patterns |
| specified by the user. |
| """ |
| if ( |
| module_name in self.dependency_graph |
| and self.dependency_graph.nodes[module_name].get("provided") is True |
| ): |
| return |
| |
| # Special case: PackageImporter provides a special module called |
| # `torch_package_importer` that allows packaged modules to reference |
| # their PackageImporter. We don't want to re-export this. |
| if module_name == "torch_package_importer": |
| self.dependency_graph.add_node( |
| module_name, |
| action=_ModuleProviderAction.SKIP, |
| provided=True, |
| ) |
| return |
| |
| if module_name == "_mock": |
| self.dependency_graph.add_node( |
| module_name, |
| action=_ModuleProviderAction.REPACKAGED_MOCK_MODULE, |
| provided=True, |
| ) |
| return |
| |
| if self._can_implicitly_extern(module_name): |
| self.dependency_graph.add_node( |
| module_name, action=_ModuleProviderAction.EXTERN, provided=True |
| ) |
| return |
| |
| for pattern, pattern_info in self.patterns.items(): |
| if pattern.matches(module_name): |
| pattern_info.was_matched = True |
| self.dependency_graph.add_node( |
| module_name, action=pattern_info.action, provided=True |
| ) |
| |
| if pattern_info.action == _ModuleProviderAction.DENY: |
| # Requiring a denied module just adds an error to the graph. |
| self.dependency_graph.add_node( |
| module_name, error=PackagingErrorReason.DENIED |
| ) |
| |
| # If we are interning this module, we need to retrieve its |
| # dependencies and package those as well. |
| if pattern_info.action == _ModuleProviderAction.INTERN: |
| self._intern_module(module_name, dependencies) |
| return |
| |
| # No patterns have matched. Explicitly add this as an error. |
| self.dependency_graph.add_node( |
| module_name, error=PackagingErrorReason.NO_ACTION |
| ) |
| |
| def save_module(self, module_name: str, dependencies=True): |
| """Save the code for ``module`` into the package. Code for the module is resolved using the ``importers`` path to find the |
| module object, and then using its ``__file__`` attribute to find the source code. |
| |
| Args: |
| module_name (str): e.g. ``my_package.my_subpackage``, code will be saved to provide code |
| for this package. |
| dependencies (bool, optional): If ``True``, we scan the source for dependencies. |
| """ |
| if not isinstance(module_name, str): |
| raise TypeError( |
| "save_module() expects a string input, did you perhaps mean to pass `__name__`?" |
| ) |
| |
| self._intern_module(module_name, dependencies) |
| |
| def _intern_module( |
| self, |
| module_name: str, |
| dependencies: bool, |
| ): |
| """Adds the module to the dependency graph as an interned module, |
| along with any metadata needed to write it out to the zipfile at serialization time. |
| """ |
| module_obj = self._import_module(module_name) |
| # Subtle: if the import above succeeded, either: |
| # 1. The module name is not mangled, and this was just a regular import, or |
| # 2. The module name is mangled, but one of the importers was able to |
| # recognize the mangling and import it. |
| # Either way, it is now safe to demangle this name so that we don't |
| # serialize the mangled version to the package. |
| module_name = demangle(module_name) |
| |
| # Find dependencies of this module and require them as well. |
| is_package = hasattr(module_obj, "__path__") |
| source = self._get_source_of_module(module_obj) |
| if source is None: |
| # Couldn't find a source! Add it to our dependency graph as broken |
| # and continue. |
| filename = getattr(module_obj, "__file__", None) |
| error_context = None |
| if filename is None: |
| packaging_error = PackagingErrorReason.NO_DUNDER_FILE |
| elif filename.endswith(tuple(importlib.machinery.EXTENSION_SUFFIXES)): |
| packaging_error = PackagingErrorReason.IS_EXTENSION_MODULE |
| else: |
| packaging_error = PackagingErrorReason.SOURCE_FILE_NOT_FOUND |
| error_context = f"filename: {filename}" |
| self.dependency_graph.add_node( |
| module_name, |
| action=_ModuleProviderAction.INTERN, |
| is_package=is_package, |
| error=packaging_error, |
| error_context=error_context, |
| provided=True, |
| ) |
| return |
| |
| self.dependency_graph.add_node( |
| module_name, |
| action=_ModuleProviderAction.INTERN, |
| is_package=is_package, |
| source=source, |
| provided=True, |
| ) |
| |
| if dependencies: |
| deps = self._get_dependencies(source, module_name, is_package) |
| for dep in deps: |
| self.dependency_graph.add_edge(module_name, dep) |
| self.add_dependency(dep) |
| |
| def save_pickle( |
| self, |
| package: str, |
| resource: str, |
| obj: Any, |
| dependencies: bool = True, |
| pickle_protocol: int = 3, |
| ): |
| """Save a python object to the archive using pickle. Equivalent to :func:`torch.save` but saving into |
| the archive rather than a stand-alone file. Standard pickle does not save the code, only the objects. |
| If ``dependencies`` is true, this method will also scan the pickled objects for which modules are required |
| to reconstruct them and save the relevant code. |
| |
| To be able to save an object where ``type(obj).__name__`` is ``my_module.MyObject``, |
| ``my_module.MyObject`` must resolve to the class of the object according to the ``importer`` order. When saving objects that |
| have previously been packaged, the importer's ``import_module`` method will need to be present in the ``importer`` list |
| for this to work. |
| |
| Args: |
| package (str): The name of module package this resource should go in (e.g. ``"my_package.my_subpackage"``). |
| resource (str): A unique name for the resource, used to identify it to load. |
| obj (Any): The object to save, must be picklable. |
| dependencies (bool, optional): If ``True``, we scan the source for dependencies. |
| """ |
| |
| assert (pickle_protocol == 4) or ( |
| pickle_protocol == 3 |
| ), "torch.package only supports pickle protocols 3 and 4" |
| |
| filename = self._filename(package, resource) |
| # Write the pickle data for `obj` |
| data_buf = io.BytesIO() |
| pickler = create_pickler(data_buf, self.importer, protocol=pickle_protocol) |
| pickler.persistent_id = self._persistent_id |
| pickler.dump(obj) |
| data_value = data_buf.getvalue() |
| mocked_modules = defaultdict(list) |
| name_in_dependency_graph = f"<{package}.{resource}>" |
| self.dependency_graph.add_node( |
| name_in_dependency_graph, |
| action=_ModuleProviderAction.INTERN, |
| provided=True, |
| is_pickle=True, |
| ) |
| |
| def _check_mocked_error(module: Optional[str], field: Optional[str]): |
| """ |
| checks if an object (field) comes from a mocked module and then adds |
| the pair to mocked_modules which contains mocked modules paired with their |
| list of mocked objects present in the pickle. |
| |
| We also hold the invariant that the first user defined rule that applies |
| to the module is the one we use. |
| """ |
| |
| assert isinstance(module, str) |
| assert isinstance(field, str) |
| if self._can_implicitly_extern(module): |
| return |
| for pattern, pattern_info in self.patterns.items(): |
| if pattern.matches(module): |
| if pattern_info.action == _ModuleProviderAction.MOCK: |
| mocked_modules[module].append(field) |
| return |
| |
| if dependencies: |
| all_dependencies = [] |
| module = None |
| field = None |
| memo: DefaultDict[int, str] = defaultdict(None) |
| memo_count = 0 |
| # pickletools.dis(data_value) |
| for opcode, arg, pos in pickletools.genops(data_value): |
| if pickle_protocol == 4: |
| if ( |
| opcode.name == "SHORT_BINUNICODE" |
| or opcode.name == "BINUNICODE" |
| or opcode.name == "BINUNICODE8" |
| ): |
| assert isinstance(arg, str) |
| module = field |
| field = arg |
| memo[memo_count] = arg |
| elif ( |
| opcode.name == "LONG_BINGET" |
| or opcode.name == "BINGET" |
| or opcode.name == "GET" |
| ): |
| assert isinstance(arg, int) |
| module = field |
| field = memo.get(arg, None) |
| elif opcode.name == "MEMOIZE": |
| memo_count += 1 |
| elif opcode.name == "STACK_GLOBAL": |
| if module is None: |
| # If not module was passed on in the entries preceeding this one, continue. |
| continue |
| assert isinstance(module, str) |
| if module not in all_dependencies: |
| all_dependencies.append(module) |
| _check_mocked_error(module, field) |
| elif ( |
| pickle_protocol == 3 and opcode.name == "GLOBAL" |
| ): # a global reference |
| assert isinstance(arg, str) |
| module, field = arg.split(" ") |
| if module not in all_dependencies: |
| all_dependencies.append(module) |
| _check_mocked_error(module, field) |
| for module_name in all_dependencies: |
| self.dependency_graph.add_edge(name_in_dependency_graph, module_name) |
| |
| """ If an object happens to come from a mocked module, then we collect these errors and spit them |
| out with the other errors found by package exporter. |
| """ |
| if module in mocked_modules: |
| assert isinstance(module, str) |
| fields = mocked_modules[module] |
| self.dependency_graph.add_node( |
| module_name, |
| action=_ModuleProviderAction.MOCK, |
| error=PackagingErrorReason.MOCKED_BUT_STILL_USED, |
| error_context=f"Object(s) '{fields}' from module `{module_name}` was mocked out during packaging " |
| f"but is being used in resource - `{resource}` in package `{package}`. ", |
| provided=True, |
| ) |
| else: |
| self.add_dependency(module_name) |
| |
| self._write(filename, data_value) |
| |
| def save_text(self, package: str, resource: str, text: str): |
| """Save text data to the package. |
| |
| Args: |
| package (str): The name of module package this resource should go it (e.g. ``"my_package.my_subpackage"``). |
| resource (str): A unique name for the resource, used to identify it to load. |
| text (str): The contents to save. |
| """ |
| return self.save_binary(package, resource, text.encode("utf-8")) |
| |
| def save_binary(self, package, resource, binary: bytes): |
| """Save raw bytes to the package. |
| |
| Args: |
| package (str): The name of module package this resource should go it (e.g. ``"my_package.my_subpackage"``). |
| resource (str): A unique name for the resource, used to identify it to load. |
| binary (str): The data to save. |
| """ |
| filename = self._filename(package, resource) |
| self._write(filename, binary) |
| |
| def register_extern_hook(self, hook: ActionHook) -> RemovableHandle: |
| """Registers an extern hook on the exporter. |
| |
| The hook will be called each time a module matches against an :meth:`extern` pattern. |
| It should have the following signature:: |
| |
| hook(exporter: PackageExporter, module_name: str) -> None |
| |
| Hooks will be called in order of registration. |
| |
| Returns: |
| :class:`torch.utils.hooks.RemovableHandle`: |
| A handle that can be used to remove the added hook by calling |
| ``handle.remove()``. |
| """ |
| handle = RemovableHandle(self._extern_hooks) |
| self._extern_hooks[handle.id] = hook |
| return handle |
| |
| def register_mock_hook(self, hook: ActionHook) -> RemovableHandle: |
| """Registers a mock hook on the exporter. |
| |
| The hook will be called each time a module matches against a :meth:`mock` pattern. |
| It should have the following signature:: |
| |
| hook(exporter: PackageExporter, module_name: str) -> None |
| |
| Hooks will be called in order of registration. |
| |
| Returns: |
| :class:`torch.utils.hooks.RemovableHandle`: |
| A handle that can be used to remove the added hook by calling |
| ``handle.remove()``. |
| """ |
| handle = RemovableHandle(self._mock_hooks) |
| self._mock_hooks[handle.id] = hook |
| return handle |
| |
| def register_intern_hook(self, hook: ActionHook) -> RemovableHandle: |
| """Registers an intern hook on the exporter. |
| |
| The hook will be called each time a module matches against an :meth:`intern` pattern. |
| It should have the following signature:: |
| |
| hook(exporter: PackageExporter, module_name: str) -> None |
| |
| Hooks will be called in order of registration. |
| |
| Returns: |
| :class:`torch.utils.hooks.RemovableHandle`: |
| A handle that can be used to remove the added hook by calling |
| ``handle.remove()``. |
| """ |
| handle = RemovableHandle(self._intern_hooks) |
| self._intern_hooks[handle.id] = hook |
| return handle |
| |
| def intern( |
| self, |
| include: "GlobPattern", |
| *, |
| exclude: "GlobPattern" = (), |
| allow_empty: bool = True, |
| ): |
| """Specify modules that should be packaged. A module must match some ``intern`` pattern in order to be |
| included in the package and have its dependencies processed recursively. |
| |
| Args: |
| include (Union[List[str], str]): A string e.g. "my_package.my_subpackage", or list of strings |
| for the names of the modules to be externed. This can also be a glob-style pattern, as described in :meth:`mock`. |
| |
| exclude (Union[List[str], str]): An optional pattern that excludes some patterns that match the include string. |
| |
| allow_empty (bool): An optional flag that specifies whether the intern modules specified by this call |
| to the ``intern`` method must be matched to some module during packaging. If an ``intern`` module glob |
| pattern is added with ``allow_empty=False``, and :meth:`close` is called (either explicitly or via ``__exit__``) |
| before any modules match that pattern, an exception is thrown. If ``allow_empty=True``, no such exception is thrown. |
| |
| """ |
| self.patterns[GlobGroup(include, exclude=exclude)] = _PatternInfo( |
| _ModuleProviderAction.INTERN, allow_empty |
| ) |
| |
| def mock( |
| self, |
| include: "GlobPattern", |
| *, |
| exclude: "GlobPattern" = (), |
| allow_empty: bool = True, |
| ): |
| """Replace some required modules with a mock implementation. Mocked modules will return a fake |
| object for any attribute accessed from it. Because we copy file-by-file, the dependency resolution will sometimes |
| find files that are imported by model files but whose functionality is never used |
| (e.g. custom serialization code or training helpers). |
| Use this function to mock this functionality out without having to modify the original code. |
| |
| Args: |
| include (Union[List[str], str]): A string e.g. ``"my_package.my_subpackage"``, or list of strings |
| for the names of the modules to be mocked out. Strings can also be a glob-style pattern |
| string that may match multiple modules. Any required dependencies that match this pattern |
| string will be mocked out automatically. |
| |
| Examples : |
| ``'torch.**'`` -- matches ``torch`` and all submodules of torch, e.g. ``'torch.nn'`` |
| and ``'torch.nn.functional'`` |
| |
| ``'torch.*'`` -- matches ``'torch.nn'`` or ``'torch.functional'``, but not |
| ``'torch.nn.functional'`` |
| |
| exclude (Union[List[str], str]): An optional pattern that excludes some patterns that match the include string. |
| e.g. ``include='torch.**', exclude='torch.foo'`` will mock all torch packages except ``'torch.foo'``, |
| Default: is ``[]``. |
| |
| allow_empty (bool): An optional flag that specifies whether the mock implementation(s) specified by this call |
| to the :meth:`mock` method must be matched to some module during packaging. If a mock is added with |
| ``allow_empty=False``, and :meth:`close` is called (either explicitly or via ``__exit__``) and the mock has |
| not been matched to a module used by the package being exported, an exception is thrown. |
| If ``allow_empty=True``, no such exception is thrown. |
| |
| """ |
| self.patterns[GlobGroup(include, exclude=exclude)] = _PatternInfo( |
| _ModuleProviderAction.MOCK, allow_empty |
| ) |
| |
| def extern( |
| self, |
| include: "GlobPattern", |
| *, |
| exclude: "GlobPattern" = (), |
| allow_empty: bool = True, |
| ): |
| """Include ``module`` in the list of external modules the package can import. |
| This will prevent dependency discovery from saving |
| it in the package. The importer will load an external module directly from the standard import system. |
| Code for extern modules must also exist in the process loading the package. |
| |
| Args: |
| include (Union[List[str], str]): A string e.g. ``"my_package.my_subpackage"``, or list of strings |
| for the names of the modules to be externed. This can also be a glob-style pattern, as |
| described in :meth:`mock`. |
| |
| exclude (Union[List[str], str]): An optional pattern that excludes some patterns that match the |
| include string. |
| |
| allow_empty (bool): An optional flag that specifies whether the extern modules specified by this call |
| to the ``extern`` method must be matched to some module during packaging. If an extern module glob |
| pattern is added with ``allow_empty=False``, and :meth:`close` is called (either explicitly or via |
| ``__exit__``) before any modules match that pattern, an exception is thrown. If ``allow_empty=True``, |
| no such exception is thrown. |
| |
| """ |
| self.patterns[GlobGroup(include, exclude=exclude)] = _PatternInfo( |
| _ModuleProviderAction.EXTERN, allow_empty |
| ) |
| |
| def deny(self, include: "GlobPattern", *, exclude: "GlobPattern" = ()): |
| """Blocklist modules who names match the given glob patterns from the list of modules the package can import. |
| If a dependency on any matching packages is found, a :class:`PackagingError` is raised. |
| |
| Args: |
| include (Union[List[str], str]): A string e.g. ``"my_package.my_subpackage"``, or list of strings |
| for the names of the modules to be externed. This can also be a glob-style pattern, as described in :meth:`mock`. |
| |
| exclude (Union[List[str], str]): An optional pattern that excludes some patterns that match the include string. |
| """ |
| self.patterns[GlobGroup(include, exclude=exclude)] = _PatternInfo( |
| _ModuleProviderAction.DENY, allow_empty=True |
| ) |
| |
| def _persistent_id(self, obj): |
| if torch.is_storage(obj) or isinstance(obj, torch.storage.TypedStorage): |
| storage: Storage |
| if isinstance(obj, torch.storage.TypedStorage): |
| # TODO: Once we decide to break serialization FC, we can |
| # remove this case |
| untyped_storage = obj._untyped_storage |
| storage_type_str = obj.pickle_storage_type() |
| storage_type = getattr(torch, storage_type_str) |
| storage = cast(Storage, untyped_storage) |
| storage_numel = obj.size() |
| |
| elif isinstance(obj, torch.UntypedStorage): |
| untyped_storage = obj |
| storage = cast(Storage, untyped_storage) |
| storage_type = normalize_storage_type(type(storage)) |
| storage_numel = storage.nbytes() |
| else: |
| raise RuntimeError(f"storage type not recognized: {type(obj)}") |
| |
| location = location_tag(storage) |
| |
| # serialize storage if not already written |
| storage_present = self.storage_context.has_storage(storage) |
| storage_id = self.storage_context.get_or_add_storage(storage) |
| if not storage_present: |
| if storage.device.type != "cpu": |
| storage = storage.cpu() |
| num_bytes = storage.nbytes() |
| self.zip_file.write_record( |
| f".data/{storage_id}.storage", storage.data_ptr(), num_bytes |
| ) |
| return ("storage", storage_type, storage_id, location, storage_numel) |
| |
| if hasattr(obj, "__reduce_package__"): |
| if _gate_torchscript_serialization and isinstance( |
| obj, torch.jit.RecursiveScriptModule |
| ): |
| raise Exception( |
| "Serializing ScriptModules directly into a package is a beta feature. " |
| "To use, set global " |
| "`torch.package.package_exporter._gate_torchscript_serialization` to `False`." |
| ) |
| if self.serialized_reduces.get(id(obj)) is None: |
| self.serialized_reduces[id(obj)] = ( |
| "reduce_package", |
| id(obj), |
| *obj.__reduce_package__(self), |
| ) |
| |
| return self.serialized_reduces[id(obj)] |
| |
| return None |
| |
| def __enter__(self): |
| return self |
| |
| def __exit__(self, exc_type, exc_value, traceback): |
| # If __exit__ was called because an exception was raised, we do not |
| # attempt to finalize the package. Instead, control is returned to the |
| # caller to continue raising the exception. |
| if exc_type is not None: |
| # Do the bare minimum to leave the open buffer in a valid state. |
| self._finalize_zip() |
| return |
| |
| self.close() |
| |
| def _write(self, filename, str_or_bytes): |
| if filename in self._written_files: |
| raise AssertionError( |
| f"Tried to write file '{filename}', but it already exists in this archive. " |
| "Please file a bug." |
| ) |
| self._written_files.add(filename) |
| |
| if is_mangled(filename): |
| raise AssertionError( |
| f"Tried to save a torch.package'd module as '{filename}'. " |
| "Directly saving torch.package'd modules is not allowed." |
| ) |
| if isinstance(str_or_bytes, str): |
| str_or_bytes = str_or_bytes.encode("utf-8") |
| self.zip_file.write_record(filename, str_or_bytes, len(str_or_bytes)) |
| |
| def _validate_dependency_graph(self): |
| # 1. Check the graph for any errors inserted during dependency analysis. |
| for attrs in self.dependency_graph.nodes.values(): |
| if "error" in attrs: |
| raise PackagingError(self.dependency_graph, debug=self.debug) |
| |
| # 2. Check that all patterns for which allow_empty=False have been matched at least once. |
| for pattern, pattern_info in self.patterns.items(): |
| if not pattern_info.allow_empty and not pattern_info.was_matched: |
| raise EmptyMatchError( |
| f"Exporter did not match any modules to {pattern}, which was marked as allow_empty=False" |
| ) |
| |
| def _write_mock_file(self): |
| if "_mock.py" not in self._written_files: |
| mock_file = str(Path(__file__).parent / "_mock.py") |
| self._write_source_string("_mock", _read_file(mock_file), is_package=False) |
| |
| def _execute_dependency_graph(self): |
| """Takes a finalized dependency graph describing how to package all |
| modules and executes it, writing to the ZIP archive. |
| """ |
| self._validate_dependency_graph() |
| |
| extern_modules = [] |
| for module_name, attrs in self.dependency_graph.nodes.items(): |
| action = attrs["action"] |
| |
| if action == _ModuleProviderAction.EXTERN: |
| for hook in self._extern_hooks.values(): |
| hook(self, module_name) |
| |
| extern_modules.append(module_name) |
| |
| elif action == _ModuleProviderAction.MOCK: |
| for hook in self._mock_hooks.values(): |
| hook(self, module_name) |
| |
| self._write_mock_file() |
| |
| is_package = hasattr(self._import_module(module_name), "__path__") |
| self._write_source_string(module_name, _MOCK_IMPL, is_package) |
| |
| elif action == _ModuleProviderAction.INTERN: |
| for hook in self._intern_hooks.values(): |
| hook(self, module_name) |
| |
| # The node in the dependency graph contains metadata that tells us |
| # how to intern the module. |
| if "provided" not in attrs: |
| raise AssertionError( |
| f"Module was marked `intern` but not provided: {module_name}" |
| ) |
| |
| if attrs.get("is_pickle") is True: |
| # This node came from save_pickle, we don't need to write any source for it. |
| continue |
| |
| is_package = attrs["is_package"] |
| source = attrs["source"] |
| self._write_source_string(module_name, source, is_package) |
| |
| elif action == _ModuleProviderAction.REPACKAGED_MOCK_MODULE: |
| self._write_mock_file() |
| elif action == _ModuleProviderAction.SKIP: |
| continue |
| else: |
| raise AssertionError( |
| f"Invalid action: {module_name}, {action}. Please report a bug to PyTorch." |
| ) |
| |
| extern_file_contents = "\n".join(extern_modules) + "\n" |
| self._write(".data/extern_modules", extern_file_contents) |
| |
| def _write_python_version(self): |
| """Writes the python version that the package was created with to .data/python_version""" |
| self._write(".data/python_version", platform.python_version()) |
| |
| def close(self): |
| """Write the package to the filesystem. Any calls after :meth:`close` are now invalid. |
| It is preferable to use resource guard syntax instead:: |
| |
| with PackageExporter("file.zip") as e: |
| ... |
| """ |
| self._execute_dependency_graph() |
| self._write_python_version() |
| |
| self.script_module_serializer.write_files() |
| self._finalize_zip() |
| |
| def _finalize_zip(self): |
| """Called at the very end of packaging to leave the zipfile in a closed but valid state.""" |
| del self.zip_file |
| if self.buffer: |
| self.buffer.flush() |
| |
| def _filename(self, package, resource): |
| package_path = package.replace(".", "/") |
| resource = _normalize_path(resource) |
| return f"{package_path}/{resource}" |
| |
| def _can_implicitly_extern(self, module_name: str): |
| top_level_package_name = module_name.partition(".")[0] |
| return top_level_package_name == "torch" or ( |
| top_level_package_name not in _DISALLOWED_MODULES |
| and is_stdlib_module(top_level_package_name) |
| ) |
| |
| def dependency_graph_string(self) -> str: |
| """Returns digraph string representation of dependencies in package. |
| |
| Returns: |
| A string representation of dependencies in package. |
| """ |
| return self.dependency_graph.to_dot() |
| |
| def _nodes_with_action_type( |
| self, action: Optional[_ModuleProviderAction] |
| ) -> List[str]: |
| result = [] |
| for name, node_dict in self.dependency_graph.nodes.items(): |
| node_action = node_dict.get("action", None) |
| if node_action == action and "is_pickle" not in node_dict: |
| result.append(name) |
| result.sort() |
| return result |
| |
| def externed_modules(self) -> List[str]: |
| """Return all modules that are currently externed. |
| |
| Returns: |
| A list containing the names of modules which will be |
| externed in this package. |
| """ |
| return self._nodes_with_action_type(_ModuleProviderAction.EXTERN) |
| |
| def interned_modules(self) -> List[str]: |
| """Return all modules that are currently interned. |
| |
| Returns: |
| A list containing the names of modules which will be |
| interned in this package. |
| """ |
| return self._nodes_with_action_type(_ModuleProviderAction.INTERN) |
| |
| def mocked_modules(self) -> List[str]: |
| """Return all modules that are currently mocked. |
| |
| Returns: |
| A list containing the names of modules which will be |
| mocked in this package. |
| """ |
| return self._nodes_with_action_type(_ModuleProviderAction.MOCK) |
| |
| def denied_modules(self) -> List[str]: |
| """Return all modules that are currently denied. |
| |
| Returns: |
| A list containing the names of modules which will be |
| denied in this package. |
| """ |
| return self._nodes_with_action_type(_ModuleProviderAction.DENY) |
| |
| def get_rdeps(self, module_name: str) -> List[str]: |
| """Return a list of all modules which depend on the module ``module_name``. |
| |
| Returns: |
| A list containing the names of modules which depend on ``module_name``. |
| """ |
| if module_name in self.dependency_graph._pred.keys(): |
| return list(self.dependency_graph._pred[module_name].keys()) |
| else: |
| return [] |
| |
| def all_paths(self, src: str, dst: str) -> str: |
| """Return a dot representation of the subgraph |
| that has all paths from src to dst. |
| |
| Returns: |
| A dot representation containing all paths from src to dst. |
| (https://graphviz.org/doc/info/lang.html) |
| """ |
| return self.dependency_graph.all_paths(src, dst) |
| |
| |
| # even though these are in the standard library, we do not allow them to be |
| # automatically externed since they offer a lot of system level access |
| _DISALLOWED_MODULES = ["sys", "io"] |
| |
| _MOCK_IMPL = """\ |
| from _mock import MockedObject |
| def __getattr__(attr: str): |
| return MockedObject(__name__ + '.' + attr, _suppress_err=True) |
| """ |
| |
| |
| def _read_file(filename: str) -> str: |
| with open(filename, "rb") as f: |
| b = f.read() |
| return b.decode("utf-8") |