| import contextlib |
| import errno |
| import hashlib |
| import json |
| import os |
| import re |
| import shutil |
| import sys |
| import tempfile |
| import torch |
| import uuid |
| import warnings |
| import zipfile |
| from pathlib import Path |
| from typing import Dict, Optional, Any |
| from urllib.error import HTTPError, URLError |
| from urllib.request import urlopen, Request |
| from urllib.parse import urlparse # noqa: F401 |
| from torch.serialization import MAP_LOCATION |
| |
| class _Faketqdm: # type: ignore[no-redef] |
| |
| def __init__(self, total=None, disable=False, |
| unit=None, *args, **kwargs): |
| self.total = total |
| self.disable = disable |
| self.n = 0 |
| # Ignore all extra *args and **kwargs lest you want to reinvent tqdm |
| |
| def update(self, n): |
| if self.disable: |
| return |
| |
| self.n += n |
| if self.total is None: |
| sys.stderr.write(f"\r{self.n:.1f} bytes") |
| else: |
| sys.stderr.write(f"\r{100 * self.n / float(self.total):.1f}%") |
| sys.stderr.flush() |
| |
| # Don't bother implementing; use real tqdm if you want |
| def set_description(self, *args, **kwargs): |
| pass |
| |
| def write(self, s): |
| sys.stderr.write(f"{s}\n") |
| |
| def close(self): |
| self.disable = True |
| |
| def __enter__(self): |
| return self |
| |
| def __exit__(self, exc_type, exc_val, exc_tb): |
| if self.disable: |
| return |
| |
| sys.stderr.write('\n') |
| |
| try: |
| from tqdm import tqdm # If tqdm is installed use it, otherwise use the fake wrapper |
| except ImportError: |
| tqdm = _Faketqdm |
| |
| __all__ = [ |
| 'download_url_to_file', |
| 'get_dir', |
| 'help', |
| 'list', |
| 'load', |
| 'load_state_dict_from_url', |
| 'set_dir', |
| ] |
| |
| # matches bfd8deac from resnet18-bfd8deac.pth |
| HASH_REGEX = re.compile(r'-([a-f0-9]*)\.') |
| |
| _TRUSTED_REPO_OWNERS = ("facebookresearch", "facebookincubator", "pytorch", "fairinternal") |
| ENV_GITHUB_TOKEN = 'GITHUB_TOKEN' |
| ENV_TORCH_HOME = 'TORCH_HOME' |
| ENV_XDG_CACHE_HOME = 'XDG_CACHE_HOME' |
| DEFAULT_CACHE_DIR = '~/.cache' |
| VAR_DEPENDENCY = 'dependencies' |
| MODULE_HUBCONF = 'hubconf.py' |
| READ_DATA_CHUNK = 8192 |
| _hub_dir = None |
| |
| |
| @contextlib.contextmanager |
| def _add_to_sys_path(path): |
| sys.path.insert(0, path) |
| try: |
| yield |
| finally: |
| sys.path.remove(path) |
| |
| |
| # Copied from tools/shared/module_loader to be included in torch package |
| def _import_module(name, path): |
| import importlib.util |
| from importlib.abc import Loader |
| spec = importlib.util.spec_from_file_location(name, path) |
| assert spec is not None |
| module = importlib.util.module_from_spec(spec) |
| assert isinstance(spec.loader, Loader) |
| spec.loader.exec_module(module) |
| return module |
| |
| |
| def _remove_if_exists(path): |
| if os.path.exists(path): |
| if os.path.isfile(path): |
| os.remove(path) |
| else: |
| shutil.rmtree(path) |
| |
| |
| def _git_archive_link(repo_owner, repo_name, ref): |
| # See https://docs.github.com/en/rest/reference/repos#download-a-repository-archive-zip |
| return f"https://github.com/{repo_owner}/{repo_name}/zipball/{ref}" |
| |
| |
| def _load_attr_from_module(module, func_name): |
| # Check if callable is defined in the module |
| if func_name not in dir(module): |
| return None |
| return getattr(module, func_name) |
| |
| |
| def _get_torch_home(): |
| torch_home = os.path.expanduser( |
| os.getenv(ENV_TORCH_HOME, |
| os.path.join(os.getenv(ENV_XDG_CACHE_HOME, |
| DEFAULT_CACHE_DIR), 'torch'))) |
| return torch_home |
| |
| |
| def _parse_repo_info(github): |
| if ':' in github: |
| repo_info, ref = github.split(':') |
| else: |
| repo_info, ref = github, None |
| repo_owner, repo_name = repo_info.split('/') |
| |
| if ref is None: |
| # The ref wasn't specified by the user, so we need to figure out the |
| # default branch: main or master. Our assumption is that if main exists |
| # then it's the default branch, otherwise it's master. |
| try: |
| with urlopen(f"https://github.com/{repo_owner}/{repo_name}/tree/main/"): |
| ref = 'main' |
| except HTTPError as e: |
| if e.code == 404: |
| ref = 'master' |
| else: |
| raise |
| except URLError as e: |
| # No internet connection, need to check for cache as last resort |
| for possible_ref in ("main", "master"): |
| if os.path.exists(f"{get_dir()}/{repo_owner}_{repo_name}_{possible_ref}"): |
| ref = possible_ref |
| break |
| if ref is None: |
| raise RuntimeError( |
| "It looks like there is no internet connection and the " |
| f"repo could not be found in the cache ({get_dir()})" |
| ) from e |
| return repo_owner, repo_name, ref |
| |
| |
| def _read_url(url): |
| with urlopen(url) as r: |
| return r.read().decode(r.headers.get_content_charset('utf-8')) |
| |
| |
| def _validate_not_a_forked_repo(repo_owner, repo_name, ref): |
| # Use urlopen to avoid depending on local git. |
| headers = {'Accept': 'application/vnd.github.v3+json'} |
| token = os.environ.get(ENV_GITHUB_TOKEN) |
| if token is not None: |
| headers['Authorization'] = f'token {token}' |
| for url_prefix in ( |
| f'https://api.github.com/repos/{repo_owner}/{repo_name}/branches', |
| f'https://api.github.com/repos/{repo_owner}/{repo_name}/tags'): |
| page = 0 |
| while True: |
| page += 1 |
| url = f'{url_prefix}?per_page=100&page={page}' |
| response = json.loads(_read_url(Request(url, headers=headers))) |
| # Empty response means no more data to process |
| if not response: |
| break |
| for br in response: |
| if br['name'] == ref or br['commit']['sha'].startswith(ref): |
| return |
| |
| raise ValueError(f'Cannot find {ref} in https://github.com/{repo_owner}/{repo_name}. ' |
| 'If it\'s a commit from a forked repo, please call hub.load() with forked repo directly.') |
| |
| |
| def _get_cache_or_reload(github, force_reload, trust_repo, calling_fn, verbose=True, skip_validation=False): |
| # Setup hub_dir to save downloaded files |
| hub_dir = get_dir() |
| if not os.path.exists(hub_dir): |
| os.makedirs(hub_dir) |
| # Parse github repo information |
| repo_owner, repo_name, ref = _parse_repo_info(github) |
| # Github allows branch name with slash '/', |
| # this causes confusion with path on both Linux and Windows. |
| # Backslash is not allowed in Github branch name so no need to |
| # to worry about it. |
| normalized_br = ref.replace('/', '_') |
| # Github renames folder repo-v1.x.x to repo-1.x.x |
| # We don't know the repo name before downloading the zip file |
| # and inspect name from it. |
| # To check if cached repo exists, we need to normalize folder names. |
| owner_name_branch = '_'.join([repo_owner, repo_name, normalized_br]) |
| repo_dir = os.path.join(hub_dir, owner_name_branch) |
| # Check that the repo is in the trusted list |
| _check_repo_is_trusted(repo_owner, repo_name, owner_name_branch, trust_repo=trust_repo, calling_fn=calling_fn) |
| |
| use_cache = (not force_reload) and os.path.exists(repo_dir) |
| |
| if use_cache: |
| if verbose: |
| sys.stderr.write(f'Using cache found in {repo_dir}\n') |
| else: |
| # Validate the tag/branch is from the original repo instead of a forked repo |
| if not skip_validation: |
| _validate_not_a_forked_repo(repo_owner, repo_name, ref) |
| |
| cached_file = os.path.join(hub_dir, normalized_br + '.zip') |
| _remove_if_exists(cached_file) |
| |
| try: |
| url = _git_archive_link(repo_owner, repo_name, ref) |
| sys.stderr.write(f'Downloading: \"{url}\" to {cached_file}\n') |
| download_url_to_file(url, cached_file, progress=False) |
| except HTTPError as err: |
| if err.code == 300: |
| # Getting a 300 Multiple Choices error likely means that the ref is both a tag and a branch |
| # in the repo. This can be disambiguated by explicitely using refs/heads/ or refs/tags |
| # See https://git-scm.com/book/en/v2/Git-Internals-Git-References |
| # Here, we do the same as git: we throw a warning, and assume the user wanted the branch |
| warnings.warn( |
| f"The ref {ref} is ambiguous. Perhaps it is both a tag and a branch in the repo? " |
| "Torchhub will now assume that it's a branch. " |
| "You can disambiguate tags and branches by explicitly passing refs/heads/branch_name or " |
| "refs/tags/tag_name as the ref. That might require using skip_validation=True." |
| ) |
| disambiguated_branch_ref = f"refs/heads/{ref}" |
| url = _git_archive_link(repo_owner, repo_name, ref=disambiguated_branch_ref) |
| download_url_to_file(url, cached_file, progress=False) |
| else: |
| raise |
| |
| with zipfile.ZipFile(cached_file) as cached_zipfile: |
| extraced_repo_name = cached_zipfile.infolist()[0].filename |
| extracted_repo = os.path.join(hub_dir, extraced_repo_name) |
| _remove_if_exists(extracted_repo) |
| # Unzip the code and rename the base folder |
| cached_zipfile.extractall(hub_dir) |
| |
| _remove_if_exists(cached_file) |
| _remove_if_exists(repo_dir) |
| shutil.move(extracted_repo, repo_dir) # rename the repo |
| |
| return repo_dir |
| |
| |
| def _check_repo_is_trusted(repo_owner, repo_name, owner_name_branch, trust_repo, calling_fn="load"): |
| hub_dir = get_dir() |
| filepath = os.path.join(hub_dir, "trusted_list") |
| |
| if not os.path.exists(filepath): |
| Path(filepath).touch() |
| with open(filepath) as file: |
| trusted_repos = tuple(line.strip() for line in file) |
| |
| # To minimize friction of introducing the new trust_repo mechanism, we consider that |
| # if a repo was already downloaded by torchhub, then it is already trusted (even if it's not in the allowlist) |
| trusted_repos_legacy = next(os.walk(hub_dir))[1] |
| |
| owner_name = '_'.join([repo_owner, repo_name]) |
| is_trusted = ( |
| owner_name in trusted_repos |
| or owner_name_branch in trusted_repos_legacy |
| or repo_owner in _TRUSTED_REPO_OWNERS |
| ) |
| |
| # TODO: Remove `None` option in 2.0 and change the default to "check" |
| if trust_repo is None: |
| if not is_trusted: |
| warnings.warn( |
| "You are about to download and run code from an untrusted repository. In a future release, this won't " |
| "be allowed. To add the repository to your trusted list, change the command to {calling_fn}(..., " |
| "trust_repo=False) and a command prompt will appear asking for an explicit confirmation of trust, " |
| f"or {calling_fn}(..., trust_repo=True), which will assume that the prompt is to be answered with " |
| f"'yes'. You can also use {calling_fn}(..., trust_repo='check') which will only prompt for " |
| f"confirmation if the repo is not already trusted. This will eventually be the default behaviour") |
| return |
| |
| if (trust_repo is False) or (trust_repo == "check" and not is_trusted): |
| response = input( |
| f"The repository {owner_name} does not belong to the list of trusted repositories and as such cannot be downloaded. " |
| "Do you trust this repository and wish to add it to the trusted list of repositories (y/N)?") |
| if response.lower() in ("y", "yes"): |
| if is_trusted: |
| print("The repository is already trusted.") |
| elif response.lower() in ("n", "no", ""): |
| raise Exception("Untrusted repository.") |
| else: |
| raise ValueError(f"Unrecognized response {response}.") |
| |
| # At this point we're sure that the user trusts the repo (or wants to trust it) |
| if not is_trusted: |
| with open(filepath, "a") as file: |
| file.write(owner_name + "\n") |
| |
| |
| def _check_module_exists(name): |
| import importlib.util |
| return importlib.util.find_spec(name) is not None |
| |
| |
| def _check_dependencies(m): |
| dependencies = _load_attr_from_module(m, VAR_DEPENDENCY) |
| |
| if dependencies is not None: |
| missing_deps = [pkg for pkg in dependencies if not _check_module_exists(pkg)] |
| if len(missing_deps): |
| raise RuntimeError(f"Missing dependencies: {', '.join(missing_deps)}") |
| |
| |
| def _load_entry_from_hubconf(m, model): |
| if not isinstance(model, str): |
| raise ValueError('Invalid input: model should be a string of function name') |
| |
| # Note that if a missing dependency is imported at top level of hubconf, it will |
| # throw before this function. It's a chicken and egg situation where we have to |
| # load hubconf to know what're the dependencies, but to import hubconf it requires |
| # a missing package. This is fine, Python will throw proper error message for users. |
| _check_dependencies(m) |
| |
| func = _load_attr_from_module(m, model) |
| |
| if func is None or not callable(func): |
| raise RuntimeError(f'Cannot find callable {model} in hubconf') |
| |
| return func |
| |
| |
| def get_dir(): |
| r""" |
| Get the Torch Hub cache directory used for storing downloaded models & weights. |
| |
| If :func:`~torch.hub.set_dir` is not called, default path is ``$TORCH_HOME/hub`` where |
| environment variable ``$TORCH_HOME`` defaults to ``$XDG_CACHE_HOME/torch``. |
| ``$XDG_CACHE_HOME`` follows the X Design Group specification of the Linux |
| filesystem layout, with a default value ``~/.cache`` if the environment |
| variable is not set. |
| """ |
| # Issue warning to move data if old env is set |
| if os.getenv('TORCH_HUB'): |
| warnings.warn('TORCH_HUB is deprecated, please use env TORCH_HOME instead') |
| |
| if _hub_dir is not None: |
| return _hub_dir |
| return os.path.join(_get_torch_home(), 'hub') |
| |
| |
| def set_dir(d): |
| r""" |
| Optionally set the Torch Hub directory used to save downloaded models & weights. |
| |
| Args: |
| d (str): path to a local folder to save downloaded models & weights. |
| """ |
| global _hub_dir |
| _hub_dir = os.path.expanduser(d) |
| |
| |
| def list(github, force_reload=False, skip_validation=False, trust_repo=None): |
| r""" |
| List all callable entrypoints available in the repo specified by ``github``. |
| |
| Args: |
| github (str): a string with format "repo_owner/repo_name[:ref]" with an optional |
| ref (tag or branch). If ``ref`` is not specified, the default branch is assumed to be ``main`` if |
| it exists, and otherwise ``master``. |
| Example: 'pytorch/vision:0.10' |
| force_reload (bool, optional): whether to discard the existing cache and force a fresh download. |
| Default is ``False``. |
| skip_validation (bool, optional): if ``False``, torchhub will check that the branch or commit |
| specified by the ``github`` argument properly belongs to the repo owner. This will make |
| requests to the GitHub API; you can specify a non-default GitHub token by setting the |
| ``GITHUB_TOKEN`` environment variable. Default is ``False``. |
| trust_repo (bool, str or None): ``"check"``, ``True``, ``False`` or ``None``. |
| This parameter was introduced in v1.12 and helps ensuring that users |
| only run code from repos that they trust. |
| |
| - If ``False``, a prompt will ask the user whether the repo should |
| be trusted. |
| - If ``True``, the repo will be added to the trusted list and loaded |
| without requiring explicit confirmation. |
| - If ``"check"``, the repo will be checked against the list of |
| trusted repos in the cache. If it is not present in that list, the |
| behaviour will fall back onto the ``trust_repo=False`` option. |
| - If ``None``: this will raise a warning, inviting the user to set |
| ``trust_repo`` to either ``False``, ``True`` or ``"check"``. This |
| is only present for backward compatibility and will be removed in |
| v2.0. |
| |
| Default is ``None`` and will eventually change to ``"check"`` in v2.0. |
| |
| Returns: |
| list: The available callables entrypoint |
| |
| Example: |
| >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_HUB) |
| >>> entrypoints = torch.hub.list('pytorch/vision', force_reload=True) |
| """ |
| repo_dir = _get_cache_or_reload(github, force_reload, trust_repo, "list", verbose=True, |
| skip_validation=skip_validation) |
| |
| with _add_to_sys_path(repo_dir): |
| hubconf_path = os.path.join(repo_dir, MODULE_HUBCONF) |
| hub_module = _import_module(MODULE_HUBCONF, hubconf_path) |
| |
| # We take functions starts with '_' as internal helper functions |
| entrypoints = [f for f in dir(hub_module) if callable(getattr(hub_module, f)) and not f.startswith('_')] |
| |
| return entrypoints |
| |
| |
| def help(github, model, force_reload=False, skip_validation=False, trust_repo=None): |
| r""" |
| Show the docstring of entrypoint ``model``. |
| |
| Args: |
| github (str): a string with format <repo_owner/repo_name[:ref]> with an optional |
| ref (a tag or a branch). If ``ref`` is not specified, the default branch is assumed |
| to be ``main`` if it exists, and otherwise ``master``. |
| Example: 'pytorch/vision:0.10' |
| model (str): a string of entrypoint name defined in repo's ``hubconf.py`` |
| force_reload (bool, optional): whether to discard the existing cache and force a fresh download. |
| Default is ``False``. |
| skip_validation (bool, optional): if ``False``, torchhub will check that the ref |
| specified by the ``github`` argument properly belongs to the repo owner. This will make |
| requests to the GitHub API; you can specify a non-default GitHub token by setting the |
| ``GITHUB_TOKEN`` environment variable. Default is ``False``. |
| trust_repo (bool, str or None): ``"check"``, ``True``, ``False`` or ``None``. |
| This parameter was introduced in v1.12 and helps ensuring that users |
| only run code from repos that they trust. |
| |
| - If ``False``, a prompt will ask the user whether the repo should |
| be trusted. |
| - If ``True``, the repo will be added to the trusted list and loaded |
| without requiring explicit confirmation. |
| - If ``"check"``, the repo will be checked against the list of |
| trusted repos in the cache. If it is not present in that list, the |
| behaviour will fall back onto the ``trust_repo=False`` option. |
| - If ``None``: this will raise a warning, inviting the user to set |
| ``trust_repo`` to either ``False``, ``True`` or ``"check"``. This |
| is only present for backward compatibility and will be removed in |
| v2.0. |
| |
| Default is ``None`` and will eventually change to ``"check"`` in v2.0. |
| Example: |
| >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_HUB) |
| >>> print(torch.hub.help('pytorch/vision', 'resnet18', force_reload=True)) |
| """ |
| repo_dir = _get_cache_or_reload(github, force_reload, trust_repo, "help", verbose=True, |
| skip_validation=skip_validation) |
| |
| with _add_to_sys_path(repo_dir): |
| hubconf_path = os.path.join(repo_dir, MODULE_HUBCONF) |
| hub_module = _import_module(MODULE_HUBCONF, hubconf_path) |
| |
| entry = _load_entry_from_hubconf(hub_module, model) |
| |
| return entry.__doc__ |
| |
| |
| def load(repo_or_dir, model, *args, source='github', trust_repo=None, force_reload=False, verbose=True, |
| skip_validation=False, |
| **kwargs): |
| r""" |
| Load a model from a github repo or a local directory. |
| |
| Note: Loading a model is the typical use case, but this can also be used to |
| for loading other objects such as tokenizers, loss functions, etc. |
| |
| If ``source`` is 'github', ``repo_or_dir`` is expected to be |
| of the form ``repo_owner/repo_name[:ref]`` with an optional |
| ref (a tag or a branch). |
| |
| If ``source`` is 'local', ``repo_or_dir`` is expected to be a |
| path to a local directory. |
| |
| Args: |
| repo_or_dir (str): If ``source`` is 'github', |
| this should correspond to a github repo with format ``repo_owner/repo_name[:ref]`` with |
| an optional ref (tag or branch), for example 'pytorch/vision:0.10'. If ``ref`` is not specified, |
| the default branch is assumed to be ``main`` if it exists, and otherwise ``master``. |
| If ``source`` is 'local' then it should be a path to a local directory. |
| model (str): the name of a callable (entrypoint) defined in the |
| repo/dir's ``hubconf.py``. |
| *args (optional): the corresponding args for callable ``model``. |
| source (str, optional): 'github' or 'local'. Specifies how |
| ``repo_or_dir`` is to be interpreted. Default is 'github'. |
| trust_repo (bool, str or None): ``"check"``, ``True``, ``False`` or ``None``. |
| This parameter was introduced in v1.12 and helps ensuring that users |
| only run code from repos that they trust. |
| |
| - If ``False``, a prompt will ask the user whether the repo should |
| be trusted. |
| - If ``True``, the repo will be added to the trusted list and loaded |
| without requiring explicit confirmation. |
| - If ``"check"``, the repo will be checked against the list of |
| trusted repos in the cache. If it is not present in that list, the |
| behaviour will fall back onto the ``trust_repo=False`` option. |
| - If ``None``: this will raise a warning, inviting the user to set |
| ``trust_repo`` to either ``False``, ``True`` or ``"check"``. This |
| is only present for backward compatibility and will be removed in |
| v2.0. |
| |
| Default is ``None`` and will eventually change to ``"check"`` in v2.0. |
| force_reload (bool, optional): whether to force a fresh download of |
| the github repo unconditionally. Does not have any effect if |
| ``source = 'local'``. Default is ``False``. |
| verbose (bool, optional): If ``False``, mute messages about hitting |
| local caches. Note that the message about first download cannot be |
| muted. Does not have any effect if ``source = 'local'``. |
| Default is ``True``. |
| skip_validation (bool, optional): if ``False``, torchhub will check that the branch or commit |
| specified by the ``github`` argument properly belongs to the repo owner. This will make |
| requests to the GitHub API; you can specify a non-default GitHub token by setting the |
| ``GITHUB_TOKEN`` environment variable. Default is ``False``. |
| **kwargs (optional): the corresponding kwargs for callable ``model``. |
| |
| Returns: |
| The output of the ``model`` callable when called with the given |
| ``*args`` and ``**kwargs``. |
| |
| Example: |
| >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_HUB) |
| >>> # from a github repo |
| >>> repo = 'pytorch/vision' |
| >>> model = torch.hub.load(repo, 'resnet50', weights='ResNet50_Weights.IMAGENET1K_V1') |
| >>> # from a local directory |
| >>> path = '/some/local/path/pytorch/vision' |
| >>> # xdoctest: +SKIP |
| >>> model = torch.hub.load(path, 'resnet50', weights='ResNet50_Weights.DEFAULT') |
| """ |
| source = source.lower() |
| |
| if source not in ('github', 'local'): |
| raise ValueError( |
| f'Unknown source: "{source}". Allowed values: "github" | "local".') |
| |
| if source == 'github': |
| repo_or_dir = _get_cache_or_reload(repo_or_dir, force_reload, trust_repo, "load", |
| verbose=verbose, skip_validation=skip_validation) |
| |
| model = _load_local(repo_or_dir, model, *args, **kwargs) |
| return model |
| |
| |
| def _load_local(hubconf_dir, model, *args, **kwargs): |
| r""" |
| Load a model from a local directory with a ``hubconf.py``. |
| |
| Args: |
| hubconf_dir (str): path to a local directory that contains a |
| ``hubconf.py``. |
| model (str): name of an entrypoint defined in the directory's |
| ``hubconf.py``. |
| *args (optional): the corresponding args for callable ``model``. |
| **kwargs (optional): the corresponding kwargs for callable ``model``. |
| |
| Returns: |
| a single model with corresponding pretrained weights. |
| |
| Example: |
| >>> # xdoctest: +SKIP("stub local path") |
| >>> path = '/some/local/path/pytorch/vision' |
| >>> model = _load_local(path, 'resnet50', weights='ResNet50_Weights.IMAGENET1K_V1') |
| """ |
| with _add_to_sys_path(hubconf_dir): |
| hubconf_path = os.path.join(hubconf_dir, MODULE_HUBCONF) |
| hub_module = _import_module(MODULE_HUBCONF, hubconf_path) |
| |
| entry = _load_entry_from_hubconf(hub_module, model) |
| model = entry(*args, **kwargs) |
| |
| return model |
| |
| |
| def download_url_to_file(url: str, dst: str, hash_prefix: Optional[str] = None, |
| progress: bool = True) -> None: |
| r"""Download object at the given URL to a local path. |
| |
| Args: |
| url (str): URL of the object to download |
| dst (str): Full path where object will be saved, e.g. ``/tmp/temporary_file`` |
| hash_prefix (str, optional): If not None, the SHA256 downloaded file should start with ``hash_prefix``. |
| Default: None |
| progress (bool, optional): whether or not to display a progress bar to stderr |
| Default: True |
| |
| Example: |
| >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_HUB) |
| >>> # xdoctest: +REQUIRES(POSIX) |
| >>> torch.hub.download_url_to_file('https://s3.amazonaws.com/pytorch/models/resnet18-5c106cde.pth', '/tmp/temporary_file') |
| |
| """ |
| file_size = None |
| req = Request(url, headers={"User-Agent": "torch.hub"}) |
| u = urlopen(req) |
| meta = u.info() |
| if hasattr(meta, 'getheaders'): |
| content_length = meta.getheaders("Content-Length") |
| else: |
| content_length = meta.get_all("Content-Length") |
| if content_length is not None and len(content_length) > 0: |
| file_size = int(content_length[0]) |
| |
| # We deliberately save it in a temp file and move it after |
| # download is complete. This prevents a local working checkpoint |
| # being overridden by a broken download. |
| # We deliberately do not use NamedTemporaryFile to avoid restrictive |
| # file permissions being applied to the downloaded file. |
| dst = os.path.expanduser(dst) |
| for seq in range(tempfile.TMP_MAX): |
| tmp_dst = dst + '.' + uuid.uuid4().hex + '.partial' |
| try: |
| f = open(tmp_dst, 'w+b') |
| except FileExistsError: |
| continue |
| break |
| else: |
| raise FileExistsError(errno.EEXIST, 'No usable temporary file name found') |
| |
| try: |
| if hash_prefix is not None: |
| sha256 = hashlib.sha256() |
| with tqdm(total=file_size, disable=not progress, |
| unit='B', unit_scale=True, unit_divisor=1024) as pbar: |
| while True: |
| buffer = u.read(8192) |
| if len(buffer) == 0: |
| break |
| f.write(buffer) |
| if hash_prefix is not None: |
| sha256.update(buffer) |
| pbar.update(len(buffer)) |
| |
| f.close() |
| if hash_prefix is not None: |
| digest = sha256.hexdigest() |
| if digest[:len(hash_prefix)] != hash_prefix: |
| raise RuntimeError(f'invalid hash value (expected "{hash_prefix}", got "{digest}")') |
| shutil.move(f.name, dst) |
| finally: |
| f.close() |
| if os.path.exists(f.name): |
| os.remove(f.name) |
| |
| |
| # Hub used to support automatically extracts from zipfile manually compressed by users. |
| # The legacy zip format expects only one file from torch.save() < 1.6 in the zip. |
| # We should remove this support since zipfile is now default zipfile format for torch.save(). |
| def _is_legacy_zip_format(filename: str) -> bool: |
| if zipfile.is_zipfile(filename): |
| infolist = zipfile.ZipFile(filename).infolist() |
| return len(infolist) == 1 and not infolist[0].is_dir() |
| return False |
| |
| |
| def _legacy_zip_load(filename: str, model_dir: str, map_location: MAP_LOCATION, weights_only: bool) -> Dict[str, Any]: |
| warnings.warn('Falling back to the old format < 1.6. This support will be ' |
| 'deprecated in favor of default zipfile format introduced in 1.6. ' |
| 'Please redo torch.save() to save it in the new zipfile format.') |
| # Note: extractall() defaults to overwrite file if exists. No need to clean up beforehand. |
| # We deliberately don't handle tarfile here since our legacy serialization format was in tar. |
| # E.g. resnet18-5c106cde.pth which is widely used. |
| with zipfile.ZipFile(filename) as f: |
| members = f.infolist() |
| if len(members) != 1: |
| raise RuntimeError('Only one file(not dir) is allowed in the zipfile') |
| f.extractall(model_dir) |
| extraced_name = members[0].filename |
| extracted_file = os.path.join(model_dir, extraced_name) |
| return torch.load(extracted_file, map_location=map_location, weights_only=weights_only) |
| |
| |
| def load_state_dict_from_url( |
| url: str, |
| model_dir: Optional[str] = None, |
| map_location: MAP_LOCATION = None, |
| progress: bool = True, |
| check_hash: bool = False, |
| file_name: Optional[str] = None, |
| weights_only: bool = False, |
| ) -> Dict[str, Any]: |
| r"""Loads the Torch serialized object at the given URL. |
| |
| If downloaded file is a zip file, it will be automatically |
| decompressed. |
| |
| If the object is already present in `model_dir`, it's deserialized and |
| returned. |
| The default value of ``model_dir`` is ``<hub_dir>/checkpoints`` where |
| ``hub_dir`` is the directory returned by :func:`~torch.hub.get_dir`. |
| |
| Args: |
| url (str): URL of the object to download |
| model_dir (str, optional): directory in which to save the object |
| map_location (optional): a function or a dict specifying how to remap storage locations (see torch.load) |
| progress (bool, optional): whether or not to display a progress bar to stderr. |
| Default: True |
| check_hash(bool, optional): If True, the filename part of the URL should follow the naming convention |
| ``filename-<sha256>.ext`` where ``<sha256>`` is the first eight or more |
| digits of the SHA256 hash of the contents of the file. The hash is used to |
| ensure unique names and to verify the contents of the file. |
| Default: False |
| file_name (str, optional): name for the downloaded file. Filename from ``url`` will be used if not set. |
| weights_only(bool, optional): If True, only weights will be loaded and no complex pickled objects. |
| Recommended for untrusted sources. See :func:`~torch.load` for more details. |
| |
| Example: |
| >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_HUB) |
| >>> state_dict = torch.hub.load_state_dict_from_url('https://s3.amazonaws.com/pytorch/models/resnet18-5c106cde.pth') |
| |
| """ |
| # Issue warning to move data if old env is set |
| if os.getenv('TORCH_MODEL_ZOO'): |
| warnings.warn('TORCH_MODEL_ZOO is deprecated, please use env TORCH_HOME instead') |
| |
| if model_dir is None: |
| hub_dir = get_dir() |
| model_dir = os.path.join(hub_dir, 'checkpoints') |
| |
| try: |
| os.makedirs(model_dir) |
| except OSError as e: |
| if e.errno == errno.EEXIST: |
| # Directory already exists, ignore. |
| pass |
| else: |
| # Unexpected OSError, re-raise. |
| raise |
| |
| parts = urlparse(url) |
| filename = os.path.basename(parts.path) |
| if file_name is not None: |
| filename = file_name |
| cached_file = os.path.join(model_dir, filename) |
| if not os.path.exists(cached_file): |
| sys.stderr.write(f'Downloading: "{url}" to {cached_file}\n') |
| hash_prefix = None |
| if check_hash: |
| r = HASH_REGEX.search(filename) # r is Optional[Match[str]] |
| hash_prefix = r.group(1) if r else None |
| download_url_to_file(url, cached_file, hash_prefix, progress=progress) |
| |
| if _is_legacy_zip_format(cached_file): |
| return _legacy_zip_load(cached_file, model_dir, map_location, weights_only) |
| return torch.load(cached_file, map_location=map_location, weights_only=weights_only) |