| # mypy: allow-untyped-defs |
| |
| import functools |
| import itertools |
| import warnings |
| import weakref |
| from collections import namedtuple, OrderedDict |
| from typing import ( |
| Any, |
| Callable, |
| Dict, |
| Iterator, |
| List, |
| Mapping, |
| Optional, |
| overload, |
| Set, |
| Tuple, |
| TypeVar, |
| Union, |
| ) |
| from typing_extensions import Self |
| |
| import torch |
| from torch import device, dtype, Tensor |
| from torch._prims_common import DeviceLikeType |
| from torch.nn.parameter import Parameter |
| from torch.utils._python_dispatch import is_traceable_wrapper_subclass |
| from torch.utils.hooks import BackwardHook, RemovableHandle |
| |
| |
| __all__ = [ |
| "register_module_forward_pre_hook", |
| "register_module_forward_hook", |
| "register_module_full_backward_pre_hook", |
| "register_module_backward_hook", |
| "register_module_full_backward_hook", |
| "register_module_buffer_registration_hook", |
| "register_module_module_registration_hook", |
| "register_module_parameter_registration_hook", |
| "Module", |
| ] |
| |
| _grad_t = Union[Tuple[Tensor, ...], Tensor] |
| # See https://mypy.readthedocs.io/en/latest/generics.html#generic-methods-and-generic-self for the use |
| # of `T` to annotate `self`. Many methods of `Module` return `self` and we want those return values to be |
| # the type of the subclass, not the looser type of `Module`. |
| T = TypeVar("T", bound="Module") |
| |
| |
| class _IncompatibleKeys( |
| namedtuple("IncompatibleKeys", ["missing_keys", "unexpected_keys"]), |
| ): |
| def __repr__(self): |
| if not self.missing_keys and not self.unexpected_keys: |
| return "<All keys matched successfully>" |
| return super().__repr__() |
| |
| __str__ = __repr__ |
| |
| |
| def _addindent(s_, numSpaces): |
| s = s_.split("\n") |
| # don't do anything for single-line stuff |
| if len(s) == 1: |
| return s_ |
| first = s.pop(0) |
| s = [(numSpaces * " ") + line for line in s] |
| s = "\n".join(s) |
| s = first + "\n" + s |
| return s |
| |
| |
| r"""This tracks hooks common to all modules that are executed immediately before |
| .registering the buffer/module/parameter""" |
| _global_buffer_registration_hooks: Dict[int, Callable] = OrderedDict() |
| _global_module_registration_hooks: Dict[int, Callable] = OrderedDict() |
| _global_parameter_registration_hooks: Dict[int, Callable] = OrderedDict() |
| |
| |
| class _WrappedHook: |
| def __init__(self, hook: Callable, module: Optional["Module"] = None): |
| self.hook: Callable = hook |
| functools.update_wrapper(self, hook) |
| |
| self.with_module: bool = False |
| |
| if module is not None: |
| self.module: weakref.ReferenceType[Module] = weakref.ref(module) |
| self.with_module = True |
| |
| def __call__(self, *args: Any, **kwargs: Any) -> Any: |
| if self.with_module: |
| module = self.module() |
| if module is None: |
| raise RuntimeError("You are trying to call the hook of a dead Module!") |
| return self.hook(module, *args, **kwargs) |
| return self.hook(*args, **kwargs) |
| |
| def __getstate__(self) -> Dict: |
| result = {"hook": self.hook, "with_module": self.with_module} |
| if self.with_module: |
| result["module"] = self.module() |
| |
| return result |
| |
| def __setstate__(self, state: Dict): |
| self.hook = state["hook"] |
| self.with_module = state["with_module"] |
| |
| if self.with_module: |
| if state["module"] is None: |
| raise RuntimeError( |
| "You are trying to revive the hook of a dead Module!" |
| ) |
| self.module = weakref.ref(state["module"]) |
| |
| |
| r"""This tracks hooks common to all modules that are executed before/after |
| calling forward and backward. This is global state used for debugging/profiling |
| purposes""" |
| _global_backward_pre_hooks: Dict[int, Callable] = OrderedDict() |
| _global_backward_hooks: Dict[int, Callable] = OrderedDict() |
| _global_is_full_backward_hook: Optional[bool] = None |
| _global_forward_pre_hooks: Dict[int, Callable] = OrderedDict() |
| _global_forward_hooks: Dict[int, Callable] = OrderedDict() |
| _global_forward_hooks_always_called: Dict[int, bool] = OrderedDict() |
| |
| _EXTRA_STATE_KEY_SUFFIX = "_extra_state" |
| |
| |
| def register_module_buffer_registration_hook( |
| hook: Callable[..., None], |
| ) -> RemovableHandle: |
| r"""Register a buffer registration hook common to all modules. |
| |
| .. warning :: |
| |
| This adds global state to the `nn.Module` module |
| |
| The hook will be called every time :func:`register_buffer` is invoked. |
| It should have the following signature:: |
| |
| hook(module, name, buffer) -> None or new buffer |
| |
| The hook can modify the input or return a single modified value in the hook. |
| |
| Returns: |
| :class:`torch.utils.hooks.RemovableHandle`: |
| a handle that can be used to remove the added hook by calling |
| ``handle.remove()`` |
| """ |
| handle = RemovableHandle(_global_buffer_registration_hooks) |
| _global_buffer_registration_hooks[handle.id] = hook |
| return handle |
| |
| |
| def register_module_module_registration_hook( |
| hook: Callable[..., None], |
| ) -> RemovableHandle: |
| r"""Register a module registration hook common to all modules. |
| |
| .. warning :: |
| |
| This adds global state to the `nn.Module` module |
| |
| The hook will be called every time :func:`register_module` is invoked. |
| It should have the following signature:: |
| |
| hook(module, name, submodule) -> None or new submodule |
| |
| The hook can modify the input or return a single modified value in the hook. |
| |
| Returns: |
| :class:`torch.utils.hooks.RemovableHandle`: |
| a handle that can be used to remove the added hook by calling |
| ``handle.remove()`` |
| """ |
| handle = RemovableHandle(_global_module_registration_hooks) |
| _global_module_registration_hooks[handle.id] = hook |
| return handle |
| |
| |
| def register_module_parameter_registration_hook( |
| hook: Callable[..., None], |
| ) -> RemovableHandle: |
| r"""Register a parameter registration hook common to all modules. |
| |
| .. warning :: |
| |
| This adds global state to the `nn.Module` module |
| |
| The hook will be called every time :func:`register_parameter` is invoked. |
| It should have the following signature:: |
| |
| hook(module, name, param) -> None or new parameter |
| |
| The hook can modify the input or return a single modified value in the hook. |
| |
| Returns: |
| :class:`torch.utils.hooks.RemovableHandle`: |
| a handle that can be used to remove the added hook by calling |
| ``handle.remove()`` |
| """ |
| handle = RemovableHandle(_global_parameter_registration_hooks) |
| _global_parameter_registration_hooks[handle.id] = hook |
| return handle |
| |
| |
| def register_module_forward_pre_hook(hook: Callable[..., None]) -> RemovableHandle: |
| r"""Register a forward pre-hook common to all modules. |
| |
| .. warning :: |
| |
| This adds global state to the `nn.module` module |
| and it is only intended for debugging/profiling purposes. |
| |
| The hook will be called every time before :func:`forward` is invoked. |
| It should have the following signature:: |
| |
| hook(module, input) -> None or modified input |
| |
| The input contains only the positional arguments given to the module. |
| Keyword arguments won't be passed to the hooks and only to the ``forward``. |
| The hook can modify the input. User can either return a tuple or a |
| single modified value in the hook. We will wrap the value into a tuple |
| if a single value is returned(unless that value is already a tuple). |
| |
| This hook has precedence over the specific module hooks registered with |
| ``register_forward_pre_hook``. |
| |
| Returns: |
| :class:`torch.utils.hooks.RemovableHandle`: |
| a handle that can be used to remove the added hook by calling |
| ``handle.remove()`` |
| """ |
| handle = RemovableHandle(_global_forward_pre_hooks) |
| _global_forward_pre_hooks[handle.id] = hook |
| return handle |
| |
| |
| def register_module_forward_hook( |
| hook: Callable[..., None], |
| *, |
| always_call: bool = False, |
| ) -> RemovableHandle: |
| r"""Register a global forward hook for all the modules. |
| |
| .. warning :: |
| |
| This adds global state to the `nn.module` module |
| and it is only intended for debugging/profiling purposes. |
| |
| The hook will be called every time after :func:`forward` has computed an output. |
| It should have the following signature:: |
| |
| hook(module, input, output) -> None or modified output |
| |
| The input contains only the positional arguments given to the module. |
| Keyword arguments won't be passed to the hooks and only to the ``forward``. |
| The hook can modify the output. It can modify the input inplace but |
| it will not have effect on forward since this is called after |
| :func:`forward` is called. |
| |
| Parameters: |
| hook (Callable): The user defined hook to be registered. |
| always_call (bool): If ``True`` the ``hook`` will be run regardless of |
| whether an exception is raised while calling the Module. |
| Default: ``False`` |
| Returns: |
| :class:`torch.utils.hooks.RemovableHandle`: |
| a handle that can be used to remove the added hook by calling |
| ``handle.remove()`` |
| |
| This hook will be executed before specific module hooks registered with |
| ``register_forward_hook``. |
| """ |
| handle = RemovableHandle( |
| _global_forward_hooks, extra_dict=_global_forward_hooks_always_called |
| ) |
| _global_forward_hooks[handle.id] = hook |
| if always_call: |
| _global_forward_hooks_always_called[handle.id] = True |
| return handle |
| |
| |
| def register_module_backward_hook( |
| hook: Callable[["Module", _grad_t, _grad_t], Union[None, _grad_t]], |
| ) -> RemovableHandle: |
| r"""Register a backward hook common to all the modules. |
| |
| This function is deprecated in favor of |
| :func:`torch.nn.modules.module.register_module_full_backward_hook` |
| and the behavior of this function will change in future versions. |
| |
| Returns: |
| :class:`torch.utils.hooks.RemovableHandle`: |
| a handle that can be used to remove the added hook by calling |
| ``handle.remove()`` |
| |
| """ |
| global _global_is_full_backward_hook |
| if _global_is_full_backward_hook is True: |
| raise RuntimeError( |
| "Cannot use both regular backward hooks and full backward hooks as a " |
| "global Module hook. Please use only one of them." |
| ) |
| |
| _global_is_full_backward_hook = False |
| |
| handle = RemovableHandle(_global_backward_hooks) |
| _global_backward_hooks[handle.id] = hook |
| return handle |
| |
| |
| def register_module_full_backward_pre_hook( |
| hook: Callable[["Module", _grad_t], Union[None, _grad_t]], |
| ) -> RemovableHandle: |
| r"""Register a backward pre-hook common to all the modules. |
| |
| .. warning :: |
| This adds global state to the `nn.module` module |
| and it is only intended for debugging/profiling purposes. |
| |
| Hooks registered using this function behave in the same way as those |
| registered by :meth:`torch.nn.Module.register_full_backward_pre_hook`. |
| Refer to its documentation for more details. |
| |
| Hooks registered using this function will be called before hooks registered |
| using :meth:`torch.nn.Module.register_full_backward_pre_hook`. |
| |
| Returns: |
| :class:`torch.utils.hooks.RemovableHandle`: |
| a handle that can be used to remove the added hook by calling |
| ``handle.remove()`` |
| |
| """ |
| handle = RemovableHandle(_global_backward_pre_hooks) |
| _global_backward_pre_hooks[handle.id] = hook |
| return handle |
| |
| |
| def register_module_full_backward_hook( |
| hook: Callable[["Module", _grad_t, _grad_t], Union[None, _grad_t]], |
| ) -> RemovableHandle: |
| r"""Register a backward hook common to all the modules. |
| |
| .. warning :: |
| This adds global state to the `nn.module` module |
| and it is only intended for debugging/profiling purposes. |
| |
| Hooks registered using this function behave in the same way as those |
| registered by :meth:`torch.nn.Module.register_full_backward_hook`. |
| Refer to its documentation for more details. |
| |
| Hooks registered using this function will be called before hooks registered |
| using :meth:`torch.nn.Module.register_full_backward_hook`. |
| |
| Returns: |
| :class:`torch.utils.hooks.RemovableHandle`: |
| a handle that can be used to remove the added hook by calling |
| ``handle.remove()`` |
| |
| """ |
| global _global_is_full_backward_hook |
| if _global_is_full_backward_hook is False: |
| raise RuntimeError( |
| "Cannot use both regular backward hooks and full backward hooks as a " |
| "global Module hook. Please use only one of them." |
| ) |
| |
| _global_is_full_backward_hook = True |
| |
| handle = RemovableHandle(_global_backward_hooks) |
| _global_backward_hooks[handle.id] = hook |
| return handle |
| |
| |
| # Trick mypy into not applying contravariance rules to inputs by defining |
| # forward as a value, rather than a function. See also |
| # https://github.com/python/mypy/issues/8795 |
| def _forward_unimplemented(self, *input: Any) -> None: |
| r"""Define the computation performed at every call. |
| |
| Should be overridden by all subclasses. |
| |
| .. note:: |
| Although the recipe for forward pass needs to be defined within |
| this function, one should call the :class:`Module` instance afterwards |
| instead of this since the former takes care of running the |
| registered hooks while the latter silently ignores them. |
| """ |
| raise NotImplementedError( |
| f'Module [{type(self).__name__}] is missing the required "forward" function' |
| ) |
| |
| |
| class Module: |
| r"""Base class for all neural network modules. |
| |
| Your models should also subclass this class. |
| |
| Modules can also contain other Modules, allowing to nest them in |
| a tree structure. You can assign the submodules as regular attributes:: |
| |
| import torch.nn as nn |
| import torch.nn.functional as F |
| |
| class Model(nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.conv1 = nn.Conv2d(1, 20, 5) |
| self.conv2 = nn.Conv2d(20, 20, 5) |
| |
| def forward(self, x): |
| x = F.relu(self.conv1(x)) |
| return F.relu(self.conv2(x)) |
| |
| Submodules assigned in this way will be registered, and will have their |
| parameters converted too when you call :meth:`to`, etc. |
| |
| .. note:: |
| As per the example above, an ``__init__()`` call to the parent class |
| must be made before assignment on the child. |
| |
| :ivar training: Boolean represents whether this module is in training or |
| evaluation mode. |
| :vartype training: bool |
| """ |
| |
| dump_patches: bool = False |
| |
| _version: int = 1 |
| r"""This allows better BC support for :meth:`load_state_dict`. In |
| :meth:`state_dict`, the version number will be saved as in the attribute |
| `_metadata` of the returned state dict, and thus pickled. `_metadata` is a |
| dictionary with keys that follow the naming convention of state dict. See |
| ``_load_from_state_dict`` on how to use this information in loading. |
| |
| If new parameters/buffers are added/removed from a module, this number shall |
| be bumped, and the module's `_load_from_state_dict` method can compare the |
| version number and do appropriate changes if the state dict is from before |
| the change.""" |
| |
| training: bool |
| _parameters: Dict[str, Optional[Parameter]] |
| _buffers: Dict[str, Optional[Tensor]] |
| _non_persistent_buffers_set: Set[str] |
| _backward_pre_hooks: Dict[int, Callable] |
| _backward_hooks: Dict[int, Callable] |
| _is_full_backward_hook: Optional[bool] |
| _forward_hooks: Dict[int, Callable] |
| # Marks whether the corresponding _forward_hooks accept kwargs or not. |
| # As JIT does not support Set[int], this dict is used as a set, where all |
| # hooks represented in this dict accept kwargs. |
| _forward_hooks_with_kwargs: Dict[int, bool] |
| # forward hooks that should always be called even if an exception is raised |
| _forward_hooks_always_called: Dict[int, bool] |
| _forward_pre_hooks: Dict[int, Callable] |
| # Marks whether the corresponding _forward_hooks accept kwargs or not. |
| # As JIT does not support Set[int], this dict is used as a set, where all |
| # hooks represented in this dict accept kwargs. |
| _forward_pre_hooks_with_kwargs: Dict[int, bool] |
| _state_dict_hooks: Dict[int, Callable] |
| _load_state_dict_pre_hooks: Dict[int, Callable] |
| _state_dict_pre_hooks: Dict[int, Callable] |
| _load_state_dict_post_hooks: Dict[int, Callable] |
| _modules: Dict[str, Optional["Module"]] |
| call_super_init: bool = False |
| _compiled_call_impl: Optional[Callable] = None |
| |
| def __init__(self, *args, **kwargs) -> None: |
| """Initialize internal Module state, shared by both nn.Module and ScriptModule.""" |
| torch._C._log_api_usage_once("python.nn_module") |
| |
| # Backward compatibility: no args used to be allowed when call_super_init=False |
| if self.call_super_init is False and bool(kwargs): |
| raise TypeError( |
| f"{type(self).__name__}.__init__() got an unexpected keyword argument '{next(iter(kwargs))}'" |
| "" |
| ) |
| |
| if self.call_super_init is False and bool(args): |
| raise TypeError( |
| f"{type(self).__name__}.__init__() takes 1 positional argument but {len(args) + 1} were" |
| " given" |
| ) |
| |
| """ |
| Calls super().__setattr__('a', a) instead of the typical self.a = a |
| to avoid Module.__setattr__ overhead. Module's __setattr__ has special |
| handling for parameters, submodules, and buffers but simply calls into |
| super().__setattr__ for all other attributes. |
| """ |
| super().__setattr__("training", True) |
| super().__setattr__("_parameters", dict()) |
| super().__setattr__("_buffers", dict()) |
| super().__setattr__("_non_persistent_buffers_set", set()) |
| super().__setattr__("_backward_pre_hooks", OrderedDict()) |
| super().__setattr__("_backward_hooks", OrderedDict()) |
| super().__setattr__("_is_full_backward_hook", None) |
| super().__setattr__("_forward_hooks", OrderedDict()) |
| super().__setattr__("_forward_hooks_with_kwargs", OrderedDict()) |
| super().__setattr__("_forward_hooks_always_called", OrderedDict()) |
| super().__setattr__("_forward_pre_hooks", OrderedDict()) |
| super().__setattr__("_forward_pre_hooks_with_kwargs", OrderedDict()) |
| super().__setattr__("_state_dict_hooks", OrderedDict()) |
| super().__setattr__("_state_dict_pre_hooks", OrderedDict()) |
| super().__setattr__("_load_state_dict_pre_hooks", OrderedDict()) |
| super().__setattr__("_load_state_dict_post_hooks", OrderedDict()) |
| super().__setattr__("_modules", dict()) |
| |
| if self.call_super_init: |
| super().__init__(*args, **kwargs) |
| |
| forward: Callable[..., Any] = _forward_unimplemented |
| |
| def register_buffer( |
| self, name: str, tensor: Optional[Tensor], persistent: bool = True |
| ) -> None: |
| r"""Add a buffer to the module. |
| |
| This is typically used to register a buffer that should not to be |
| considered a model parameter. For example, BatchNorm's ``running_mean`` |
| is not a parameter, but is part of the module's state. Buffers, by |
| default, are persistent and will be saved alongside parameters. This |
| behavior can be changed by setting :attr:`persistent` to ``False``. The |
| only difference between a persistent buffer and a non-persistent buffer |
| is that the latter will not be a part of this module's |
| :attr:`state_dict`. |
| |
| Buffers can be accessed as attributes using given names. |
| |
| Args: |
| name (str): name of the buffer. The buffer can be accessed |
| from this module using the given name |
| tensor (Tensor or None): buffer to be registered. If ``None``, then operations |
| that run on buffers, such as :attr:`cuda`, are ignored. If ``None``, |
| the buffer is **not** included in the module's :attr:`state_dict`. |
| persistent (bool): whether the buffer is part of this module's |
| :attr:`state_dict`. |
| |
| Example:: |
| |
| >>> # xdoctest: +SKIP("undefined vars") |
| >>> self.register_buffer('running_mean', torch.zeros(num_features)) |
| |
| """ |
| if persistent is False and isinstance(self, torch.jit.ScriptModule): |
| raise RuntimeError("ScriptModule does not support non-persistent buffers") |
| |
| if "_buffers" not in self.__dict__: |
| raise AttributeError("cannot assign buffer before Module.__init__() call") |
| elif not isinstance(name, str): |
| raise TypeError( |
| f"buffer name should be a string. Got {torch.typename(name)}" |
| ) |
| elif "." in name: |
| raise KeyError('buffer name can\'t contain "."') |
| elif name == "": |
| raise KeyError('buffer name can\'t be empty string ""') |
| elif hasattr(self, name) and name not in self._buffers: |
| raise KeyError(f"attribute '{name}' already exists") |
| elif tensor is not None and not isinstance(tensor, torch.Tensor): |
| raise TypeError( |
| f"cannot assign '{torch.typename(tensor)}' object to buffer '{name}' " |
| "(torch Tensor or None required)" |
| ) |
| else: |
| for hook in _global_buffer_registration_hooks.values(): |
| output = hook(self, name, tensor) |
| if output is not None: |
| tensor = output |
| self._buffers[name] = tensor |
| if persistent: |
| self._non_persistent_buffers_set.discard(name) |
| else: |
| self._non_persistent_buffers_set.add(name) |
| |
| def register_parameter(self, name: str, param: Optional[Parameter]) -> None: |
| r"""Add a parameter to the module. |
| |
| The parameter can be accessed as an attribute using given name. |
| |
| Args: |
| name (str): name of the parameter. The parameter can be accessed |
| from this module using the given name |
| param (Parameter or None): parameter to be added to the module. If |
| ``None``, then operations that run on parameters, such as :attr:`cuda`, |
| are ignored. If ``None``, the parameter is **not** included in the |
| module's :attr:`state_dict`. |
| """ |
| if "_parameters" not in self.__dict__: |
| raise AttributeError( |
| "cannot assign parameter before Module.__init__() call" |
| ) |
| |
| elif not isinstance(name, str): |
| raise TypeError( |
| f"parameter name should be a string. Got {torch.typename(name)}" |
| ) |
| elif "." in name: |
| raise KeyError('parameter name can\'t contain "."') |
| elif name == "": |
| raise KeyError('parameter name can\'t be empty string ""') |
| elif hasattr(self, name) and name not in self._parameters: |
| raise KeyError(f"attribute '{name}' already exists") |
| |
| if param is None: |
| self._parameters[name] = None |
| elif not isinstance(param, Parameter): |
| raise TypeError( |
| f"cannot assign '{torch.typename(param)}' object to parameter '{name}' " |
| "(torch.nn.Parameter or None required)" |
| ) |
| elif param.grad_fn: |
| raise ValueError( |
| f"Cannot assign non-leaf Tensor to parameter '{name}'. Model " |
| f"parameters must be created explicitly. To express '{name}' " |
| "as a function of another Tensor, compute the value in " |
| "the forward() method." |
| ) |
| else: |
| for hook in _global_parameter_registration_hooks.values(): |
| output = hook(self, name, param) |
| if output is not None: |
| param = output |
| self._parameters[name] = param |
| |
| def add_module(self, name: str, module: Optional["Module"]) -> None: |
| r"""Add a child module to the current module. |
| |
| The module can be accessed as an attribute using the given name. |
| |
| Args: |
| name (str): name of the child module. The child module can be |
| accessed from this module using the given name |
| module (Module): child module to be added to the module. |
| """ |
| if not isinstance(module, Module) and module is not None: |
| raise TypeError(f"{torch.typename(module)} is not a Module subclass") |
| elif not isinstance(name, str): |
| raise TypeError( |
| f"module name should be a string. Got {torch.typename(name)}" |
| ) |
| elif hasattr(self, name) and name not in self._modules: |
| raise KeyError(f"attribute '{name}' already exists") |
| elif "." in name: |
| raise KeyError(f'module name can\'t contain ".", got: {name}') |
| elif name == "": |
| raise KeyError('module name can\'t be empty string ""') |
| for hook in _global_module_registration_hooks.values(): |
| output = hook(self, name, module) |
| if output is not None: |
| module = output |
| self._modules[name] = module |
| |
| def register_module(self, name: str, module: Optional["Module"]) -> None: |
| r"""Alias for :func:`add_module`.""" |
| self.add_module(name, module) |
| |
| def get_submodule(self, target: str) -> "Module": |
| """Return the submodule given by ``target`` if it exists, otherwise throw an error. |
| |
| For example, let's say you have an ``nn.Module`` ``A`` that |
| looks like this: |
| |
| .. code-block:: text |
| |
| A( |
| (net_b): Module( |
| (net_c): Module( |
| (conv): Conv2d(16, 33, kernel_size=(3, 3), stride=(2, 2)) |
| ) |
| (linear): Linear(in_features=100, out_features=200, bias=True) |
| ) |
| ) |
| |
| (The diagram shows an ``nn.Module`` ``A``. ``A`` has a nested |
| submodule ``net_b``, which itself has two submodules ``net_c`` |
| and ``linear``. ``net_c`` then has a submodule ``conv``.) |
| |
| To check whether or not we have the ``linear`` submodule, we |
| would call ``get_submodule("net_b.linear")``. To check whether |
| we have the ``conv`` submodule, we would call |
| ``get_submodule("net_b.net_c.conv")``. |
| |
| The runtime of ``get_submodule`` is bounded by the degree |
| of module nesting in ``target``. A query against |
| ``named_modules`` achieves the same result, but it is O(N) in |
| the number of transitive modules. So, for a simple check to see |
| if some submodule exists, ``get_submodule`` should always be |
| used. |
| |
| Args: |
| target: The fully-qualified string name of the submodule |
| to look for. (See above example for how to specify a |
| fully-qualified string.) |
| |
| Returns: |
| torch.nn.Module: The submodule referenced by ``target`` |
| |
| Raises: |
| AttributeError: If the target string references an invalid |
| path or resolves to something that is not an |
| ``nn.Module`` |
| """ |
| if target == "": |
| return self |
| |
| atoms: List[str] = target.split(".") |
| mod: torch.nn.Module = self |
| |
| for item in atoms: |
| if not hasattr(mod, item): |
| raise AttributeError( |
| mod._get_name() + " has no " "attribute `" + item + "`" |
| ) |
| |
| mod = getattr(mod, item) |
| |
| if not isinstance(mod, torch.nn.Module): |
| raise AttributeError("`" + item + "` is not " "an nn.Module") |
| |
| return mod |
| |
| def set_submodule(self, target: str, module: "Module") -> None: |
| """ |
| Set the submodule given by ``target`` if it exists, otherwise throw an error. |
| |
| For example, let's say you have an ``nn.Module`` ``A`` that |
| looks like this: |
| |
| .. code-block:: text |
| |
| A( |
| (net_b): Module( |
| (net_c): Module( |
| (conv): Conv2d(16, 33, kernel_size=(3, 3), stride=(2, 2)) |
| ) |
| (linear): Linear(in_features=100, out_features=200, bias=True) |
| ) |
| ) |
| |
| (The diagram shows an ``nn.Module`` ``A``. ``A`` has a nested |
| submodule ``net_b``, which itself has two submodules ``net_c`` |
| and ``linear``. ``net_c`` then has a submodule ``conv``.) |
| |
| To overide the ``Conv2d`` with a new submodule ``Linear``, you |
| would call |
| ``set_submodule("net_b.net_c.conv", nn.Linear(33, 16))``. |
| |
| Args: |
| target: The fully-qualified string name of the submodule |
| to look for. (See above example for how to specify a |
| fully-qualified string.) |
| module: The module to set the submodule to. |
| |
| Raises: |
| ValueError: If the target string is empty |
| AttributeError: If the target string references an invalid |
| path or resolves to something that is not an |
| ``nn.Module`` |
| """ |
| if target == "": |
| raise ValueError("Cannot set the submodule without a target name!") |
| |
| atoms: List[str] = target.split(".") |
| name = atoms.pop(-1) |
| mod: torch.nn.Module = self |
| |
| for item in atoms: |
| if not hasattr(mod, item): |
| raise AttributeError( |
| mod._get_name() + " has no attribute `" + item + "`" |
| ) |
| |
| mod = getattr(mod, item) |
| |
| # Use isinstance instead of type here to also handle subclass of nn.Module |
| if not isinstance(mod, torch.nn.Module): |
| raise AttributeError("`" + item + "` is not an nn.Module") |
| |
| setattr(mod, name, module) |
| |
| def get_parameter(self, target: str) -> "Parameter": |
| """Return the parameter given by ``target`` if it exists, otherwise throw an error. |
| |
| See the docstring for ``get_submodule`` for a more detailed |
| explanation of this method's functionality as well as how to |
| correctly specify ``target``. |
| |
| Args: |
| target: The fully-qualified string name of the Parameter |
| to look for. (See ``get_submodule`` for how to specify a |
| fully-qualified string.) |
| |
| Returns: |
| torch.nn.Parameter: The Parameter referenced by ``target`` |
| |
| Raises: |
| AttributeError: If the target string references an invalid |
| path or resolves to something that is not an |
| ``nn.Parameter`` |
| """ |
| module_path, _, param_name = target.rpartition(".") |
| |
| mod: torch.nn.Module = self.get_submodule(module_path) |
| |
| if not hasattr(mod, param_name): |
| raise AttributeError( |
| mod._get_name() + " has no attribute `" + param_name + "`" |
| ) |
| |
| param: torch.nn.Parameter = getattr(mod, param_name) |
| |
| if not isinstance(param, torch.nn.Parameter): |
| raise AttributeError("`" + param_name + "` is not an " "nn.Parameter") |
| |
| return param |
| |
| def get_buffer(self, target: str) -> "Tensor": |
| """Return the buffer given by ``target`` if it exists, otherwise throw an error. |
| |
| See the docstring for ``get_submodule`` for a more detailed |
| explanation of this method's functionality as well as how to |
| correctly specify ``target``. |
| |
| Args: |
| target: The fully-qualified string name of the buffer |
| to look for. (See ``get_submodule`` for how to specify a |
| fully-qualified string.) |
| |
| Returns: |
| torch.Tensor: The buffer referenced by ``target`` |
| |
| Raises: |
| AttributeError: If the target string references an invalid |
| path or resolves to something that is not a |
| buffer |
| """ |
| module_path, _, buffer_name = target.rpartition(".") |
| |
| mod: torch.nn.Module = self.get_submodule(module_path) |
| |
| if not hasattr(mod, buffer_name): |
| raise AttributeError( |
| mod._get_name() + " has no attribute `" + buffer_name + "`" |
| ) |
| |
| buffer: torch.Tensor = getattr(mod, buffer_name) |
| |
| if buffer_name not in mod._buffers: |
| raise AttributeError("`" + buffer_name + "` is not a buffer") |
| |
| return buffer |
| |
| def get_extra_state(self) -> Any: |
| """Return any extra state to include in the module's state_dict. |
| |
| Implement this and a corresponding :func:`set_extra_state` for your module |
| if you need to store extra state. This function is called when building the |
| module's `state_dict()`. |
| |
| Note that extra state should be picklable to ensure working serialization |
| of the state_dict. We only provide provide backwards compatibility guarantees |
| for serializing Tensors; other objects may break backwards compatibility if |
| their serialized pickled form changes. |
| |
| Returns: |
| object: Any extra state to store in the module's state_dict |
| """ |
| raise RuntimeError( |
| "Reached a code path in Module.get_extra_state() that should never be called. " |
| "Please file an issue at https://github.com/pytorch/pytorch/issues/new?template=bug-report.yml " |
| "to report this bug." |
| ) |
| |
| def set_extra_state(self, state: Any) -> None: |
| """Set extra state contained in the loaded `state_dict`. |
| |
| This function is called from :func:`load_state_dict` to handle any extra state |
| found within the `state_dict`. Implement this function and a corresponding |
| :func:`get_extra_state` for your module if you need to store extra state within its |
| `state_dict`. |
| |
| Args: |
| state (dict): Extra state from the `state_dict` |
| """ |
| raise RuntimeError( |
| "Reached a code path in Module.set_extra_state() that should never be called. " |
| "Please file an issue at https://github.com/pytorch/pytorch/issues/new?template=bug-report.yml " |
| "to report this bug." |
| ) |
| |
| def _apply(self, fn, recurse=True): |
| if recurse: |
| for module in self.children(): |
| module._apply(fn) |
| |
| def compute_should_use_set_data(tensor, tensor_applied): |
| if torch._has_compatible_shallow_copy_type(tensor, tensor_applied): |
| # If the new tensor has compatible tensor type as the existing tensor, |
| # the current behavior is to change the tensor in-place using `.data =`, |
| # and the future behavior is to overwrite the existing tensor. However, |
| # changing the current behavior is a BC-breaking change, and we want it |
| # to happen in future releases. So for now we introduce the |
| # `torch.__future__.get_overwrite_module_params_on_conversion()` |
| # global flag to let the user control whether they want the future |
| # behavior of overwriting the existing tensor or not. |
| return not torch.__future__.get_overwrite_module_params_on_conversion() |
| else: |
| return False |
| |
| should_use_swap_tensors = ( |
| torch.__future__.get_swap_module_params_on_conversion() |
| ) |
| |
| for key, param in self._parameters.items(): |
| if param is None: |
| continue |
| # Tensors stored in modules are graph leaves, and we don't want to |
| # track autograd history of `param_applied`, so we have to use |
| # `with torch.no_grad():` |
| with torch.no_grad(): |
| param_applied = fn(param) |
| p_should_use_set_data = compute_should_use_set_data(param, param_applied) |
| |
| # subclasses may have multiple child tensors so we need to use swap_tensors |
| p_should_use_swap_tensors = ( |
| should_use_swap_tensors or is_traceable_wrapper_subclass(param_applied) |
| ) |
| |
| param_grad = param.grad |
| if p_should_use_swap_tensors: |
| try: |
| if param_grad is not None: |
| # Accessing param.grad makes its at::Tensor's use_count 2, which will prevent swapping. |
| # Decrement use count of the gradient by setting to None |
| param.grad = None |
| param_applied = torch.nn.Parameter( |
| param_applied, requires_grad=param.requires_grad |
| ) |
| torch.utils.swap_tensors(param, param_applied) |
| except Exception as e: |
| if param_grad is not None: |
| param.grad = param_grad |
| raise RuntimeError( |
| f"_apply(): Couldn't swap {self._get_name()}.{key}" |
| ) from e |
| out_param = param |
| elif p_should_use_set_data: |
| param.data = param_applied |
| out_param = param |
| else: |
| assert isinstance(param, Parameter) |
| assert param.is_leaf |
| out_param = Parameter(param_applied, param.requires_grad) |
| self._parameters[key] = out_param |
| |
| if param_grad is not None: |
| with torch.no_grad(): |
| grad_applied = fn(param_grad) |
| g_should_use_set_data = compute_should_use_set_data( |
| param_grad, grad_applied |
| ) |
| if p_should_use_swap_tensors: |
| grad_applied.requires_grad_(param_grad.requires_grad) |
| try: |
| torch.utils.swap_tensors(param_grad, grad_applied) |
| except Exception as e: |
| raise RuntimeError( |
| f"_apply(): Couldn't swap {self._get_name()}.{key}.grad" |
| ) from e |
| out_param.grad = param_grad |
| elif g_should_use_set_data: |
| assert out_param.grad is not None |
| out_param.grad.data = grad_applied |
| else: |
| assert param_grad.is_leaf |
| out_param.grad = grad_applied.requires_grad_( |
| param_grad.requires_grad |
| ) |
| |
| for key, buf in self._buffers.items(): |
| if buf is not None: |
| self._buffers[key] = fn(buf) |
| |
| return self |
| |
| def apply(self: T, fn: Callable[["Module"], None]) -> T: |
| r"""Apply ``fn`` recursively to every submodule (as returned by ``.children()``) as well as self. |
| |
| Typical use includes initializing the parameters of a model |
| (see also :ref:`nn-init-doc`). |
| |
| Args: |
| fn (:class:`Module` -> None): function to be applied to each submodule |
| |
| Returns: |
| Module: self |
| |
| Example:: |
| |
| >>> @torch.no_grad() |
| >>> def init_weights(m): |
| >>> print(m) |
| >>> if type(m) == nn.Linear: |
| >>> m.weight.fill_(1.0) |
| >>> print(m.weight) |
| >>> net = nn.Sequential(nn.Linear(2, 2), nn.Linear(2, 2)) |
| >>> net.apply(init_weights) |
| Linear(in_features=2, out_features=2, bias=True) |
| Parameter containing: |
| tensor([[1., 1.], |
| [1., 1.]], requires_grad=True) |
| Linear(in_features=2, out_features=2, bias=True) |
| Parameter containing: |
| tensor([[1., 1.], |
| [1., 1.]], requires_grad=True) |
| Sequential( |
| (0): Linear(in_features=2, out_features=2, bias=True) |
| (1): Linear(in_features=2, out_features=2, bias=True) |
| ) |
| |
| """ |
| for module in self.children(): |
| module.apply(fn) |
| fn(self) |
| return self |
| |
| def cuda(self: T, device: Optional[Union[int, device]] = None) -> T: |
| r"""Move all model parameters and buffers to the GPU. |
| |
| This also makes associated parameters and buffers different objects. So |
| it should be called before constructing optimizer if the module will |
| live on GPU while being optimized. |
| |
| .. note:: |
| This method modifies the module in-place. |
| |
| Args: |
| device (int, optional): if specified, all parameters will be |
| copied to that device |
| |
| Returns: |
| Module: self |
| """ |
| return self._apply(lambda t: t.cuda(device)) |
| |
| def ipu(self: T, device: Optional[Union[int, device]] = None) -> T: |
| r"""Move all model parameters and buffers to the IPU. |
| |
| This also makes associated parameters and buffers different objects. So |
| it should be called before constructing optimizer if the module will |
| live on IPU while being optimized. |
| |
| .. note:: |
| This method modifies the module in-place. |
| |
| Arguments: |
| device (int, optional): if specified, all parameters will be |
| copied to that device |
| |
| Returns: |
| Module: self |
| """ |
| return self._apply(lambda t: t.ipu(device)) |
| |
| def xpu(self: T, device: Optional[Union[int, device]] = None) -> T: |
| r"""Move all model parameters and buffers to the XPU. |
| |
| This also makes associated parameters and buffers different objects. So |
| it should be called before constructing optimizer if the module will |
| live on XPU while being optimized. |
| |
| .. note:: |
| This method modifies the module in-place. |
| |
| Arguments: |
| device (int, optional): if specified, all parameters will be |
| copied to that device |
| |
| Returns: |
| Module: self |
| """ |
| return self._apply(lambda t: t.xpu(device)) |
| |
| def cpu(self: T) -> T: |
| r"""Move all model parameters and buffers to the CPU. |
| |
| .. note:: |
| This method modifies the module in-place. |
| |
| Returns: |
| Module: self |
| """ |
| return self._apply(lambda t: t.cpu()) |
| |
| def type(self: T, dst_type: Union[dtype, str]) -> T: |
| r"""Casts all parameters and buffers to :attr:`dst_type`. |
| |
| .. note:: |
| This method modifies the module in-place. |
| |
| Args: |
| dst_type (type or string): the desired type |
| |
| Returns: |
| Module: self |
| """ |
| return self._apply(lambda t: t.type(dst_type)) |
| |
| def float(self: T) -> T: |
| r"""Casts all floating point parameters and buffers to ``float`` datatype. |
| |
| .. note:: |
| This method modifies the module in-place. |
| |
| Returns: |
| Module: self |
| """ |
| return self._apply(lambda t: t.float() if t.is_floating_point() else t) |
| |
| def double(self: T) -> T: |
| r"""Casts all floating point parameters and buffers to ``double`` datatype. |
| |
| .. note:: |
| This method modifies the module in-place. |
| |
| Returns: |
| Module: self |
| """ |
| return self._apply(lambda t: t.double() if t.is_floating_point() else t) |
| |
| def half(self: T) -> T: |
| r"""Casts all floating point parameters and buffers to ``half`` datatype. |
| |
| .. note:: |
| This method modifies the module in-place. |
| |
| Returns: |
| Module: self |
| """ |
| return self._apply(lambda t: t.half() if t.is_floating_point() else t) |
| |
| def bfloat16(self: T) -> T: |
| r"""Casts all floating point parameters and buffers to ``bfloat16`` datatype. |
| |
| .. note:: |
| This method modifies the module in-place. |
| |
| Returns: |
| Module: self |
| """ |
| return self._apply(lambda t: t.bfloat16() if t.is_floating_point() else t) |
| |
| def to_empty( |
| self: T, *, device: Optional[DeviceLikeType], recurse: bool = True |
| ) -> T: |
| r"""Move the parameters and buffers to the specified device without copying storage. |
| |
| Args: |
| device (:class:`torch.device`): The desired device of the parameters |
| and buffers in this module. |
| recurse (bool): Whether parameters and buffers of submodules should |
| be recursively moved to the specified device. |
| |
| Returns: |
| Module: self |
| """ |
| return self._apply( |
| lambda t: torch.empty_like(t, device=device), recurse=recurse |
| ) |
| |
| @overload |
| def to( |
| self, |
| device: Optional[DeviceLikeType] = ..., |
| dtype: Optional[dtype] = ..., |
| non_blocking: bool = ..., |
| ) -> Self: |
| ... |
| |
| @overload |
| def to(self, dtype: dtype, non_blocking: bool = ...) -> Self: |
| ... |
| |
| @overload |
| def to(self, tensor: Tensor, non_blocking: bool = ...) -> Self: |
| ... |
| |
| def to(self, *args, **kwargs): |
| r"""Move and/or cast the parameters and buffers. |
| |
| This can be called as |
| |
| .. function:: to(device=None, dtype=None, non_blocking=False) |
| :noindex: |
| |
| .. function:: to(dtype, non_blocking=False) |
| :noindex: |
| |
| .. function:: to(tensor, non_blocking=False) |
| :noindex: |
| |
| .. function:: to(memory_format=torch.channels_last) |
| :noindex: |
| |
| Its signature is similar to :meth:`torch.Tensor.to`, but only accepts |
| floating point or complex :attr:`dtype`\ s. In addition, this method will |
| only cast the floating point or complex parameters and buffers to :attr:`dtype` |
| (if given). The integral parameters and buffers will be moved |
| :attr:`device`, if that is given, but with dtypes unchanged. When |
| :attr:`non_blocking` is set, it tries to convert/move asynchronously |
| with respect to the host if possible, e.g., moving CPU Tensors with |
| pinned memory to CUDA devices. |
| |
| See below for examples. |
| |
| .. note:: |
| This method modifies the module in-place. |
| |
| Args: |
| device (:class:`torch.device`): the desired device of the parameters |
| and buffers in this module |
| dtype (:class:`torch.dtype`): the desired floating point or complex dtype of |
| the parameters and buffers in this module |
| tensor (torch.Tensor): Tensor whose dtype and device are the desired |
| dtype and device for all parameters and buffers in this module |
| memory_format (:class:`torch.memory_format`): the desired memory |
| format for 4D parameters and buffers in this module (keyword |
| only argument) |
| |
| Returns: |
| Module: self |
| |
| Examples:: |
| |
| >>> # xdoctest: +IGNORE_WANT("non-deterministic") |
| >>> linear = nn.Linear(2, 2) |
| >>> linear.weight |
| Parameter containing: |
| tensor([[ 0.1913, -0.3420], |
| [-0.5113, -0.2325]]) |
| >>> linear.to(torch.double) |
| Linear(in_features=2, out_features=2, bias=True) |
| >>> linear.weight |
| Parameter containing: |
| tensor([[ 0.1913, -0.3420], |
| [-0.5113, -0.2325]], dtype=torch.float64) |
| >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA1) |
| >>> gpu1 = torch.device("cuda:1") |
| >>> linear.to(gpu1, dtype=torch.half, non_blocking=True) |
| Linear(in_features=2, out_features=2, bias=True) |
| >>> linear.weight |
| Parameter containing: |
| tensor([[ 0.1914, -0.3420], |
| [-0.5112, -0.2324]], dtype=torch.float16, device='cuda:1') |
| >>> cpu = torch.device("cpu") |
| >>> linear.to(cpu) |
| Linear(in_features=2, out_features=2, bias=True) |
| >>> linear.weight |
| Parameter containing: |
| tensor([[ 0.1914, -0.3420], |
| [-0.5112, -0.2324]], dtype=torch.float16) |
| |
| >>> linear = nn.Linear(2, 2, bias=None).to(torch.cdouble) |
| >>> linear.weight |
| Parameter containing: |
| tensor([[ 0.3741+0.j, 0.2382+0.j], |
| [ 0.5593+0.j, -0.4443+0.j]], dtype=torch.complex128) |
| >>> linear(torch.ones(3, 2, dtype=torch.cdouble)) |
| tensor([[0.6122+0.j, 0.1150+0.j], |
| [0.6122+0.j, 0.1150+0.j], |
| [0.6122+0.j, 0.1150+0.j]], dtype=torch.complex128) |
| |
| """ |
| device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to( |
| *args, **kwargs |
| ) |
| |
| if dtype is not None: |
| if not (dtype.is_floating_point or dtype.is_complex): |
| raise TypeError( |
| "nn.Module.to only accepts floating point or complex " |
| f"dtypes, but got desired dtype={dtype}" |
| ) |
| if dtype.is_complex: |
| warnings.warn( |
| "Complex modules are a new feature under active development whose design may change, " |
| "and some modules might not work as expected when using complex tensors as parameters or buffers. " |
| "Please file an issue at https://github.com/pytorch/pytorch/issues/new?template=bug-report.yml " |
| "if a complex module does not work as expected." |
| ) |
| |
| def convert(t): |
| try: |
| if convert_to_format is not None and t.dim() in (4, 5): |
| return t.to( |
| device, |
| dtype if t.is_floating_point() or t.is_complex() else None, |
| non_blocking, |
| memory_format=convert_to_format, |
| ) |
| return t.to( |
| device, |
| dtype if t.is_floating_point() or t.is_complex() else None, |
| non_blocking, |
| ) |
| except NotImplementedError as e: |
| if str(e) == "Cannot copy out of meta tensor; no data!": |
| raise NotImplementedError( |
| f"{e} Please use torch.nn.Module.to_empty() instead of torch.nn.Module.to() " |
| f"when moving module from meta to a different device." |
| ) from None |
| else: |
| raise |
| |
| return self._apply(convert) |
| |
| def register_full_backward_pre_hook( |
| self, |
| hook: Callable[["Module", _grad_t], Union[None, _grad_t]], |
| prepend: bool = False, |
| ) -> RemovableHandle: |
| r"""Register a backward pre-hook on the module. |
| |
| The hook will be called every time the gradients for the module are computed. |
| The hook should have the following signature:: |
| |
| hook(module, grad_output) -> tuple[Tensor] or None |
| |
| The :attr:`grad_output` is a tuple. The hook should |
| not modify its arguments, but it can optionally return a new gradient with |
| respect to the output that will be used in place of :attr:`grad_output` in |
| subsequent computations. Entries in :attr:`grad_output` will be ``None`` for |
| all non-Tensor arguments. |
| |
| For technical reasons, when this hook is applied to a Module, its forward function will |
| receive a view of each Tensor passed to the Module. Similarly the caller will receive a view |
| of each Tensor returned by the Module's forward function. |
| |
| .. warning :: |
| Modifying inputs inplace is not allowed when using backward hooks and |
| will raise an error. |
| |
| Args: |
| hook (Callable): The user-defined hook to be registered. |
| prepend (bool): If true, the provided ``hook`` will be fired before |
| all existing ``backward_pre`` hooks on this |
| :class:`torch.nn.modules.Module`. Otherwise, the provided |
| ``hook`` will be fired after all existing ``backward_pre`` hooks |
| on this :class:`torch.nn.modules.Module`. Note that global |
| ``backward_pre`` hooks registered with |
| :func:`register_module_full_backward_pre_hook` will fire before |
| all hooks registered by this method. |
| |
| Returns: |
| :class:`torch.utils.hooks.RemovableHandle`: |
| a handle that can be used to remove the added hook by calling |
| ``handle.remove()`` |
| |
| """ |
| handle = RemovableHandle(self._backward_pre_hooks) |
| self._backward_pre_hooks[handle.id] = hook |
| if prepend: |
| self._backward_pre_hooks.move_to_end(handle.id, last=False) # type: ignore[attr-defined] |
| return handle |
| |
| def register_backward_hook( |
| self, hook: Callable[["Module", _grad_t, _grad_t], Union[None, _grad_t]] |
| ) -> RemovableHandle: |
| r"""Register a backward hook on the module. |
| |
| This function is deprecated in favor of :meth:`~torch.nn.Module.register_full_backward_hook` and |
| the behavior of this function will change in future versions. |
| |
| Returns: |
| :class:`torch.utils.hooks.RemovableHandle`: |
| a handle that can be used to remove the added hook by calling |
| ``handle.remove()`` |
| |
| """ |
| if self._is_full_backward_hook is True: |
| raise RuntimeError( |
| "Cannot use both regular backward hooks and full backward hooks on a " |
| "single Module. Please use only one of them." |
| ) |
| |
| self._is_full_backward_hook = False |
| |
| handle = RemovableHandle(self._backward_hooks) |
| self._backward_hooks[handle.id] = hook |
| return handle |
| |
| def register_full_backward_hook( |
| self, |
| hook: Callable[["Module", _grad_t, _grad_t], Union[None, _grad_t]], |
| prepend: bool = False, |
| ) -> RemovableHandle: |
| r"""Register a backward hook on the module. |
| |
| The hook will be called every time the gradients with respect to a module |
| are computed, i.e. the hook will execute if and only if the gradients with |
| respect to module outputs are computed. The hook should have the following |
| signature:: |
| |
| hook(module, grad_input, grad_output) -> tuple(Tensor) or None |
| |
| The :attr:`grad_input` and :attr:`grad_output` are tuples that contain the gradients |
| with respect to the inputs and outputs respectively. The hook should |
| not modify its arguments, but it can optionally return a new gradient with |
| respect to the input that will be used in place of :attr:`grad_input` in |
| subsequent computations. :attr:`grad_input` will only correspond to the inputs given |
| as positional arguments and all kwarg arguments are ignored. Entries |
| in :attr:`grad_input` and :attr:`grad_output` will be ``None`` for all non-Tensor |
| arguments. |
| |
| For technical reasons, when this hook is applied to a Module, its forward function will |
| receive a view of each Tensor passed to the Module. Similarly the caller will receive a view |
| of each Tensor returned by the Module's forward function. |
| |
| .. warning :: |
| Modifying inputs or outputs inplace is not allowed when using backward hooks and |
| will raise an error. |
| |
| Args: |
| hook (Callable): The user-defined hook to be registered. |
| prepend (bool): If true, the provided ``hook`` will be fired before |
| all existing ``backward`` hooks on this |
| :class:`torch.nn.modules.Module`. Otherwise, the provided |
| ``hook`` will be fired after all existing ``backward`` hooks on |
| this :class:`torch.nn.modules.Module`. Note that global |
| ``backward`` hooks registered with |
| :func:`register_module_full_backward_hook` will fire before |
| all hooks registered by this method. |
| |
| Returns: |
| :class:`torch.utils.hooks.RemovableHandle`: |
| a handle that can be used to remove the added hook by calling |
| ``handle.remove()`` |
| |
| """ |
| if self._is_full_backward_hook is False: |
| raise RuntimeError( |
| "Cannot use both regular backward hooks and full backward hooks on a " |
| "single Module. Please use only one of them." |
| ) |
| |
| self._is_full_backward_hook = True |
| |
| handle = RemovableHandle(self._backward_hooks) |
| self._backward_hooks[handle.id] = hook |
| if prepend: |
| self._backward_hooks.move_to_end(handle.id, last=False) # type: ignore[attr-defined] |
| return handle |
| |
| def _get_backward_hooks(self): |
| r"""Return the backward hooks for use in the call function. |
| |
| It returns two lists, one with the full backward hooks and one with the non-full |
| backward hooks. |
| """ |
| full_backward_hooks: List[Callable] = [] |
| if _global_is_full_backward_hook is True: |
| full_backward_hooks += _global_backward_hooks.values() |
| if self._is_full_backward_hook is True: |
| full_backward_hooks += self._backward_hooks.values() |
| |
| non_full_backward_hooks: List[Callable] = [] |
| if _global_is_full_backward_hook is False: |
| non_full_backward_hooks += _global_backward_hooks.values() |
| if self._is_full_backward_hook is False: |
| non_full_backward_hooks += self._backward_hooks.values() |
| |
| return full_backward_hooks, non_full_backward_hooks |
| |
| def _get_backward_pre_hooks(self): |
| backward_pre_hooks: List[Callable] = [] |
| backward_pre_hooks += _global_backward_pre_hooks.values() |
| backward_pre_hooks += self._backward_pre_hooks.values() |
| |
| return backward_pre_hooks |
| |
| def _maybe_warn_non_full_backward_hook(self, inputs, result, grad_fn): |
| if not isinstance(result, torch.Tensor): |
| if not ( |
| isinstance(result, tuple) |
| and all(isinstance(r, torch.Tensor) for r in result) |
| ): |
| warnings.warn( |
| "Using non-full backward hooks on a Module that does not return a " |
| "single Tensor or a tuple of Tensors is deprecated and will be removed " |
| "in future versions. This hook will be missing some of the grad_output. " |
| "Please use register_full_backward_hook to get the documented behavior.", |
| FutureWarning, |
| stacklevel=2, |
| ) |
| return |
| else: |
| result = (result,) |
| |
| if not isinstance(inputs, torch.Tensor): |
| if not ( |
| isinstance(inputs, tuple) |
| and all(isinstance(i, torch.Tensor) for i in inputs) |
| ): |
| warnings.warn( |
| "Using non-full backward hooks on a Module that does not take as input a " |
| "single Tensor or a tuple of Tensors is deprecated and will be removed " |
| "in future versions. This hook will be missing some of the grad_input. " |
| "Please use register_full_backward_hook to get the documented behavior.", |
| FutureWarning, |
| stacklevel=2, |
| ) |
| return |
| else: |
| inputs = (inputs,) |
| |
| # At this point we are sure that inputs and result are tuple of Tensors |
| out_grad_fn = {r.grad_fn for r in result if r.grad_fn is not None} |
| if len(out_grad_fn) == 0 or ( |
| len(out_grad_fn) == 1 and grad_fn not in out_grad_fn |
| ): |
| warnings.warn( |
| "Using a non-full backward hook when outputs are nested in python data structure " |
| "is deprecated and will be removed in future versions. This hook will be missing " |
| "some grad_output.", |
| FutureWarning, |
| stacklevel=2, |
| ) |
| elif len(out_grad_fn) > 1: |
| warnings.warn( |
| "Using a non-full backward hook when outputs are generated by different autograd Nodes " |
| "is deprecated and will be removed in future versions. This hook will be missing " |
| "some grad_output. Please use register_full_backward_hook to get the documented behavior.", |
| FutureWarning, |
| stacklevel=2, |
| ) |
| else: |
| # At this point the grad_output part of the hook will most likely be correct |
| inputs_grad_fn = {i.grad_fn for i in inputs if i.grad_fn is not None} |
| |
| next_functions = {n[0] for n in grad_fn.next_functions} |
| |
| if inputs_grad_fn != next_functions: |
| warnings.warn( |
| "Using a non-full backward hook when the forward contains multiple autograd Nodes " |
| "is deprecated and will be removed in future versions. This hook will be missing " |
| "some grad_input. Please use register_full_backward_hook to get the documented " |
| "behavior.", |
| FutureWarning, |
| stacklevel=2, |
| ) |
| |
| def register_forward_pre_hook( |
| self, |
| hook: Union[ |
| Callable[[T, Tuple[Any, ...]], Optional[Any]], |
| Callable[ |
| [T, Tuple[Any, ...], Dict[str, Any]], |
| Optional[Tuple[Any, Dict[str, Any]]], |
| ], |
| ], |
| *, |
| prepend: bool = False, |
| with_kwargs: bool = False, |
| ) -> RemovableHandle: |
| r"""Register a forward pre-hook on the module. |
| |
| The hook will be called every time before :func:`forward` is invoked. |
| |
| |
| If ``with_kwargs`` is false or not specified, the input contains only |
| the positional arguments given to the module. Keyword arguments won't be |
| passed to the hooks and only to the ``forward``. The hook can modify the |
| input. User can either return a tuple or a single modified value in the |
| hook. We will wrap the value into a tuple if a single value is returned |
| (unless that value is already a tuple). The hook should have the |
| following signature:: |
| |
| hook(module, args) -> None or modified input |
| |
| If ``with_kwargs`` is true, the forward pre-hook will be passed the |
| kwargs given to the forward function. And if the hook modifies the |
| input, both the args and kwargs should be returned. The hook should have |
| the following signature:: |
| |
| hook(module, args, kwargs) -> None or a tuple of modified input and kwargs |
| |
| Args: |
| hook (Callable): The user defined hook to be registered. |
| prepend (bool): If true, the provided ``hook`` will be fired before |
| all existing ``forward_pre`` hooks on this |
| :class:`torch.nn.modules.Module`. Otherwise, the provided |
| ``hook`` will be fired after all existing ``forward_pre`` hooks |
| on this :class:`torch.nn.modules.Module`. Note that global |
| ``forward_pre`` hooks registered with |
| :func:`register_module_forward_pre_hook` will fire before all |
| hooks registered by this method. |
| Default: ``False`` |
| with_kwargs (bool): If true, the ``hook`` will be passed the kwargs |
| given to the forward function. |
| Default: ``False`` |
| |
| Returns: |
| :class:`torch.utils.hooks.RemovableHandle`: |
| a handle that can be used to remove the added hook by calling |
| ``handle.remove()`` |
| """ |
| handle = RemovableHandle( |
| self._forward_pre_hooks, extra_dict=self._forward_pre_hooks_with_kwargs |
| ) |
| self._forward_pre_hooks[handle.id] = hook |
| if with_kwargs: |
| self._forward_pre_hooks_with_kwargs[handle.id] = True |
| |
| if prepend: |
| self._forward_pre_hooks.move_to_end(handle.id, last=False) # type: ignore[attr-defined] |
| return handle |
| |
| def register_forward_hook( |
| self, |
| hook: Union[ |
| Callable[[T, Tuple[Any, ...], Any], Optional[Any]], |
| Callable[[T, Tuple[Any, ...], Dict[str, Any], Any], Optional[Any]], |
| ], |
| *, |
| prepend: bool = False, |
| with_kwargs: bool = False, |
| always_call: bool = False, |
| ) -> RemovableHandle: |
| r"""Register a forward hook on the module. |
| |
| The hook will be called every time after :func:`forward` has computed an output. |
| |
| If ``with_kwargs`` is ``False`` or not specified, the input contains only |
| the positional arguments given to the module. Keyword arguments won't be |
| passed to the hooks and only to the ``forward``. The hook can modify the |
| output. It can modify the input inplace but it will not have effect on |
| forward since this is called after :func:`forward` is called. The hook |
| should have the following signature:: |
| |
| hook(module, args, output) -> None or modified output |
| |
| If ``with_kwargs`` is ``True``, the forward hook will be passed the |
| ``kwargs`` given to the forward function and be expected to return the |
| output possibly modified. The hook should have the following signature:: |
| |
| hook(module, args, kwargs, output) -> None or modified output |
| |
| Args: |
| hook (Callable): The user defined hook to be registered. |
| prepend (bool): If ``True``, the provided ``hook`` will be fired |
| before all existing ``forward`` hooks on this |
| :class:`torch.nn.modules.Module`. Otherwise, the provided |
| ``hook`` will be fired after all existing ``forward`` hooks on |
| this :class:`torch.nn.modules.Module`. Note that global |
| ``forward`` hooks registered with |
| :func:`register_module_forward_hook` will fire before all hooks |
| registered by this method. |
| Default: ``False`` |
| with_kwargs (bool): If ``True``, the ``hook`` will be passed the |
| kwargs given to the forward function. |
| Default: ``False`` |
| always_call (bool): If ``True`` the ``hook`` will be run regardless of |
| whether an exception is raised while calling the Module. |
| Default: ``False`` |
| |
| Returns: |
| :class:`torch.utils.hooks.RemovableHandle`: |
| a handle that can be used to remove the added hook by calling |
| ``handle.remove()`` |
| """ |
| handle = RemovableHandle( |
| self._forward_hooks, |
| extra_dict=[ |
| self._forward_hooks_with_kwargs, |
| self._forward_hooks_always_called, |
| ], |
| ) |
| self._forward_hooks[handle.id] = hook |
| if with_kwargs: |
| self._forward_hooks_with_kwargs[handle.id] = True |
| if always_call: |
| self._forward_hooks_always_called[handle.id] = True |
| if prepend: |
| self._forward_hooks.move_to_end(handle.id, last=False) # type: ignore[attr-defined] |
| return handle |
| |
| def _slow_forward(self, *input, **kwargs): |
| tracing_state = torch._C._get_tracing_state() |
| if not tracing_state or isinstance(self.forward, torch._C.ScriptMethod): |
| return self.forward(*input, **kwargs) |
| recording_scopes = torch.jit._trace._trace_module_map is not None |
| if recording_scopes: |
| # type ignore was added because at this point one knows that |
| # torch.jit._trace._trace_module_map is not Optional and has type Dict[Any, Any] |
| name = torch.jit._trace._trace_module_map[self] if self in torch.jit._trace._trace_module_map else None # type: ignore[index, operator] # noqa: B950 |
| if name: |
| tracing_state.push_scope(name) |
| else: |
| recording_scopes = False |
| try: |
| result = self.forward(*input, **kwargs) |
| finally: |
| if recording_scopes: |
| tracing_state.pop_scope() |
| return result |
| |
| def _wrapped_call_impl(self, *args, **kwargs): |
| if self._compiled_call_impl is not None: |
| return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc] |
| else: |
| return self._call_impl(*args, **kwargs) |
| |
| # torchrec tests the code consistency with the following code |
| # fmt: off |
| def _call_impl(self, *args, **kwargs): |
| forward_call = (self._slow_forward if torch._C._get_tracing_state() else self.forward) |
| # If we don't have any hooks, we want to skip the rest of the logic in |
| # this function, and just call forward. |
| if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks |
| or _global_backward_pre_hooks or _global_backward_hooks |
| or _global_forward_hooks or _global_forward_pre_hooks): |
| return forward_call(*args, **kwargs) |
| |
| try: |
| result = None |
| called_always_called_hooks = set() |
| |
| full_backward_hooks, non_full_backward_hooks = [], [] |
| backward_pre_hooks = [] |
| if self._backward_pre_hooks or _global_backward_pre_hooks: |
| backward_pre_hooks = self._get_backward_pre_hooks() |
| |
| if self._backward_hooks or _global_backward_hooks: |
| full_backward_hooks, non_full_backward_hooks = self._get_backward_hooks() |
| |
| if _global_forward_pre_hooks or self._forward_pre_hooks: |
| for hook_id, hook in ( |
| *_global_forward_pre_hooks.items(), |
| *self._forward_pre_hooks.items(), |
| ): |
| if hook_id in self._forward_pre_hooks_with_kwargs: |
| args_kwargs_result = hook(self, args, kwargs) # type: ignore[misc] |
| if args_kwargs_result is not None: |
| if isinstance(args_kwargs_result, tuple) and len(args_kwargs_result) == 2: |
| args, kwargs = args_kwargs_result |
| else: |
| raise RuntimeError( |
| "forward pre-hook must return None or a tuple " |
| f"of (new_args, new_kwargs), but got {args_kwargs_result}." |
| ) |
| else: |
| args_result = hook(self, args) |
| if args_result is not None: |
| if not isinstance(args_result, tuple): |
| args_result = (args_result,) |
| args = args_result |
| |
| bw_hook = None |
| if full_backward_hooks or backward_pre_hooks: |
| bw_hook = BackwardHook(self, full_backward_hooks, backward_pre_hooks) |
| args = bw_hook.setup_input_hook(args) |
| |
| result = forward_call(*args, **kwargs) |
| if _global_forward_hooks or self._forward_hooks: |
| for hook_id, hook in ( |
| *_global_forward_hooks.items(), |
| *self._forward_hooks.items(), |
| ): |
| # mark that always called hook is run |
| if hook_id in self._forward_hooks_always_called or hook_id in _global_forward_hooks_always_called: |
| called_always_called_hooks.add(hook_id) |
| |
| if hook_id in self._forward_hooks_with_kwargs: |
| hook_result = hook(self, args, kwargs, result) |
| else: |
| hook_result = hook(self, args, result) |
| |
| if hook_result is not None: |
| result = hook_result |
| |
| if bw_hook: |
| if not isinstance(result, (torch.Tensor, tuple)): |
| warnings.warn("For backward hooks to be called," |
| " module output should be a Tensor or a tuple of Tensors" |
| f" but received {type(result)}") |
| result = bw_hook.setup_output_hook(result) |
| |
| # Handle the non-full backward hooks |
| if non_full_backward_hooks: |
| var = result |
| while not isinstance(var, torch.Tensor): |
| if isinstance(var, dict): |
| var = next(v for v in var.values() if isinstance(v, torch.Tensor)) |
| else: |
| var = var[0] |
| grad_fn = var.grad_fn |
| if grad_fn is not None: |
| for hook in non_full_backward_hooks: |
| grad_fn.register_hook(_WrappedHook(hook, self)) |
| self._maybe_warn_non_full_backward_hook(args, result, grad_fn) |
| |
| return result |
| |
| except Exception: |
| # run always called hooks if they have not already been run |
| # For now only forward hooks have the always_call option but perhaps |
| # this functionality should be added to full backward hooks as well. |
| for hook_id, hook in _global_forward_hooks.items(): |
| if hook_id in _global_forward_hooks_always_called and hook_id not in called_always_called_hooks: # type: ignore[possibly-undefined] |
| try: |
| hook_result = hook(self, args, result) # type: ignore[possibly-undefined] |
| if hook_result is not None: |
| result = hook_result |
| except Exception as e: |
| warnings.warn("global module forward hook with ``always_call=True`` raised an exception " |
| f"that was silenced as another error was raised in forward: {str(e)}") |
| continue |
| |
| for hook_id, hook in self._forward_hooks.items(): |
| if hook_id in self._forward_hooks_always_called and hook_id not in called_always_called_hooks: # type: ignore[possibly-undefined] |
| try: |
| if hook_id in self._forward_hooks_with_kwargs: |
| hook_result = hook(self, args, kwargs, result) # type: ignore[possibly-undefined] |
| else: |
| hook_result = hook(self, args, result) # type: ignore[possibly-undefined] |
| if hook_result is not None: |
| result = hook_result |
| except Exception as e: |
| warnings.warn("module forward hook with ``always_call=True`` raised an exception " |
| f"that was silenced as another error was raised in forward: {str(e)}") |
| continue |
| # raise exception raised in try block |
| raise |
| # fmt: on |
| |
| __call__: Callable[..., Any] = _wrapped_call_impl |
| |
| def __getstate__(self): |
| state = self.__dict__.copy() |
| state.pop("_compiled_call_impl", None) |
| return state |
| |
| def __setstate__(self, state): |
| self.__dict__.update(state) |
| |
| # Support loading old checkpoints that don't have the following attrs: |
| if "_forward_pre_hooks" not in self.__dict__: |
| self._forward_pre_hooks = OrderedDict() |
| if "_forward_pre_hooks_with_kwargs" not in self.__dict__: |
| self._forward_pre_hooks_with_kwargs = OrderedDict() |
| if "_forward_hooks_with_kwargs" not in self.__dict__: |
| self._forward_hooks_with_kwargs = OrderedDict() |
| if "_forward_hooks_always_called" not in self.__dict__: |
| self._forward_hooks_always_called = OrderedDict() |
| if "_state_dict_hooks" not in self.__dict__: |
| self._state_dict_hooks = OrderedDict() |
| if "_state_dict_pre_hooks" not in self.__dict__: |
| self._state_dict_pre_hooks = OrderedDict() |
| if "_load_state_dict_pre_hooks" not in self.__dict__: |
| self._load_state_dict_pre_hooks = OrderedDict() |
| if "_load_state_dict_post_hooks" not in self.__dict__: |
| self._load_state_dict_post_hooks = OrderedDict() |
| if "_non_persistent_buffers_set" not in self.__dict__: |
| self._non_persistent_buffers_set = set() |
| if "_is_full_backward_hook" not in self.__dict__: |
| self._is_full_backward_hook = None |
| if "_backward_pre_hooks" not in self.__dict__: |
| self._backward_pre_hooks = OrderedDict() |
| |
| # On the return type: |
| # We choose to return `Any` in the `__getattr__` type signature instead of a more strict `Union[Tensor, Module]`. |
| # This is done for better interop with various type checkers for the end users. |
| # Having a stricter return type doesn't play nicely with `register_buffer()` and forces |
| # people to excessively use type-ignores, asserts, casts, etc. |
| # See full discussion on the problems with returning `Union` here |
| # https://github.com/microsoft/pyright/issues/4213 |
| def __getattr__(self, name: str) -> Any: |
| if "_parameters" in self.__dict__: |
| _parameters = self.__dict__["_parameters"] |
| if name in _parameters: |
| return _parameters[name] |
| if "_buffers" in self.__dict__: |
| _buffers = self.__dict__["_buffers"] |
| if name in _buffers: |
| return _buffers[name] |
| if "_modules" in self.__dict__: |
| modules = self.__dict__["_modules"] |
| if name in modules: |
| return modules[name] |
| raise AttributeError( |
| f"'{type(self).__name__}' object has no attribute '{name}'" |
| ) |
| |
| def __setattr__(self, name: str, value: Union[Tensor, "Module"]) -> None: |
| def remove_from(*dicts_or_sets): |
| for d in dicts_or_sets: |
| if name in d: |
| if isinstance(d, dict): |
| del d[name] |
| else: |
| d.discard(name) |
| |
| params = self.__dict__.get("_parameters") |
| if isinstance(value, Parameter): |
| if params is None: |
| raise AttributeError( |
| "cannot assign parameters before Module.__init__() call" |
| ) |
| remove_from( |
| self.__dict__, |
| self._buffers, |
| self._modules, |
| self._non_persistent_buffers_set, |
| ) |
| self.register_parameter(name, value) |
| elif params is not None and name in params: |
| if value is not None: |
| raise TypeError( |
| f"cannot assign '{torch.typename(value)}' as parameter '{name}' " |
| "(torch.nn.Parameter or None expected)" |
| ) |
| self.register_parameter(name, value) |
| else: |
| modules = self.__dict__.get("_modules") |
| if isinstance(value, Module): |
| if modules is None: |
| raise AttributeError( |
| "cannot assign module before Module.__init__() call" |
| ) |
| remove_from( |
| self.__dict__, |
| self._parameters, |
| self._buffers, |
| self._non_persistent_buffers_set, |
| ) |
| for hook in _global_module_registration_hooks.values(): |
| output = hook(self, name, value) |
| if output is not None: |
| value = output |
| modules[name] = value |
| elif modules is not None and name in modules: |
| if value is not None: |
| raise TypeError( |
| f"cannot assign '{torch.typename(value)}' as child module '{name}' " |
| "(torch.nn.Module or None expected)" |
| ) |
| for hook in _global_module_registration_hooks.values(): |
| output = hook(self, name, value) |
| if output is not None: |
| value = output |
| modules[name] = value |
| else: |
| buffers = self.__dict__.get("_buffers") |
| if buffers is not None and name in buffers: |
| if value is not None and not isinstance(value, torch.Tensor): |
| raise TypeError( |
| f"cannot assign '{torch.typename(value)}' as buffer '{name}' " |
| "(torch.Tensor or None expected)" |
| ) |
| for hook in _global_buffer_registration_hooks.values(): |
| output = hook(self, name, value) |
| if output is not None: |
| value = output |
| buffers[name] = value |
| else: |
| super().__setattr__(name, value) |
| |
| def __delattr__(self, name): |
| if name in self._parameters: |
| del self._parameters[name] |
| elif name in self._buffers: |
| del self._buffers[name] |
| self._non_persistent_buffers_set.discard(name) |
| elif name in self._modules: |
| del self._modules[name] |
| else: |
| super().__delattr__(name) |
| |
| def _register_state_dict_hook(self, hook): |
| r"""Register a state-dict hook. |
| |
| These hooks will be called with arguments: `self`, `state_dict`, |
| `prefix`, `local_metadata`, after the `state_dict` of `self` is set. |
| Note that only parameters and buffers of `self` or its children are |
| guaranteed to exist in `state_dict`. The hooks may modify `state_dict` |
| inplace or return a new one. |
| """ |
| handle = RemovableHandle(self._state_dict_hooks) |
| self._state_dict_hooks[handle.id] = hook |
| return handle |
| |
| def register_state_dict_pre_hook(self, hook): |
| r"""Register a pre-hook for the :meth:`~torch.nn.Module.state_dict` method. |
| |
| These hooks will be called with arguments: ``self``, ``prefix``, |
| and ``keep_vars`` before calling ``state_dict`` on ``self``. The registered |
| hooks can be used to perform pre-processing before the ``state_dict`` |
| call is made. |
| """ |
| handle = RemovableHandle(self._state_dict_pre_hooks) |
| self._state_dict_pre_hooks[handle.id] = hook |
| return handle |
| |
| def _save_to_state_dict(self, destination, prefix, keep_vars): |
| r"""Save module state to the `destination` dictionary. |
| |
| The `destination` dictionary will contain the state |
| of the module, but not its descendants. This is called on every |
| submodule in :meth:`~torch.nn.Module.state_dict`. |
| |
| In rare cases, subclasses can achieve class-specific behavior by |
| overriding this method with custom logic. |
| |
| Args: |
| destination (dict): a dict where state will be stored |
| prefix (str): the prefix for parameters and buffers used in this |
| module |
| """ |
| for name, param in self._parameters.items(): |
| if param is not None: |
| destination[prefix + name] = param if keep_vars else param.detach() |
| for name, buf in self._buffers.items(): |
| if buf is not None and name not in self._non_persistent_buffers_set: |
| destination[prefix + name] = buf if keep_vars else buf.detach() |
| extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX |
| if ( |
| getattr(self.__class__, "get_extra_state", Module.get_extra_state) |
| is not Module.get_extra_state |
| ): |
| destination[extra_state_key] = self.get_extra_state() |
| |
| # The user can pass an optional arbitrary mappable object to `state_dict`, in which case `state_dict` returns |
| # back that same object. But if they pass nothing, an `OrderedDict` is created and returned. |
| T_destination = TypeVar("T_destination", bound=Dict[str, Any]) |
| |
| @overload |
| def state_dict( |
| self, *, destination: T_destination, prefix: str = ..., keep_vars: bool = ... |
| ) -> T_destination: |
| ... |
| |
| @overload |
| def state_dict(self, *, prefix: str = ..., keep_vars: bool = ...) -> Dict[str, Any]: |
| ... |
| |
| # TODO: Change `*args` to `*` and remove the corresponding warning in docs when BC allows. |
| # Also remove the logic for arg parsing together. |
| def state_dict(self, *args, destination=None, prefix="", keep_vars=False): |
| r"""Return a dictionary containing references to the whole state of the module. |
| |
| Both parameters and persistent buffers (e.g. running averages) are |
| included. Keys are corresponding parameter and buffer names. |
| Parameters and buffers set to ``None`` are not included. |
| |
| .. note:: |
| The returned object is a shallow copy. It contains references |
| to the module's parameters and buffers. |
| |
| .. warning:: |
| Currently ``state_dict()`` also accepts positional arguments for |
| ``destination``, ``prefix`` and ``keep_vars`` in order. However, |
| this is being deprecated and keyword arguments will be enforced in |
| future releases. |
| |
| .. warning:: |
| Please avoid the use of argument ``destination`` as it is not |
| designed for end-users. |
| |
| Args: |
| destination (dict, optional): If provided, the state of module will |
| be updated into the dict and the same object is returned. |
| Otherwise, an ``OrderedDict`` will be created and returned. |
| Default: ``None``. |
| prefix (str, optional): a prefix added to parameter and buffer |
| names to compose the keys in state_dict. Default: ``''``. |
| keep_vars (bool, optional): by default the :class:`~torch.Tensor` s |
| returned in the state dict are detached from autograd. If it's |
| set to ``True``, detaching will not be performed. |
| Default: ``False``. |
| |
| Returns: |
| dict: |
| a dictionary containing a whole state of the module |
| |
| Example:: |
| |
| >>> # xdoctest: +SKIP("undefined vars") |
| >>> module.state_dict().keys() |
| ['bias', 'weight'] |
| |
| """ |
| # TODO: Remove `args` and the parsing logic when BC allows. |
| if len(args) > 0: |
| # DeprecationWarning is ignored by default |
| warnings.warn( |
| "Positional args are being deprecated, use kwargs instead. Refer to " |
| "https://pytorch.org/docs/main/generated/torch.nn.Module.html#torch.nn.Module.state_dict" |
| " for details.", |
| FutureWarning, |
| stacklevel=2, |
| ) |
| if destination is None: |
| destination = args[0] |
| if len(args) > 1 and prefix == "": |
| prefix = args[1] |
| if len(args) > 2 and keep_vars is False: |
| keep_vars = args[2] |
| |
| if destination is None: |
| destination = OrderedDict() |
| destination._metadata = OrderedDict() |
| |
| local_metadata = dict(version=self._version) |
| if hasattr(destination, "_metadata"): |
| destination._metadata[prefix[:-1]] = local_metadata |
| |
| for hook in self._state_dict_pre_hooks.values(): |
| hook(self, prefix, keep_vars) |
| self._save_to_state_dict(destination, prefix, keep_vars) |
| for name, module in self._modules.items(): |
| if module is not None: |
| module.state_dict( |
| destination=destination, |
| prefix=prefix + name + ".", |
| keep_vars=keep_vars, |
| ) |
| for hook in self._state_dict_hooks.values(): |
| hook_result = hook(self, destination, prefix, local_metadata) |
| if hook_result is not None: |
| destination = hook_result |
| return destination |
| |
| def _register_load_state_dict_pre_hook(self, hook, with_module=False): |
| r"""Register a pre-hook for the :meth:`~torch.nn.Module.load_state_dict` method. |
| |
| These hooks will be called with arguments: `state_dict`, `prefix`, |
| `local_metadata`, `strict`, `missing_keys`, `unexpected_keys`, |
| `error_msgs`, before loading `state_dict` into `self`. These arguments |
| are exactly the same as those of `_load_from_state_dict`. |
| |
| If ``with_module`` is ``True``, then the first argument to the hook is |
| an instance of the module. |
| |
| Arguments: |
| hook (Callable): Callable hook that will be invoked before |
| loading the state dict. |
| with_module (bool, optional): Whether or not to pass the module |
| instance to the hook as the first parameter. |
| """ |
| handle = RemovableHandle(self._load_state_dict_pre_hooks) |
| self._load_state_dict_pre_hooks[handle.id] = _WrappedHook( |
| hook, self if with_module else None |
| ) |
| return handle |
| |
| def register_load_state_dict_post_hook(self, hook): |
| r"""Register a post hook to be run after module's ``load_state_dict`` is called. |
| |
| It should have the following signature:: |
| hook(module, incompatible_keys) -> None |
| |
| The ``module`` argument is the current module that this hook is registered |
| on, and the ``incompatible_keys`` argument is a ``NamedTuple`` consisting |
| of attributes ``missing_keys`` and ``unexpected_keys``. ``missing_keys`` |
| is a ``list`` of ``str`` containing the missing keys and |
| ``unexpected_keys`` is a ``list`` of ``str`` containing the unexpected keys. |
| |
| The given incompatible_keys can be modified inplace if needed. |
| |
| Note that the checks performed when calling :func:`load_state_dict` with |
| ``strict=True`` are affected by modifications the hook makes to |
| ``missing_keys`` or ``unexpected_keys``, as expected. Additions to either |
| set of keys will result in an error being thrown when ``strict=True``, and |
| clearing out both missing and unexpected keys will avoid an error. |
| |
| Returns: |
| :class:`torch.utils.hooks.RemovableHandle`: |
| a handle that can be used to remove the added hook by calling |
| ``handle.remove()`` |
| """ |
| handle = RemovableHandle(self._load_state_dict_post_hooks) |
| self._load_state_dict_post_hooks[handle.id] = hook |
| return handle |
| |
| def _load_from_state_dict( |
| self, |
| state_dict, |
| prefix, |
| local_metadata, |
| strict, |
| missing_keys, |
| unexpected_keys, |
| error_msgs, |
| ): |
| r"""Copy parameters and buffers from :attr:`state_dict` into only this module, but not its descendants. |
| |
| This is called on every submodule |
| in :meth:`~torch.nn.Module.load_state_dict`. Metadata saved for this |
| module in input :attr:`state_dict` is provided as :attr:`local_metadata`. |
| For state dicts without metadata, :attr:`local_metadata` is empty. |
| Subclasses can achieve class-specific backward compatible loading using |
| the version number at `local_metadata.get("version", None)`. |
| Additionally, :attr:`local_metadata` can also contain the key |
| `assign_to_params_buffers` that indicates whether keys should be |
| assigned their corresponding tensor in the state_dict. |
| |
| .. note:: |
| :attr:`state_dict` is not the same object as the input |
| :attr:`state_dict` to :meth:`~torch.nn.Module.load_state_dict`. So |
| it can be modified. |
| |
| Args: |
| state_dict (dict): a dict containing parameters and |
| persistent buffers. |
| prefix (str): the prefix for parameters and buffers used in this |
| module |
| local_metadata (dict): a dict containing the metadata for this module. |
| See |
| strict (bool): whether to strictly enforce that the keys in |
| :attr:`state_dict` with :attr:`prefix` match the names of |
| parameters and buffers in this module |
| missing_keys (list of str): if ``strict=True``, add missing keys to |
| this list |
| unexpected_keys (list of str): if ``strict=True``, add unexpected |
| keys to this list |
| error_msgs (list of str): error messages should be added to this |
| list, and will be reported together in |
| :meth:`~torch.nn.Module.load_state_dict` |
| """ |
| for hook in self._load_state_dict_pre_hooks.values(): |
| hook( |
| state_dict, |
| prefix, |
| local_metadata, |
| strict, |
| missing_keys, |
| unexpected_keys, |
| error_msgs, |
| ) |
| |
| persistent_buffers = { |
| k: v |
| for k, v in self._buffers.items() |
| if k not in self._non_persistent_buffers_set |
| } |
| local_name_params = itertools.chain( |
| self._parameters.items(), persistent_buffers.items() |
| ) |
| local_state = {k: v for k, v in local_name_params if v is not None} |
| assign_to_params_buffers = local_metadata.get("assign_to_params_buffers", False) |
| use_swap_tensors = torch.__future__.get_swap_module_params_on_conversion() |
| |
| for name, param in local_state.items(): |
| key = prefix + name |
| if key in state_dict: |
| input_param = state_dict[key] |
| if not torch.overrides.is_tensor_like(input_param): |
| error_msgs.append( |
| f'While copying the parameter named "{key}", ' |
| "expected torch.Tensor or Tensor-like object from checkpoint but " |
| f"received {type(input_param)}" |
| ) |
| continue |
| |
| # This is used to avoid copying uninitialized parameters into |
| # non-lazy modules, since they dont have the hook to do the checks |
| # in such case, it will error when accessing the .shape attribute. |
| is_param_lazy = torch.nn.parameter.is_lazy(param) |
| # Backward compatibility: loading 1-dim tensor from 0.3.* to version 0.4+ |
| if ( |
| not is_param_lazy |
| and len(param.shape) == 0 |
| and len(input_param.shape) == 1 |
| ): |
| input_param = input_param[0] |
| |
| if not is_param_lazy and input_param.shape != param.shape: |
| # local shape should match the one in checkpoint |
| error_msgs.append( |
| f"size mismatch for {key}: copying a param with shape {input_param.shape} from checkpoint, " |
| f"the shape in current model is {param.shape}." |
| ) |
| continue |
| |
| if ( |
| param.is_meta |
| and not input_param.is_meta |
| and not assign_to_params_buffers |
| ): |
| warnings.warn( |
| f"for {key}: copying from a non-meta parameter in the checkpoint to a meta " |
| "parameter in the current model, which is a no-op. (Did you mean to " |
| "pass `assign=True` to assign items in the state dictionary to their " |
| "corresponding key in the module instead of copying them in place?)" |
| ) |
| |
| try: |
| with torch.no_grad(): |
| if use_swap_tensors: |
| new_input_param = param.module_load( |
| input_param, assign=assign_to_params_buffers |
| ) |
| if id(new_input_param) == id(input_param) or id( |
| new_input_param |
| ) == id(param): |
| raise RuntimeError( |
| "module_load returned one of self or other, please .detach() " |
| "the result if returning one of the inputs in module_load" |
| ) |
| if isinstance(param, torch.nn.Parameter): |
| if not isinstance(new_input_param, torch.nn.Parameter): |
| new_input_param = torch.nn.Parameter( |
| new_input_param, |
| requires_grad=param.requires_grad, |
| ) |
| else: |
| new_input_param.requires_grad_(param.requires_grad) |
| torch.utils.swap_tensors(param, new_input_param) |
| del new_input_param |
| elif assign_to_params_buffers: |
| # Shape checks are already done above |
| if isinstance(param, torch.nn.Parameter): |
| if not isinstance(input_param, torch.nn.Parameter): |
| input_param = torch.nn.Parameter( |
| input_param, requires_grad=param.requires_grad |
| ) |
| else: |
| input_param.requires_grad_(param.requires_grad) |
| setattr(self, name, input_param) |
| else: |
| param.copy_(input_param) |
| except Exception as ex: |
| action = "swapping" if use_swap_tensors else "copying" |
| error_msgs.append( |
| f'While {action} the parameter named "{key}", ' |
| f"whose dimensions in the model are {param.size()} and " |
| f"whose dimensions in the checkpoint are {input_param.size()}, " |
| f"an exception occurred : {ex.args}." |
| ) |
| elif strict: |
| missing_keys.append(key) |
| |
| extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX |
| if ( |
| getattr(self.__class__, "set_extra_state", Module.set_extra_state) |
| is not Module.set_extra_state |
| ): |
| if extra_state_key in state_dict: |
| self.set_extra_state(state_dict[extra_state_key]) |
| elif strict: |
| missing_keys.append(extra_state_key) |
| elif strict and (extra_state_key in state_dict): |
| unexpected_keys.append(extra_state_key) |
| |
| if strict: |
| for key in state_dict.keys(): |
| if key.startswith(prefix) and key != extra_state_key: |
| input_name = key[len(prefix) :].split(".", 1) |
| # Must be Module if it have attributes |
| if len(input_name) > 1: |
| if input_name[0] not in self._modules: |
| unexpected_keys.append(key) |
| elif input_name[0] not in local_state: |
| unexpected_keys.append(key) |
| |
| def load_state_dict( |
| self, state_dict: Mapping[str, Any], strict: bool = True, assign: bool = False |
| ): |
| r"""Copy parameters and buffers from :attr:`state_dict` into this module and its descendants. |
| |
| If :attr:`strict` is ``True``, then |
| the keys of :attr:`state_dict` must exactly match the keys returned |
| by this module's :meth:`~torch.nn.Module.state_dict` function. |
| |
| .. warning:: |
| If :attr:`assign` is ``True`` the optimizer must be created after |
| the call to :attr:`load_state_dict` unless |
| :func:`~torch.__future__.get_swap_module_params_on_conversion` is ``True``. |
| |
| Args: |
| state_dict (dict): a dict containing parameters and |
| persistent buffers. |
| strict (bool, optional): whether to strictly enforce that the keys |
| in :attr:`state_dict` match the keys returned by this module's |
| :meth:`~torch.nn.Module.state_dict` function. Default: ``True`` |
| assign (bool, optional): When ``False``, the properties of the tensors |
| in the current module are preserved while when ``True``, the |
| properties of the Tensors in the state dict are preserved. The only |
| exception is the ``requires_grad`` field of :class:`~torch.nn.Parameter`s |
| for which the value from the module is preserved. |
| Default: ``False`` |
| |
| Returns: |
| ``NamedTuple`` with ``missing_keys`` and ``unexpected_keys`` fields: |
| * **missing_keys** is a list of str containing any keys that are expected |
| by this module but missing from the provided ``state_dict``. |
| * **unexpected_keys** is a list of str containing the keys that are not |
| expected by this module but present in the provided ``state_dict``. |
| |
| Note: |
| If a parameter or buffer is registered as ``None`` and its corresponding key |
| exists in :attr:`state_dict`, :meth:`load_state_dict` will raise a |
| ``RuntimeError``. |
| """ |
| if not isinstance(state_dict, Mapping): |
| raise TypeError( |
| f"Expected state_dict to be dict-like, got {type(state_dict)}." |
| ) |
| |
| missing_keys: List[str] = [] |
| unexpected_keys: List[str] = [] |
| error_msgs: List[str] = [] |
| |
| # copy state_dict so _load_from_state_dict can modify it |
| metadata = getattr(state_dict, "_metadata", None) |
| state_dict = OrderedDict(state_dict) |
| if metadata is not None: |
| # mypy isn't aware that "_metadata" exists in state_dict |
| state_dict._metadata = metadata # type: ignore[attr-defined] |
| |
| def load(module, local_state_dict, prefix=""): |
| local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) |
| if assign: |
| local_metadata["assign_to_params_buffers"] = assign |
| module._load_from_state_dict( |
| local_state_dict, |
| prefix, |
| local_metadata, |
| True, |
| missing_keys, |
| unexpected_keys, |
| error_msgs, |
| ) |
| for name, child in module._modules.items(): |
| if child is not None: |
| child_prefix = prefix + name + "." |
| child_state_dict = { |
| k: v |
| for k, v in local_state_dict.items() |
| if k.startswith(child_prefix) |
| } |
| load(child, child_state_dict, child_prefix) # noqa: F821 |
| |
| # Note that the hook can modify missing_keys and unexpected_keys. |
| incompatible_keys = _IncompatibleKeys(missing_keys, unexpected_keys) |
| for hook in module._load_state_dict_post_hooks.values(): |
| out = hook(module, incompatible_keys) |
| assert out is None, ( |
| "Hooks registered with ``register_load_state_dict_post_hook`` are not" |
| "expected to return new values, if incompatible_keys need to be modified," |
| "it should be done inplace." |
| ) |
| |
| load(self, state_dict) |
| del load |
| |
| if strict: |
| if len(unexpected_keys) > 0: |
| error_msgs.insert( |
| 0, |
| "Unexpected key(s) in state_dict: {}. ".format( |
| ", ".join(f'"{k}"' for k in unexpected_keys) |
| ), |
| ) |
| if len(missing_keys) > 0: |
| error_msgs.insert( |
| 0, |
| "Missing key(s) in state_dict: {}. ".format( |
| ", ".join(f'"{k}"' for k in missing_keys) |
| ), |
| ) |
| |
| if len(error_msgs) > 0: |
| raise RuntimeError( |
| "Error(s) in loading state_dict for {}:\n\t{}".format( |
| self.__class__.__name__, "\n\t".join(error_msgs) |
| ) |
| ) |
| return _IncompatibleKeys(missing_keys, unexpected_keys) |
| |
| def _named_members( |
| self, get_members_fn, prefix="", recurse=True, remove_duplicate: bool = True |
| ): |
| r"""Help yield various names + members of modules.""" |
| memo = set() |
| modules = ( |
| self.named_modules(prefix=prefix, remove_duplicate=remove_duplicate) |
| if recurse |
| else [(prefix, self)] |
| ) |
| for module_prefix, module in modules: |
| members = get_members_fn(module) |
| for k, v in members: |
| if v is None or v in memo: |
| continue |
| if remove_duplicate: |
| memo.add(v) |
| name = module_prefix + ("." if module_prefix else "") + k |
| yield name, v |
| |
| def parameters(self, recurse: bool = True) -> Iterator[Parameter]: |
| r"""Return an iterator over module parameters. |
| |
| This is typically passed to an optimizer. |
| |
| Args: |
| recurse (bool): if True, then yields parameters of this module |
| and all submodules. Otherwise, yields only parameters that |
| are direct members of this module. |
| |
| Yields: |
| Parameter: module parameter |
| |
| Example:: |
| |
| >>> # xdoctest: +SKIP("undefined vars") |
| >>> for param in model.parameters(): |
| >>> print(type(param), param.size()) |
| <class 'torch.Tensor'> (20L,) |
| <class 'torch.Tensor'> (20L, 1L, 5L, 5L) |
| |
| """ |
| for name, param in self.named_parameters(recurse=recurse): |
| yield param |
| |
| def named_parameters( |
| self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True |
| ) -> Iterator[Tuple[str, Parameter]]: |
| r"""Return an iterator over module parameters, yielding both the name of the parameter as well as the parameter itself. |
| |
| Args: |
| prefix (str): prefix to prepend to all parameter names. |
| recurse (bool): if True, then yields parameters of this module |
| and all submodules. Otherwise, yields only parameters that |
| are direct members of this module. |
| remove_duplicate (bool, optional): whether to remove the duplicated |
| parameters in the result. Defaults to True. |
| |
| Yields: |
| (str, Parameter): Tuple containing the name and parameter |
| |
| Example:: |
| |
| >>> # xdoctest: +SKIP("undefined vars") |
| >>> for name, param in self.named_parameters(): |
| >>> if name in ['bias']: |
| >>> print(param.size()) |
| |
| """ |
| gen = self._named_members( |
| lambda module: module._parameters.items(), |
| prefix=prefix, |
| recurse=recurse, |
| remove_duplicate=remove_duplicate, |
| ) |
| yield from gen |
| |
| def buffers(self, recurse: bool = True) -> Iterator[Tensor]: |
| r"""Return an iterator over module buffers. |
| |
| Args: |
| recurse (bool): if True, then yields buffers of this module |
| and all submodules. Otherwise, yields only buffers that |
| are direct members of this module. |
| |
| Yields: |
| torch.Tensor: module buffer |
| |
| Example:: |
| |
| >>> # xdoctest: +SKIP("undefined vars") |
| >>> for buf in model.buffers(): |
| >>> print(type(buf), buf.size()) |
| <class 'torch.Tensor'> (20L,) |
| <class 'torch.Tensor'> (20L, 1L, 5L, 5L) |
| |
| """ |
| for _, buf in self.named_buffers(recurse=recurse): |
| yield buf |
| |
| def named_buffers( |
| self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True |
| ) -> Iterator[Tuple[str, Tensor]]: |
| r"""Return an iterator over module buffers, yielding both the name of the buffer as well as the buffer itself. |
| |
| Args: |
| prefix (str): prefix to prepend to all buffer names. |
| recurse (bool, optional): if True, then yields buffers of this module |
| and all submodules. Otherwise, yields only buffers that |
| are direct members of this module. Defaults to True. |
| remove_duplicate (bool, optional): whether to remove the duplicated buffers in the result. Defaults to True. |
| |
| Yields: |
| (str, torch.Tensor): Tuple containing the name and buffer |
| |
| Example:: |
| |
| >>> # xdoctest: +SKIP("undefined vars") |
| >>> for name, buf in self.named_buffers(): |
| >>> if name in ['running_var']: |
| >>> print(buf.size()) |
| |
| """ |
| gen = self._named_members( |
| lambda module: module._buffers.items(), |
| prefix=prefix, |
| recurse=recurse, |
| remove_duplicate=remove_duplicate, |
| ) |
| yield from gen |
| |
| def children(self) -> Iterator["Module"]: |
| r"""Return an iterator over immediate children modules. |
| |
| Yields: |
| Module: a child module |
| """ |
| for name, module in self.named_children(): |
| yield module |
| |
| def named_children(self) -> Iterator[Tuple[str, "Module"]]: |
| r"""Return an iterator over immediate children modules, yielding both the name of the module as well as the module itself. |
| |
| Yields: |
| (str, Module): Tuple containing a name and child module |
| |
| Example:: |
| |
| >>> # xdoctest: +SKIP("undefined vars") |
| >>> for name, module in model.named_children(): |
| >>> if name in ['conv4', 'conv5']: |
| >>> print(module) |
| |
| """ |
| memo = set() |
| for name, module in self._modules.items(): |
| if module is not None and module not in memo: |
| memo.add(module) |
| yield name, module |
| |
| def modules(self) -> Iterator["Module"]: |
| r"""Return an iterator over all modules in the network. |
| |
| Yields: |
| Module: a module in the network |
| |
| Note: |
| Duplicate modules are returned only once. In the following |
| example, ``l`` will be returned only once. |
| |
| Example:: |
| |
| >>> l = nn.Linear(2, 2) |
| >>> net = nn.Sequential(l, l) |
| >>> for idx, m in enumerate(net.modules()): |
| ... print(idx, '->', m) |
| |
| 0 -> Sequential( |
| (0): Linear(in_features=2, out_features=2, bias=True) |
| (1): Linear(in_features=2, out_features=2, bias=True) |
| ) |
| 1 -> Linear(in_features=2, out_features=2, bias=True) |
| |
| """ |
| for _, module in self.named_modules(): |
| yield module |
| |
| def named_modules( |
| self, |
| memo: Optional[Set["Module"]] = None, |
| prefix: str = "", |
| remove_duplicate: bool = True, |
| ): |
| r"""Return an iterator over all modules in the network, yielding both the name of the module as well as the module itself. |
| |
| Args: |
| memo: a memo to store the set of modules already added to the result |
| prefix: a prefix that will be added to the name of the module |
| remove_duplicate: whether to remove the duplicated module instances in the result |
| or not |
| |
| Yields: |
| (str, Module): Tuple of name and module |
| |
| Note: |
| Duplicate modules are returned only once. In the following |
| example, ``l`` will be returned only once. |
| |
| Example:: |
| |
| >>> l = nn.Linear(2, 2) |
| >>> net = nn.Sequential(l, l) |
| >>> for idx, m in enumerate(net.named_modules()): |
| ... print(idx, '->', m) |
| |
| 0 -> ('', Sequential( |
| (0): Linear(in_features=2, out_features=2, bias=True) |
| (1): Linear(in_features=2, out_features=2, bias=True) |
| )) |
| 1 -> ('0', Linear(in_features=2, out_features=2, bias=True)) |
| |
| """ |
| if memo is None: |
| memo = set() |
| if self not in memo: |
| if remove_duplicate: |
| memo.add(self) |
| yield prefix, self |
| for name, module in self._modules.items(): |
| if module is None: |
| continue |
| submodule_prefix = prefix + ("." if prefix else "") + name |
| yield from module.named_modules( |
| memo, submodule_prefix, remove_duplicate |
| ) |
| |
| def train(self: T, mode: bool = True) -> T: |
| r"""Set the module in training mode. |
| |
| This has any effect only on certain modules. See documentations of |
| particular modules for details of their behaviors in training/evaluation |
| mode, if they are affected, e.g. :class:`Dropout`, :class:`BatchNorm`, |
| etc. |
| |
| Args: |
| mode (bool): whether to set training mode (``True``) or evaluation |
| mode (``False``). Default: ``True``. |
| |
| Returns: |
| Module: self |
| """ |
| if not isinstance(mode, bool): |
| raise ValueError("training mode is expected to be boolean") |
| self.training = mode |
| for module in self.children(): |
| module.train(mode) |
| return self |
| |
| def eval(self: T) -> T: |
| r"""Set the module in evaluation mode. |
| |
| This has any effect only on certain modules. See documentations of |
| particular modules for details of their behaviors in training/evaluation |
| mode, if they are affected, e.g. :class:`Dropout`, :class:`BatchNorm`, |
| etc. |
| |
| This is equivalent with :meth:`self.train(False) <torch.nn.Module.train>`. |
| |
| See :ref:`locally-disable-grad-doc` for a comparison between |
| `.eval()` and several similar mechanisms that may be confused with it. |
| |
| Returns: |
| Module: self |
| """ |
| return self.train(False) |
| |
| def requires_grad_(self: T, requires_grad: bool = True) -> T: |
| r"""Change if autograd should record operations on parameters in this module. |
| |
| This method sets the parameters' :attr:`requires_grad` attributes |
| in-place. |
| |
| This method is helpful for freezing part of the module for finetuning |
| or training parts of a model individually (e.g., GAN training). |
| |
| See :ref:`locally-disable-grad-doc` for a comparison between |
| `.requires_grad_()` and several similar mechanisms that may be confused with it. |
| |
| Args: |
| requires_grad (bool): whether autograd should record operations on |
| parameters in this module. Default: ``True``. |
| |
| Returns: |
| Module: self |
| """ |
| for p in self.parameters(): |
| p.requires_grad_(requires_grad) |
| return self |
| |
| def zero_grad(self, set_to_none: bool = True) -> None: |
| r"""Reset gradients of all model parameters. |
| |
| See similar function under :class:`torch.optim.Optimizer` for more context. |
| |
| Args: |
| set_to_none (bool): instead of setting to zero, set the grads to None. |
| See :meth:`torch.optim.Optimizer.zero_grad` for details. |
| """ |
| if getattr(self, "_is_replica", False): |
| warnings.warn( |
| "Calling .zero_grad() from a module created with nn.DataParallel() has no effect. " |
| "The parameters are copied (in a differentiable manner) from the original module. " |
| "This means they are not leaf nodes in autograd and so don't accumulate gradients. " |
| "If you need gradients in your forward method, consider using autograd.grad instead." |
| ) |
| |
| for p in self.parameters(): |
| if p.grad is not None: |
| if set_to_none: |
| p.grad = None |
| else: |
| if p.grad.grad_fn is not None: |
| p.grad.detach_() |
| else: |
| p.grad.requires_grad_(False) |
| p.grad.zero_() |
| |
| def share_memory(self: T) -> T: |
| r"""See :meth:`torch.Tensor.share_memory_`.""" |
| return self._apply(lambda t: t.share_memory_()) |
| |
| def _get_name(self): |
| return self.__class__.__name__ |
| |
| def extra_repr(self) -> str: |
| r"""Set the extra representation of the module. |
| |
| To print customized extra information, you should re-implement |
| this method in your own modules. Both single-line and multi-line |
| strings are acceptable. |
| """ |
| return "" |
| |
| def __repr__(self): |
| # We treat the extra repr like the sub-module, one item per line |
| extra_lines = [] |
| extra_repr = self.extra_repr() |
| # empty string will be split into list [''] |
| if extra_repr: |
| extra_lines = extra_repr.split("\n") |
| child_lines = [] |
| for key, module in self._modules.items(): |
| mod_str = repr(module) |
| mod_str = _addindent(mod_str, 2) |
| child_lines.append("(" + key + "): " + mod_str) |
| lines = extra_lines + child_lines |
| |
| main_str = self._get_name() + "(" |
| if lines: |
| # simple one-liner info, which most builtin Modules will use |
| if len(extra_lines) == 1 and not child_lines: |
| main_str += extra_lines[0] |
| else: |
| main_str += "\n " + "\n ".join(lines) + "\n" |
| |
| main_str += ")" |
| return main_str |
| |
| def __dir__(self): |
| module_attrs = dir(self.__class__) |
| attrs = list(self.__dict__.keys()) |
| parameters = list(self._parameters.keys()) |
| modules = list(self._modules.keys()) |
| buffers = list(self._buffers.keys()) |
| keys = module_attrs + attrs + parameters + modules + buffers |
| |
| # Eliminate attrs that are not legal Python variable names |
| keys = [key for key in keys if not key[0].isdigit()] |
| |
| return sorted(keys) |
| |
| def _replicate_for_data_parallel(self): |
| replica = self.__new__(type(self)) |
| replica.__dict__ = self.__dict__.copy() |
| |
| # replicas do not have parameters themselves, the replicas reference the original |
| # module. |
| replica._parameters = dict() |
| replica._buffers = replica._buffers.copy() |
| replica._modules = replica._modules.copy() |
| replica._is_replica = True # type: ignore[assignment] |
| |
| return replica |
| |
| def compile(self, *args, **kwargs): |
| """ |
| Compile this Module's forward using :func:`torch.compile`. |
| |
| This Module's `__call__` method is compiled and all arguments are passed as-is |
| to :func:`torch.compile`. |
| |
| See :func:`torch.compile` for details on the arguments for this function. |
| """ |
| self._compiled_call_impl = torch.compile(self._call_impl, *args, **kwargs) |