| # mypy: allow-untyped-defs |
| import contextlib |
| import errno |
| import hashlib |
| import json |
| import os |
| import re |
| import shutil |
| import sys |
| import tempfile |
| import uuid |
| import warnings |
| import zipfile |
| from pathlib import Path |
| from typing import Any, Dict, Optional |
| from typing_extensions import deprecated |
| from urllib.error import HTTPError, URLError |
| from urllib.parse import urlparse # noqa: F401 |
| from urllib.request import Request, urlopen |
| |
| import torch |
| from torch.serialization import MAP_LOCATION |
| |
| |
| class _Faketqdm: # type: ignore[no-redef] |
| def __init__(self, total=None, disable=False, unit=None, *args, **kwargs): |
| self.total = total |
| self.disable = disable |
| self.n = 0 |
| # Ignore all extra *args and **kwargs lest you want to reinvent tqdm |
| |
| def update(self, n): |
| if self.disable: |
| return |
| |
| self.n += n |
| if self.total is None: |
| sys.stderr.write(f"\r{self.n:.1f} bytes") |
| else: |
| sys.stderr.write(f"\r{100 * self.n / float(self.total):.1f}%") |
| sys.stderr.flush() |
| |
| # Don't bother implementing; use real tqdm if you want |
| def set_description(self, *args, **kwargs): |
| pass |
| |
| def write(self, s): |
| sys.stderr.write(f"{s}\n") |
| |
| def close(self): |
| self.disable = True |
| |
| def __enter__(self): |
| return self |
| |
| def __exit__(self, exc_type, exc_val, exc_tb): |
| if self.disable: |
| return |
| |
| sys.stderr.write("\n") |
| |
| |
| try: |
| from tqdm import tqdm # If tqdm is installed use it, otherwise use the fake wrapper |
| except ImportError: |
| tqdm = _Faketqdm |
| |
| __all__ = [ |
| "download_url_to_file", |
| "get_dir", |
| "help", |
| "list", |
| "load", |
| "load_state_dict_from_url", |
| "set_dir", |
| ] |
| |
| # matches bfd8deac from resnet18-bfd8deac.pth |
| HASH_REGEX = re.compile(r"-([a-f0-9]*)\.") |
| |
| _TRUSTED_REPO_OWNERS = ( |
| "facebookresearch", |
| "facebookincubator", |
| "pytorch", |
| "fairinternal", |
| ) |
| ENV_GITHUB_TOKEN = "GITHUB_TOKEN" |
| ENV_TORCH_HOME = "TORCH_HOME" |
| ENV_XDG_CACHE_HOME = "XDG_CACHE_HOME" |
| DEFAULT_CACHE_DIR = "~/.cache" |
| VAR_DEPENDENCY = "dependencies" |
| MODULE_HUBCONF = "hubconf.py" |
| READ_DATA_CHUNK = 128 * 1024 |
| _hub_dir: Optional[str] = None |
| |
| |
| @contextlib.contextmanager |
| def _add_to_sys_path(path): |
| sys.path.insert(0, path) |
| try: |
| yield |
| finally: |
| sys.path.remove(path) |
| |
| |
| # Copied from tools/shared/module_loader to be included in torch package |
| def _import_module(name, path): |
| import importlib.util |
| from importlib.abc import Loader |
| |
| spec = importlib.util.spec_from_file_location(name, path) |
| assert spec is not None |
| module = importlib.util.module_from_spec(spec) |
| assert isinstance(spec.loader, Loader) |
| spec.loader.exec_module(module) |
| return module |
| |
| |
| def _remove_if_exists(path): |
| if os.path.exists(path): |
| if os.path.isfile(path): |
| os.remove(path) |
| else: |
| shutil.rmtree(path) |
| |
| |
| def _git_archive_link(repo_owner, repo_name, ref): |
| # See https://docs.github.com/en/rest/reference/repos#download-a-repository-archive-zip |
| return f"https://github.com/{repo_owner}/{repo_name}/zipball/{ref}" |
| |
| |
| def _load_attr_from_module(module, func_name): |
| # Check if callable is defined in the module |
| if func_name not in dir(module): |
| return None |
| return getattr(module, func_name) |
| |
| |
| def _get_torch_home(): |
| torch_home = os.path.expanduser( |
| os.getenv( |
| ENV_TORCH_HOME, |
| os.path.join(os.getenv(ENV_XDG_CACHE_HOME, DEFAULT_CACHE_DIR), "torch"), |
| ) |
| ) |
| return torch_home |
| |
| |
| def _parse_repo_info(github): |
| if ":" in github: |
| repo_info, ref = github.split(":") |
| else: |
| repo_info, ref = github, None |
| repo_owner, repo_name = repo_info.split("/") |
| |
| if ref is None: |
| # The ref wasn't specified by the user, so we need to figure out the |
| # default branch: main or master. Our assumption is that if main exists |
| # then it's the default branch, otherwise it's master. |
| try: |
| with urlopen(f"https://github.com/{repo_owner}/{repo_name}/tree/main/"): |
| ref = "main" |
| except HTTPError as e: |
| if e.code == 404: |
| ref = "master" |
| else: |
| raise |
| except URLError as e: |
| # No internet connection, need to check for cache as last resort |
| for possible_ref in ("main", "master"): |
| if os.path.exists( |
| f"{get_dir()}/{repo_owner}_{repo_name}_{possible_ref}" |
| ): |
| ref = possible_ref |
| break |
| if ref is None: |
| raise RuntimeError( |
| "It looks like there is no internet connection and the " |
| f"repo could not be found in the cache ({get_dir()})" |
| ) from e |
| return repo_owner, repo_name, ref |
| |
| |
| def _read_url(url): |
| with urlopen(url) as r: |
| return r.read().decode(r.headers.get_content_charset("utf-8")) |
| |
| |
| def _validate_not_a_forked_repo(repo_owner, repo_name, ref): |
| # Use urlopen to avoid depending on local git. |
| headers = {"Accept": "application/vnd.github.v3+json"} |
| token = os.environ.get(ENV_GITHUB_TOKEN) |
| if token is not None: |
| headers["Authorization"] = f"token {token}" |
| for url_prefix in ( |
| f"https://api.github.com/repos/{repo_owner}/{repo_name}/branches", |
| f"https://api.github.com/repos/{repo_owner}/{repo_name}/tags", |
| ): |
| page = 0 |
| while True: |
| page += 1 |
| url = f"{url_prefix}?per_page=100&page={page}" |
| response = json.loads(_read_url(Request(url, headers=headers))) |
| # Empty response means no more data to process |
| if not response: |
| break |
| for br in response: |
| if br["name"] == ref or br["commit"]["sha"].startswith(ref): |
| return |
| |
| raise ValueError( |
| f"Cannot find {ref} in https://github.com/{repo_owner}/{repo_name}. " |
| "If it's a commit from a forked repo, please call hub.load() with forked repo directly." |
| ) |
| |
| |
| def _get_cache_or_reload( |
| github, |
| force_reload, |
| trust_repo, |
| calling_fn, |
| verbose=True, |
| skip_validation=False, |
| ): |
| # Setup hub_dir to save downloaded files |
| hub_dir = get_dir() |
| os.makedirs(hub_dir, exist_ok=True) |
| # Parse github repo information |
| repo_owner, repo_name, ref = _parse_repo_info(github) |
| # Github allows branch name with slash '/', |
| # this causes confusion with path on both Linux and Windows. |
| # Backslash is not allowed in Github branch name so no need to |
| # to worry about it. |
| normalized_br = ref.replace("/", "_") |
| # Github renames folder repo-v1.x.x to repo-1.x.x |
| # We don't know the repo name before downloading the zip file |
| # and inspect name from it. |
| # To check if cached repo exists, we need to normalize folder names. |
| owner_name_branch = "_".join([repo_owner, repo_name, normalized_br]) |
| repo_dir = os.path.join(hub_dir, owner_name_branch) |
| # Check that the repo is in the trusted list |
| _check_repo_is_trusted( |
| repo_owner, |
| repo_name, |
| owner_name_branch, |
| trust_repo=trust_repo, |
| calling_fn=calling_fn, |
| ) |
| |
| use_cache = (not force_reload) and os.path.exists(repo_dir) |
| |
| if use_cache: |
| if verbose: |
| sys.stderr.write(f"Using cache found in {repo_dir}\n") |
| else: |
| # Validate the tag/branch is from the original repo instead of a forked repo |
| if not skip_validation: |
| _validate_not_a_forked_repo(repo_owner, repo_name, ref) |
| |
| cached_file = os.path.join(hub_dir, normalized_br + ".zip") |
| _remove_if_exists(cached_file) |
| |
| try: |
| url = _git_archive_link(repo_owner, repo_name, ref) |
| sys.stderr.write(f'Downloading: "{url}" to {cached_file}\n') |
| download_url_to_file(url, cached_file, progress=False) |
| except HTTPError as err: |
| if err.code == 300: |
| # Getting a 300 Multiple Choices error likely means that the ref is both a tag and a branch |
| # in the repo. This can be disambiguated by explicitely using refs/heads/ or refs/tags |
| # See https://git-scm.com/book/en/v2/Git-Internals-Git-References |
| # Here, we do the same as git: we throw a warning, and assume the user wanted the branch |
| warnings.warn( |
| f"The ref {ref} is ambiguous. Perhaps it is both a tag and a branch in the repo? " |
| "Torchhub will now assume that it's a branch. " |
| "You can disambiguate tags and branches by explicitly passing refs/heads/branch_name or " |
| "refs/tags/tag_name as the ref. That might require using skip_validation=True." |
| ) |
| disambiguated_branch_ref = f"refs/heads/{ref}" |
| url = _git_archive_link( |
| repo_owner, repo_name, ref=disambiguated_branch_ref |
| ) |
| download_url_to_file(url, cached_file, progress=False) |
| else: |
| raise |
| |
| with zipfile.ZipFile(cached_file) as cached_zipfile: |
| extraced_repo_name = cached_zipfile.infolist()[0].filename |
| extracted_repo = os.path.join(hub_dir, extraced_repo_name) |
| _remove_if_exists(extracted_repo) |
| # Unzip the code and rename the base folder |
| cached_zipfile.extractall(hub_dir) |
| |
| _remove_if_exists(cached_file) |
| _remove_if_exists(repo_dir) |
| shutil.move(extracted_repo, repo_dir) # rename the repo |
| |
| return repo_dir |
| |
| |
| def _check_repo_is_trusted( |
| repo_owner, |
| repo_name, |
| owner_name_branch, |
| trust_repo, |
| calling_fn="load", |
| ): |
| hub_dir = get_dir() |
| filepath = os.path.join(hub_dir, "trusted_list") |
| |
| if not os.path.exists(filepath): |
| Path(filepath).touch() |
| with open(filepath) as file: |
| trusted_repos = tuple(line.strip() for line in file) |
| |
| # To minimize friction of introducing the new trust_repo mechanism, we consider that |
| # if a repo was already downloaded by torchhub, then it is already trusted (even if it's not in the allowlist) |
| trusted_repos_legacy = next(os.walk(hub_dir))[1] |
| |
| owner_name = "_".join([repo_owner, repo_name]) |
| is_trusted = ( |
| owner_name in trusted_repos |
| or owner_name_branch in trusted_repos_legacy |
| or repo_owner in _TRUSTED_REPO_OWNERS |
| ) |
| |
| # TODO: Remove `None` option in 2.0 and change the default to "check" |
| if trust_repo is None: |
| if not is_trusted: |
| warnings.warn( |
| "You are about to download and run code from an untrusted repository. In a future release, this won't " |
| "be allowed. To add the repository to your trusted list, change the command to {calling_fn}(..., " |
| "trust_repo=False) and a command prompt will appear asking for an explicit confirmation of trust, " |
| f"or {calling_fn}(..., trust_repo=True), which will assume that the prompt is to be answered with " |
| f"'yes'. You can also use {calling_fn}(..., trust_repo='check') which will only prompt for " |
| f"confirmation if the repo is not already trusted. This will eventually be the default behaviour" |
| ) |
| return |
| |
| if (trust_repo is False) or (trust_repo == "check" and not is_trusted): |
| response = input( |
| f"The repository {owner_name} does not belong to the list of trusted repositories and as such cannot be downloaded. " |
| "Do you trust this repository and wish to add it to the trusted list of repositories (y/N)?" |
| ) |
| if response.lower() in ("y", "yes"): |
| if is_trusted: |
| print("The repository is already trusted.") |
| elif response.lower() in ("n", "no", ""): |
| raise Exception("Untrusted repository.") # noqa: TRY002 |
| else: |
| raise ValueError(f"Unrecognized response {response}.") |
| |
| # At this point we're sure that the user trusts the repo (or wants to trust it) |
| if not is_trusted: |
| with open(filepath, "a") as file: |
| file.write(owner_name + "\n") |
| |
| |
| def _check_module_exists(name): |
| import importlib.util |
| |
| return importlib.util.find_spec(name) is not None |
| |
| |
| def _check_dependencies(m): |
| dependencies = _load_attr_from_module(m, VAR_DEPENDENCY) |
| |
| if dependencies is not None: |
| missing_deps = [pkg for pkg in dependencies if not _check_module_exists(pkg)] |
| if len(missing_deps): |
| raise RuntimeError(f"Missing dependencies: {', '.join(missing_deps)}") |
| |
| |
| def _load_entry_from_hubconf(m, model): |
| if not isinstance(model, str): |
| raise ValueError("Invalid input: model should be a string of function name") |
| |
| # Note that if a missing dependency is imported at top level of hubconf, it will |
| # throw before this function. It's a chicken and egg situation where we have to |
| # load hubconf to know what're the dependencies, but to import hubconf it requires |
| # a missing package. This is fine, Python will throw proper error message for users. |
| _check_dependencies(m) |
| |
| func = _load_attr_from_module(m, model) |
| |
| if func is None or not callable(func): |
| raise RuntimeError(f"Cannot find callable {model} in hubconf") |
| |
| return func |
| |
| |
| def get_dir(): |
| r""" |
| Get the Torch Hub cache directory used for storing downloaded models & weights. |
| |
| If :func:`~torch.hub.set_dir` is not called, default path is ``$TORCH_HOME/hub`` where |
| environment variable ``$TORCH_HOME`` defaults to ``$XDG_CACHE_HOME/torch``. |
| ``$XDG_CACHE_HOME`` follows the X Design Group specification of the Linux |
| filesystem layout, with a default value ``~/.cache`` if the environment |
| variable is not set. |
| """ |
| # Issue warning to move data if old env is set |
| if os.getenv("TORCH_HUB"): |
| warnings.warn("TORCH_HUB is deprecated, please use env TORCH_HOME instead") |
| |
| if _hub_dir is not None: |
| return _hub_dir |
| return os.path.join(_get_torch_home(), "hub") |
| |
| |
| def set_dir(d): |
| r""" |
| Optionally set the Torch Hub directory used to save downloaded models & weights. |
| |
| Args: |
| d (str): path to a local folder to save downloaded models & weights. |
| """ |
| global _hub_dir |
| _hub_dir = os.path.expanduser(d) |
| |
| |
| def list( |
| github, |
| force_reload=False, |
| skip_validation=False, |
| trust_repo=None, |
| verbose=True, |
| ): |
| r""" |
| List all callable entrypoints available in the repo specified by ``github``. |
| |
| Args: |
| github (str): a string with format "repo_owner/repo_name[:ref]" with an optional |
| ref (tag or branch). If ``ref`` is not specified, the default branch is assumed to be ``main`` if |
| it exists, and otherwise ``master``. |
| Example: 'pytorch/vision:0.10' |
| force_reload (bool, optional): whether to discard the existing cache and force a fresh download. |
| Default is ``False``. |
| skip_validation (bool, optional): if ``False``, torchhub will check that the branch or commit |
| specified by the ``github`` argument properly belongs to the repo owner. This will make |
| requests to the GitHub API; you can specify a non-default GitHub token by setting the |
| ``GITHUB_TOKEN`` environment variable. Default is ``False``. |
| trust_repo (bool, str or None): ``"check"``, ``True``, ``False`` or ``None``. |
| This parameter was introduced in v1.12 and helps ensuring that users |
| only run code from repos that they trust. |
| |
| - If ``False``, a prompt will ask the user whether the repo should |
| be trusted. |
| - If ``True``, the repo will be added to the trusted list and loaded |
| without requiring explicit confirmation. |
| - If ``"check"``, the repo will be checked against the list of |
| trusted repos in the cache. If it is not present in that list, the |
| behaviour will fall back onto the ``trust_repo=False`` option. |
| - If ``None``: this will raise a warning, inviting the user to set |
| ``trust_repo`` to either ``False``, ``True`` or ``"check"``. This |
| is only present for backward compatibility and will be removed in |
| v2.0. |
| |
| Default is ``None`` and will eventually change to ``"check"`` in v2.0. |
| verbose (bool, optional): If ``False``, mute messages about hitting |
| local caches. Note that the message about first download cannot be |
| muted. Default is ``True``. |
| |
| Returns: |
| list: The available callables entrypoint |
| |
| Example: |
| >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_HUB) |
| >>> entrypoints = torch.hub.list("pytorch/vision", force_reload=True) |
| """ |
| repo_dir = _get_cache_or_reload( |
| github, |
| force_reload, |
| trust_repo, |
| "list", |
| verbose=verbose, |
| skip_validation=skip_validation, |
| ) |
| |
| with _add_to_sys_path(repo_dir): |
| hubconf_path = os.path.join(repo_dir, MODULE_HUBCONF) |
| hub_module = _import_module(MODULE_HUBCONF, hubconf_path) |
| |
| # We take functions starts with '_' as internal helper functions |
| entrypoints = [ |
| f |
| for f in dir(hub_module) |
| if callable(getattr(hub_module, f)) and not f.startswith("_") |
| ] |
| |
| return entrypoints |
| |
| |
| def help(github, model, force_reload=False, skip_validation=False, trust_repo=None): |
| r""" |
| Show the docstring of entrypoint ``model``. |
| |
| Args: |
| github (str): a string with format <repo_owner/repo_name[:ref]> with an optional |
| ref (a tag or a branch). If ``ref`` is not specified, the default branch is assumed |
| to be ``main`` if it exists, and otherwise ``master``. |
| Example: 'pytorch/vision:0.10' |
| model (str): a string of entrypoint name defined in repo's ``hubconf.py`` |
| force_reload (bool, optional): whether to discard the existing cache and force a fresh download. |
| Default is ``False``. |
| skip_validation (bool, optional): if ``False``, torchhub will check that the ref |
| specified by the ``github`` argument properly belongs to the repo owner. This will make |
| requests to the GitHub API; you can specify a non-default GitHub token by setting the |
| ``GITHUB_TOKEN`` environment variable. Default is ``False``. |
| trust_repo (bool, str or None): ``"check"``, ``True``, ``False`` or ``None``. |
| This parameter was introduced in v1.12 and helps ensuring that users |
| only run code from repos that they trust. |
| |
| - If ``False``, a prompt will ask the user whether the repo should |
| be trusted. |
| - If ``True``, the repo will be added to the trusted list and loaded |
| without requiring explicit confirmation. |
| - If ``"check"``, the repo will be checked against the list of |
| trusted repos in the cache. If it is not present in that list, the |
| behaviour will fall back onto the ``trust_repo=False`` option. |
| - If ``None``: this will raise a warning, inviting the user to set |
| ``trust_repo`` to either ``False``, ``True`` or ``"check"``. This |
| is only present for backward compatibility and will be removed in |
| v2.0. |
| |
| Default is ``None`` and will eventually change to ``"check"`` in v2.0. |
| Example: |
| >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_HUB) |
| >>> print(torch.hub.help("pytorch/vision", "resnet18", force_reload=True)) |
| """ |
| repo_dir = _get_cache_or_reload( |
| github, |
| force_reload, |
| trust_repo, |
| "help", |
| verbose=True, |
| skip_validation=skip_validation, |
| ) |
| |
| with _add_to_sys_path(repo_dir): |
| hubconf_path = os.path.join(repo_dir, MODULE_HUBCONF) |
| hub_module = _import_module(MODULE_HUBCONF, hubconf_path) |
| |
| entry = _load_entry_from_hubconf(hub_module, model) |
| |
| return entry.__doc__ |
| |
| |
| def load( |
| repo_or_dir, |
| model, |
| *args, |
| source="github", |
| trust_repo=None, |
| force_reload=False, |
| verbose=True, |
| skip_validation=False, |
| **kwargs, |
| ): |
| r""" |
| Load a model from a github repo or a local directory. |
| |
| Note: Loading a model is the typical use case, but this can also be used to |
| for loading other objects such as tokenizers, loss functions, etc. |
| |
| If ``source`` is 'github', ``repo_or_dir`` is expected to be |
| of the form ``repo_owner/repo_name[:ref]`` with an optional |
| ref (a tag or a branch). |
| |
| If ``source`` is 'local', ``repo_or_dir`` is expected to be a |
| path to a local directory. |
| |
| Args: |
| repo_or_dir (str): If ``source`` is 'github', |
| this should correspond to a github repo with format ``repo_owner/repo_name[:ref]`` with |
| an optional ref (tag or branch), for example 'pytorch/vision:0.10'. If ``ref`` is not specified, |
| the default branch is assumed to be ``main`` if it exists, and otherwise ``master``. |
| If ``source`` is 'local' then it should be a path to a local directory. |
| model (str): the name of a callable (entrypoint) defined in the |
| repo/dir's ``hubconf.py``. |
| *args (optional): the corresponding args for callable ``model``. |
| source (str, optional): 'github' or 'local'. Specifies how |
| ``repo_or_dir`` is to be interpreted. Default is 'github'. |
| trust_repo (bool, str or None): ``"check"``, ``True``, ``False`` or ``None``. |
| This parameter was introduced in v1.12 and helps ensuring that users |
| only run code from repos that they trust. |
| |
| - If ``False``, a prompt will ask the user whether the repo should |
| be trusted. |
| - If ``True``, the repo will be added to the trusted list and loaded |
| without requiring explicit confirmation. |
| - If ``"check"``, the repo will be checked against the list of |
| trusted repos in the cache. If it is not present in that list, the |
| behaviour will fall back onto the ``trust_repo=False`` option. |
| - If ``None``: this will raise a warning, inviting the user to set |
| ``trust_repo`` to either ``False``, ``True`` or ``"check"``. This |
| is only present for backward compatibility and will be removed in |
| v2.0. |
| |
| Default is ``None`` and will eventually change to ``"check"`` in v2.0. |
| force_reload (bool, optional): whether to force a fresh download of |
| the github repo unconditionally. Does not have any effect if |
| ``source = 'local'``. Default is ``False``. |
| verbose (bool, optional): If ``False``, mute messages about hitting |
| local caches. Note that the message about first download cannot be |
| muted. Does not have any effect if ``source = 'local'``. |
| Default is ``True``. |
| skip_validation (bool, optional): if ``False``, torchhub will check that the branch or commit |
| specified by the ``github`` argument properly belongs to the repo owner. This will make |
| requests to the GitHub API; you can specify a non-default GitHub token by setting the |
| ``GITHUB_TOKEN`` environment variable. Default is ``False``. |
| **kwargs (optional): the corresponding kwargs for callable ``model``. |
| |
| Returns: |
| The output of the ``model`` callable when called with the given |
| ``*args`` and ``**kwargs``. |
| |
| Example: |
| >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_HUB) |
| >>> # from a github repo |
| >>> repo = "pytorch/vision" |
| >>> model = torch.hub.load( |
| ... repo, "resnet50", weights="ResNet50_Weights.IMAGENET1K_V1" |
| ... ) |
| >>> # from a local directory |
| >>> path = "/some/local/path/pytorch/vision" |
| >>> # xdoctest: +SKIP |
| >>> model = torch.hub.load(path, "resnet50", weights="ResNet50_Weights.DEFAULT") |
| """ |
| source = source.lower() |
| |
| if source not in ("github", "local"): |
| raise ValueError( |
| f'Unknown source: "{source}". Allowed values: "github" | "local".' |
| ) |
| |
| if source == "github": |
| repo_or_dir = _get_cache_or_reload( |
| repo_or_dir, |
| force_reload, |
| trust_repo, |
| "load", |
| verbose=verbose, |
| skip_validation=skip_validation, |
| ) |
| |
| model = _load_local(repo_or_dir, model, *args, **kwargs) |
| return model |
| |
| |
| def _load_local(hubconf_dir, model, *args, **kwargs): |
| r""" |
| Load a model from a local directory with a ``hubconf.py``. |
| |
| Args: |
| hubconf_dir (str): path to a local directory that contains a |
| ``hubconf.py``. |
| model (str): name of an entrypoint defined in the directory's |
| ``hubconf.py``. |
| *args (optional): the corresponding args for callable ``model``. |
| **kwargs (optional): the corresponding kwargs for callable ``model``. |
| |
| Returns: |
| a single model with corresponding pretrained weights. |
| |
| Example: |
| >>> # xdoctest: +SKIP("stub local path") |
| >>> path = "/some/local/path/pytorch/vision" |
| >>> model = _load_local(path, "resnet50", weights="ResNet50_Weights.IMAGENET1K_V1") |
| """ |
| with _add_to_sys_path(hubconf_dir): |
| hubconf_path = os.path.join(hubconf_dir, MODULE_HUBCONF) |
| hub_module = _import_module(MODULE_HUBCONF, hubconf_path) |
| |
| entry = _load_entry_from_hubconf(hub_module, model) |
| model = entry(*args, **kwargs) |
| |
| return model |
| |
| |
| def download_url_to_file( |
| url: str, |
| dst: str, |
| hash_prefix: Optional[str] = None, |
| progress: bool = True, |
| ) -> None: |
| r"""Download object at the given URL to a local path. |
| |
| Args: |
| url (str): URL of the object to download |
| dst (str): Full path where object will be saved, e.g. ``/tmp/temporary_file`` |
| hash_prefix (str, optional): If not None, the SHA256 downloaded file should start with ``hash_prefix``. |
| Default: None |
| progress (bool, optional): whether or not to display a progress bar to stderr |
| Default: True |
| |
| Example: |
| >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_HUB) |
| >>> # xdoctest: +REQUIRES(POSIX) |
| >>> torch.hub.download_url_to_file( |
| ... "https://s3.amazonaws.com/pytorch/models/resnet18-5c106cde.pth", |
| ... "/tmp/temporary_file", |
| ... ) |
| |
| """ |
| file_size = None |
| req = Request(url, headers={"User-Agent": "torch.hub"}) |
| u = urlopen(req) |
| meta = u.info() |
| if hasattr(meta, "getheaders"): |
| content_length = meta.getheaders("Content-Length") |
| else: |
| content_length = meta.get_all("Content-Length") |
| if content_length is not None and len(content_length) > 0: |
| file_size = int(content_length[0]) |
| |
| # We deliberately save it in a temp file and move it after |
| # download is complete. This prevents a local working checkpoint |
| # being overridden by a broken download. |
| # We deliberately do not use NamedTemporaryFile to avoid restrictive |
| # file permissions being applied to the downloaded file. |
| dst = os.path.expanduser(dst) |
| for seq in range(tempfile.TMP_MAX): |
| tmp_dst = dst + "." + uuid.uuid4().hex + ".partial" |
| try: |
| f = open(tmp_dst, "w+b") |
| except FileExistsError: |
| continue |
| break |
| else: |
| raise FileExistsError(errno.EEXIST, "No usable temporary file name found") |
| |
| try: |
| if hash_prefix is not None: |
| sha256 = hashlib.sha256() |
| with tqdm( |
| total=file_size, |
| disable=not progress, |
| unit="B", |
| unit_scale=True, |
| unit_divisor=1024, |
| ) as pbar: |
| while True: |
| buffer = u.read(READ_DATA_CHUNK) |
| if len(buffer) == 0: |
| break |
| f.write(buffer) # type: ignore[possibly-undefined] |
| if hash_prefix is not None: |
| sha256.update(buffer) # type: ignore[possibly-undefined] |
| pbar.update(len(buffer)) |
| |
| f.close() |
| if hash_prefix is not None: |
| digest = sha256.hexdigest() # type: ignore[possibly-undefined] |
| if digest[: len(hash_prefix)] != hash_prefix: |
| raise RuntimeError( |
| f'invalid hash value (expected "{hash_prefix}", got "{digest}")' |
| ) |
| shutil.move(f.name, dst) |
| finally: |
| f.close() |
| if os.path.exists(f.name): |
| os.remove(f.name) |
| |
| |
| # Hub used to support automatically extracts from zipfile manually compressed by users. |
| # The legacy zip format expects only one file from torch.save() < 1.6 in the zip. |
| # We should remove this support since zipfile is now default zipfile format for torch.save(). |
| def _is_legacy_zip_format(filename: str) -> bool: |
| if zipfile.is_zipfile(filename): |
| infolist = zipfile.ZipFile(filename).infolist() |
| return len(infolist) == 1 and not infolist[0].is_dir() |
| return False |
| |
| |
| @deprecated( |
| "Falling back to the old format < 1.6. This support will be " |
| "deprecated in favor of default zipfile format introduced in 1.6. " |
| "Please redo torch.save() to save it in the new zipfile format.", |
| category=FutureWarning, |
| ) |
| def _legacy_zip_load( |
| filename: str, |
| model_dir: str, |
| map_location: MAP_LOCATION, |
| weights_only: bool, |
| ) -> Dict[str, Any]: |
| # Note: extractall() defaults to overwrite file if exists. No need to clean up beforehand. |
| # We deliberately don't handle tarfile here since our legacy serialization format was in tar. |
| # E.g. resnet18-5c106cde.pth which is widely used. |
| with zipfile.ZipFile(filename) as f: |
| members = f.infolist() |
| if len(members) != 1: |
| raise RuntimeError("Only one file(not dir) is allowed in the zipfile") |
| f.extractall(model_dir) |
| extraced_name = members[0].filename |
| extracted_file = os.path.join(model_dir, extraced_name) |
| return torch.load( |
| extracted_file, map_location=map_location, weights_only=weights_only |
| ) |
| |
| |
| def load_state_dict_from_url( |
| url: str, |
| model_dir: Optional[str] = None, |
| map_location: MAP_LOCATION = None, |
| progress: bool = True, |
| check_hash: bool = False, |
| file_name: Optional[str] = None, |
| weights_only: bool = False, |
| ) -> Dict[str, Any]: |
| r"""Loads the Torch serialized object at the given URL. |
| |
| If downloaded file is a zip file, it will be automatically |
| decompressed. |
| |
| If the object is already present in `model_dir`, it's deserialized and |
| returned. |
| The default value of ``model_dir`` is ``<hub_dir>/checkpoints`` where |
| ``hub_dir`` is the directory returned by :func:`~torch.hub.get_dir`. |
| |
| Args: |
| url (str): URL of the object to download |
| model_dir (str, optional): directory in which to save the object |
| map_location (optional): a function or a dict specifying how to remap storage locations (see torch.load) |
| progress (bool, optional): whether or not to display a progress bar to stderr. |
| Default: True |
| check_hash(bool, optional): If True, the filename part of the URL should follow the naming convention |
| ``filename-<sha256>.ext`` where ``<sha256>`` is the first eight or more |
| digits of the SHA256 hash of the contents of the file. The hash is used to |
| ensure unique names and to verify the contents of the file. |
| Default: False |
| file_name (str, optional): name for the downloaded file. Filename from ``url`` will be used if not set. |
| weights_only(bool, optional): If True, only weights will be loaded and no complex pickled objects. |
| Recommended for untrusted sources. See :func:`~torch.load` for more details. |
| |
| Example: |
| >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_HUB) |
| >>> state_dict = torch.hub.load_state_dict_from_url( |
| ... "https://s3.amazonaws.com/pytorch/models/resnet18-5c106cde.pth" |
| ... ) |
| |
| """ |
| # Issue warning to move data if old env is set |
| if os.getenv("TORCH_MODEL_ZOO"): |
| warnings.warn( |
| "TORCH_MODEL_ZOO is deprecated, please use env TORCH_HOME instead" |
| ) |
| |
| if model_dir is None: |
| hub_dir = get_dir() |
| model_dir = os.path.join(hub_dir, "checkpoints") |
| |
| os.makedirs(model_dir, exist_ok=True) |
| |
| parts = urlparse(url) |
| filename = os.path.basename(parts.path) |
| if file_name is not None: |
| filename = file_name |
| cached_file = os.path.join(model_dir, filename) |
| if not os.path.exists(cached_file): |
| sys.stderr.write(f'Downloading: "{url}" to {cached_file}\n') |
| hash_prefix = None |
| if check_hash: |
| r = HASH_REGEX.search(filename) # r is Optional[Match[str]] |
| hash_prefix = r.group(1) if r else None |
| download_url_to_file(url, cached_file, hash_prefix, progress=progress) |
| |
| if _is_legacy_zip_format(cached_file): |
| return _legacy_zip_load(cached_file, model_dir, map_location, weights_only) |
| return torch.load(cached_file, map_location=map_location, weights_only=weights_only) |