blob: 570500b2d489012703bd955d0a2cbe8fb2c86d62 [file] [log] [blame]
import contextlib
from typing import Any, Callable, Dict, Iterator, List, Tuple, Union, Set, Optional
import warnings
import torch
from torch import Tensor
__all__ = ["functional_call"]
# We avoid typing module here because module attributes are declared as Union[Parameter, Tensor] by default
# and using other types causes mypy errors
def _change_class(module, params_and_buffers) -> None:
cls = module.__class__
attr_to_path : Dict[str, str] = module._attr_to_path
def _getattribute(self, name: str) -> Any:
if name in attr_to_path:
return params_and_buffers[attr_to_path[name]]
return cls.__getattribute__(self, name)
def _setattr(self, name: str, value: Any) -> None:
if name in attr_to_path:
params_and_buffers[attr_to_path[name]] = value
else:
return cls.__setattr__(self, name, value)
param_cls = type(
f"StatelessReplacer{cls.__name__}",
(cls,),
{
"__getattribute__": _getattribute,
"__setattr__": _setattr,
},
)
module.__class__ = param_cls
module._orig_class = cls
def _create_tied_weights_map(module: 'torch.nn.Module', params_and_buffers: Dict[str, Tensor]) -> Dict[str, str]:
"""
_create_tied_weights_map(module: Module, params_and_buffers: Dict[str, Tensor]) -> Dict[str, str]
Creates a weight map of {tied_name: name_given_by_user} for all weights where one of their tied weights is passed
ex: Foo() has self.foo and self.tied_foo, which are tied. If a user passed {'foo': ...} as the reparamaterization,
this would return {'tied_foo': 'foo'}. Similarly if a user passed {'tied_foo': ...}, this returns
{'tied_foo': 'foo'}.
ex: If there aren't any tied weights and the user passed values for every parameter and buffer, this will return a
map where every name maps to an empty set: {'l1.weight': set(), 'l1.bias': set(), ...}
ex: The map only contains values that a user is reparamaterizing. For example, if module = nn.Linear(...) and the
user only passed a new value for 'bias', this looks returns: {'bias': set()}
This is useful because we will start by reparamaterizing all the keys of params_and_buffers, then all the key from
this returned dictionary.
"""
# The basic algorithm looks like:
# - index all weights by their original tensor value to find tied weights
# - when we encounter a weight not used by the user, we save it in a set (second element in the tuple)
# - when we run into a weight used by the user, we save that separate from the set as the first element in the tuple
# - ending map looks like {tensor: (name_given_by_user, set(all_tied_names)}
# - then loop through the values of this map (name_given_by_user and set(all_tied_names))
# - for each element of all_tied_names, add {tied_name: name_given_by_user} to a new map
names = params_and_buffers.keys()
weight_to_name_and_tied_names: Dict[torch.Tensor, Tuple[Optional[str], Set[str]]] = {}
# create a map keyed by tensor value so that tied weights get mapped to the same key. The value is the interesting
# part at the end it's (used_name, (tied_names)).
# For example, in the first example where there's tied weights self.foo and self.tied_foo and the user passes a
# value for self.foo, this will return {torch.Tensor(...): ('foo', set('tied_foo'))}
def add_to_name_map(n: str, t: torch.Tensor):
# if the tensor hasn't been seen before, add it to the map
if t not in weight_to_name_and_tied_names:
weight_to_name_and_tied_names[t] = (n, set()) if n in names else (None, {n})
return
# if the name is not used by the user, we add it to the tied set
if n not in names:
weight_to_name_and_tied_names[t][1].add(n)
return
# check that the user didn't pass two different tensors for the same tied weight
first_seen_name = weight_to_name_and_tied_names[t][0]
# if they didn't pass multiple names for tied weights or used the same tensor, we set the used name
if first_seen_name is None or params_and_buffers[n] is params_and_buffers[first_seen_name]:
weight_to_name_and_tied_names[t] = (n, weight_to_name_and_tied_names[t][1])
return
raise ValueError(f"functional_call got values for both {n} and {first_seen_name}, which are tied. " +
"Consider using tie_weights=False")
tensor: Tensor
for name, tensor in module.named_parameters(remove_duplicate=False):
add_to_name_map(name, tensor)
for name, tensor in module.named_buffers(remove_duplicate=False):
add_to_name_map(name, tensor)
# make {tied_name: name_given_by_user} from pairs of (name_given_by_user, set(all_tied_names))
tied_weights_to_given_name = {}
for name_given_by_user, tied_names in weight_to_name_and_tied_names.values():
if name_given_by_user is None: # no mapping was passed for this tensor, use original tensor
continue
for tied_name in tied_names:
tied_weights_to_given_name[tied_name] = name_given_by_user
return tied_weights_to_given_name
def _create_swap_params(params_and_buffers):
def _swap_parameters(module, tensor_name: str, full_path: str, tensor: Optional[Tensor]) -> None:
# Changes the module class to get a new __getattr__ dunder method
# that looks for the reparametrized tensor
if hasattr(module, "_attr_to_path"):
module._attr_to_path[tensor_name] = full_path
else:
module._attr_to_path = {}
module._attr_to_path[tensor_name] = full_path
_change_class(module, params_and_buffers)
return _swap_parameters
def _remove_swap(module, name: str, full_path: str) -> None:
if hasattr(module, "_orig_class"):
module.__class__ = module._orig_class
delattr(module, "_orig_class")
delattr(module, "_attr_to_path")
@contextlib.contextmanager
def _reparametrize_module(
module: 'torch.nn.Module',
parameters_and_buffers: Dict[str, Tensor],
tie_weights: bool = False,
) -> Iterator[None]:
tied_weights_map = _create_tied_weights_map(module, parameters_and_buffers) if tie_weights else {}
for name, tensor in parameters_and_buffers.items():
_apply_func_submodules(
_create_swap_params(parameters_and_buffers),
module, name.split("."), name, (tensor,))
for tied_name, user_given_name in tied_weights_map.items():
_apply_func_submodules(
_create_swap_params(parameters_and_buffers),
module, tied_name.split("."), user_given_name, (None,))
try:
yield
finally:
for name in parameters_and_buffers:
_apply_func_submodules(
_remove_swap,
module, name.split("."), name, ())
def _apply_func_submodules(
func: Callable[..., None],
module: 'torch.nn.Module',
path: List[str],
full_path: str,
args: Tuple,
):
if len(path) == 1:
func(module, path[0], full_path, *args)
else:
_apply_func_submodules(func, getattr(module, path[0]), path[1:], full_path, args)
def functional_call(
module: 'torch.nn.Module',
parameters_and_buffers: Dict[str, Tensor],
args: Union[Any, Tuple],
kwargs: Dict[str, Any] = None,
*,
tie_weights: bool = True,
):
r"""Performs a functional call on the module by replacing the module parameters
and buffers with the provided ones.
.. warning::
This API is deprecated as of PyTorch 2.0 and will be removed in a future
version of PyTorch. Please use :func:`torch.func.functional_call` instead,
which is a drop-in replacement for this API.
.. note:: If the module has active parametrizations, passing a value in the
:attr:`parameters_and_buffers` argument with the name set to the regular parameter
name will completely disable the parametrization.
If you want to apply the parametrization function to the value passed
please set the key as ``{submodule_name}.parametrizations.{parameter_name}.original``.
.. note:: If the module performs in-place operations on parameters/buffers, these will be reflected
in the `parameters_and_buffers` input.
Example::
>>> a = {'foo': torch.zeros(())}
>>> # xdoctest: +SKIP
>>> mod = Foo() # does self.foo = self.foo + 1
>>> print(mod.foo) # tensor(0.)
>>> functional_call(mod, a, torch.ones(()))
>>> print(mod.foo) # tensor(0.)
>>> print(a['foo']) # tensor(1.)
.. note:: If the module has tied weights, whether or not functional_call respects the tying is determined by the
tie_weights flag.
Example::
>>> a = {'foo': torch.zeros(())}
>>> # xdoctest: +SKIP
>>> mod = Foo() # has both self.foo and self.foo_tied which are tied. Returns x + self.foo + self.foo_tied
>>> print(mod.foo) # tensor(1.)
>>> mod(torch.zeros(())) # tensor(2.)
>>> functional_call(mod, a, torch.zeros(())) # tensor(0.) since it will change self.foo_tied too
>>> functional_call(mod, a, torch.zeros(()), tie_weights=False) # tensor(1.)--self.foo_tied is not updated
>>> new_a = {'foo', torch.zeros(()), 'foo_tied': torch.zeros(())}
>>> functional_call(mod, new_a, torch.zeros()) # tensor(0.)
Args:
module (torch.nn.Module): the module to call
parameters_and_buffers (dict of str and Tensor): the parameters that will be used in
the module call.
args (Any or tuple): arguments to be passed to the module call. If not a tuple, considered a single argument.
kwargs (dict): keyword arguments to be passed to the module call
tie_weights (bool, optional): If True, then parameters and buffers tied in the original model will be treated as
tied in the reparamaterized version. Therefore, if True and different values are passed for the tied
paramaters and buffers, it will error. If False, it will not respect the originally tied parameters and
buffers unless the values passed for both weights are the same. Default: True.
Returns:
Any: the result of calling ``module``.
"""
warnings.warn(
"This API is deprecated as of PyTorch 2.0 and will be removed in a future "
"version of PyTorch. Please use torch.func.functional_call instead "
"which is a drop-in replacement for this API.")
return _functional_call(module, parameters_and_buffers, args, kwargs,
tie_weights=tie_weights)
def _functional_call(
module: 'torch.nn.Module',
parameters_and_buffers: Dict[str, Tensor],
args: Union[Any, Tuple],
kwargs: Dict[str, Any] = None,
*,
tie_weights: bool = True,
):
# TODO allow kwargs such as unsafe and others for parametrization
if (
torch.jit.is_tracing()
or torch.jit.is_scripting()
or isinstance(module, (
torch.jit.RecursiveScriptModule,
torch.jit.ScriptModule,
torch.jit.ScriptFunction)
)
):
raise RuntimeError("The stateless API can't be used with Jitted modules")
if kwargs is None:
kwargs = {}
with _reparametrize_module(module, parameters_and_buffers, tie_weights):
if isinstance(args, tuple):
out = module(*args, **kwargs)
else:
out = module(args, **kwargs)
return out