blob: be9a4c1f0a653bb291561517fe191d3d4bb38dc4 [file] [log] [blame]
import torch
from collections import OrderedDict
import weakref
import warnings
from typing import Any
__all__ = ["RemovableHandle", "unserializable_hook", "warn_if_has_hooks", "BackwardHook"]
class RemovableHandle(object):
r"""
A handle which provides the capability to remove a hook.
Args:
hooks_dict (dict): A dictionary of hooks, indexed by hook ``id``.
extra_dict (dict): An additional dictionary whose keys will be deleted
when the same keys are removed from ``hooks_dict``.
"""
id: int
next_id: int = 0
def __init__(self, hooks_dict: Any, *, extra_dict: Any = None) -> None:
self.hooks_dict_ref = weakref.ref(hooks_dict)
self.id = RemovableHandle.next_id
RemovableHandle.next_id += 1
self.extra_dict_ref = (
weakref.ref(extra_dict)
if extra_dict is not None
else None
)
def remove(self) -> None:
hooks_dict = self.hooks_dict_ref()
if hooks_dict is not None and self.id in hooks_dict:
del hooks_dict[self.id]
if self.extra_dict_ref is not None:
extra_dict = self.extra_dict_ref()
if extra_dict is not None and self.id in extra_dict:
del extra_dict[self.id]
def __getstate__(self):
return (
(self.hooks_dict_ref(), self.id)
if self.extra_dict_ref is None
else (self.hooks_dict_ref(), self.id, self.extra_dict_ref())
)
def __setstate__(self, state) -> None:
if state[0] is None:
# create a dead reference
self.hooks_dict_ref = weakref.ref(OrderedDict())
else:
self.hooks_dict_ref = weakref.ref(state[0])
self.id = state[1]
RemovableHandle.next_id = max(RemovableHandle.next_id, self.id + 1)
self.extra_dict_ref = (
None
if len(state) < 3
else weakref.ref(OrderedDict() if state[2] is None else state[2])
)
def __enter__(self) -> "RemovableHandle":
return self
def __exit__(self, type: Any, value: Any, tb: Any) -> None:
self.remove()
def unserializable_hook(f):
"""
Decorator which marks a function as an unserializable hook.
This suppresses warnings that would otherwise arise if you attempt
to serialize a tensor that has a hook.
"""
f.__torch_unserializable__ = True
return f
def warn_if_has_hooks(tensor):
if tensor._backward_hooks:
for k in tensor._backward_hooks:
hook = tensor._backward_hooks[k]
if not hasattr(k, "__torch_unserializable__"):
warnings.warn("backward hook {} on tensor will not be "
"serialized. If this is expected, you can "
"decorate the function with @torch.utils.hooks.unserializable_hook "
"to suppress this warning".format(repr(hook)))
class BackwardHook(object):
"""
A wrapper class to implement nn.Module backward hooks.
It handles:
- Ignoring non-Tensor inputs and replacing them by None before calling the user hook
- Generating the proper Node to capture a set of Tensor's gradients
- Linking the gradients captures for the outputs with the gradients captured for the input
- Calling the user hook once both output and input gradients are available
"""
def __init__(self, module, user_hooks, user_pre_hooks):
self.user_hooks = user_hooks
self.user_pre_hooks = user_pre_hooks
self.module = module
self.grad_outputs = None
self.n_outputs = -1
self.output_tensors_index = None
self.n_inputs = -1
self.input_tensors_index = None
def _pack_with_none(self, indices, values, size):
res = [None] * size
for idx, val in zip(indices, values):
res[idx] = val
return tuple(res)
def _unpack_none(self, indices, values):
res = []
for idx in indices:
res.append(values[idx])
return tuple(res)
def _set_user_hook(self, grad_fn):
def hook(grad_input, _):
if self.grad_outputs is None:
# This happens because the gradient in your nn.Module flows to
# the Module's input without " passing through the Module's
# output, e.g. when you're doing double backward.
return
res = self._pack_with_none(self.input_tensors_index, grad_input, self.n_inputs)
for hook in self.user_hooks:
out = hook(self.module, res, self.grad_outputs)
if out is None:
continue
if len(out) != len(res):
raise RuntimeError("Backward hook returned an invalid number of grad_input, "
"got {}, but expected {}".format(len(out), len(res)))
res = out
self.grad_outputs = None
return self._unpack_none(self.input_tensors_index, res)
grad_fn.register_hook(hook)
def _apply_on_tensors(self, fn, args):
# Can be used to apply the given function to the tensors contained in the
# args. Will return updated args and the tensors indices
tensors_idx = []
tensors = []
requires_grad = False
for i, arg in enumerate(args):
if isinstance(arg, torch.Tensor):
tensors_idx.append(i)
tensors.append(arg)
requires_grad |= arg.requires_grad
if not (requires_grad and torch.is_grad_enabled()):
return args, None
new_tensors = torch.nn.modules._functions.BackwardHookFunction.apply(*tensors)
if len(new_tensors) == 0:
raise RuntimeError("Cannot set Module backward hook for a Module with no input Tensors.")
grad_fns = [t.grad_fn for t in new_tensors if t.grad_fn is not None and t.grad_fn.name() == "BackwardHookFunctionBackward"]
if len(grad_fns) == 0:
raise RuntimeError("Error while setting up backward hooks. Please open "
"an issue with a code sample to reproduce this.")
fn(grad_fns[0])
arg_list = list(args)
for idx, val in zip(tensors_idx, new_tensors):
arg_list[idx] = val
return tuple(arg_list), tensors_idx
def setup_input_hook(self, args):
def fn(grad_fn):
self._set_user_hook(grad_fn)
res, input_idx = self._apply_on_tensors(fn, args)
self.n_inputs = len(args)
self.input_tensors_index = input_idx
return res
def setup_output_hook(self, args):
def fn(grad_fn):
def hook(_, grad_output):
self.grad_outputs = self._pack_with_none(self.output_tensors_index,
grad_output,
self.n_outputs)
if self.user_pre_hooks:
expected_len = len(self.grad_outputs)
for user_pre_hook in self.user_pre_hooks:
hook_grad_outputs = user_pre_hook(self.module, self.grad_outputs)
if hook_grad_outputs is None:
continue
actual_len = len(hook_grad_outputs)
if actual_len != expected_len:
raise RuntimeError("Backward pre hook returned an invalid number of grad_output, "
"got {}, but expected {}".format(actual_len, expected_len))
self.grad_outputs = hook_grad_outputs
# Special case if no input required gradients, this hook should call the user
# hook directly
if self.input_tensors_index is None:
grad_inputs = self._pack_with_none([], [], self.n_inputs)
for user_hook in self.user_hooks:
res = user_hook(self.module, grad_inputs, self.grad_outputs)
if res is not None and not (isinstance(res, tuple) and all(el is None for el in res)):
raise RuntimeError("Backward hook for Modules where no input requires "
"gradient should always return None or None for all gradients.")
self.grad_outputs = None
grad_fn.register_hook(hook)
is_tuple = True
if not isinstance(args, tuple):
args = (args,)
is_tuple = False
res, output_idx = self._apply_on_tensors(fn, args)
self.n_outputs = len(args)
self.output_tensors_index = output_idx
if not is_tuple:
res = res[0]
return res