| import weakref |
| |
| from typing import Set |
| |
| import torch |
| from torch.autograd.graph import register_multi_grad_hook |
| from torch.nn.modules.module import ( |
| register_module_forward_hook, |
| register_module_forward_pre_hook, |
| ) |
| from torch.utils._pytree import tree_flatten |
| |
| __all__ = ["ModuleTracker"] |
| |
| |
| class ModuleTracker: |
| """ |
| ``ModuleTracker`` is a context manager that tracks the nn.Module hierarchy during execution |
| so that other system can query which Module is currently being executed (or its backward is being |
| executed). |
| |
| You can access the ``parents`` attribute on this context manager to get the set of all the |
| Modules currently being executed via their fqn (fully qualified name, also used as the key within |
| the state_dict). |
| You can access the ``is_bw`` attribute to know if you are currently running in backward or not. |
| |
| Note that ``parents`` is never empty and always contains the "Global" key. The ``is_bw`` flag |
| will remain ``True`` after the forward until another Module is executed. If you need it to be |
| more accurate, please submit an issue requesting this. Adding a map from fqn to the module instance |
| is possible but not done yet, please submit an issue requesting this if you need it. |
| |
| Example usage |
| |
| .. code-block:: python |
| |
| mod = torch.nn.Linear(2, 2) |
| |
| with ModuleTracker() as tracker: |
| # Access anything during the forward pass |
| def my_linear(m1, m2, bias): |
| print(f"Current modules: {tracker.parents}") |
| return torch.mm(m1, m2.t()) + bias |
| torch.nn.functional.linear = my_linear |
| |
| mod(torch.rand(2, 2)) |
| |
| """ |
| |
| parents: Set[str] |
| """ |
| A Set containing the fqn for each module currently running their forward |
| """ |
| |
| def __init__(self): |
| self.parents = {"Global"} |
| self._known_modules: weakref.WeakKeyDictionary = weakref.WeakKeyDictionary() |
| self._seen_modules: weakref.WeakSet = weakref.WeakSet() |
| self._has_callback = False |
| |
| def _maybe_set_engine_callback(self): |
| # This assumes no concurrent calls to backward |
| if self._has_callback: |
| return |
| |
| def callback(): |
| self.parents = {"Global"} |
| self._has_callback = False |
| |
| torch.autograd.Variable._execution_engine.queue_callback(callback) |
| self._has_callback = True |
| |
| @property |
| def is_bw(self): |
| """ |
| A boolean marking if this is currently running during the backward pass or not |
| """ |
| return torch._C._current_graph_task_id() != -1 |
| |
| def _get_mod_name(self, mod): |
| if mod not in self._known_modules: |
| self._known_modules[mod] = type(mod).__name__ |
| mod_name = self._known_modules[mod] |
| if mod not in self._seen_modules: |
| for name, submod in mod.named_children(): |
| self._known_modules[submod] = f"{mod_name}.{name}" |
| self._get_mod_name(submod) |
| self._seen_modules.add(mod) |
| return mod_name |
| |
| def _get_append_fn(self, name, is_bw): |
| def fn(*args): |
| if is_bw: |
| self._maybe_set_engine_callback() |
| if name in self.parents: |
| print( |
| "The module hierarchy tracking seems to be messed up." |
| "Please file a bug to PyTorch." |
| ) |
| self.parents.add(name) |
| |
| return fn |
| |
| def _get_pop_fn(self, name, is_bw): |
| def fn(*args): |
| if name in self.parents: |
| self.parents.remove(name) |
| elif not is_bw: |
| # Due to some input/output not requiring gradients, we cannot enforce |
| # proper nesting in backward |
| raise RuntimeError( |
| "The Module hierarchy tracking is wrong. Report a bug to PyTorch" |
| ) |
| |
| return fn |
| |
| def _fw_pre_hook(self, mod, input): |
| name = self._get_mod_name(mod) |
| self._get_append_fn(name, False)() |
| |
| args, _ = tree_flatten(input) |
| tensors = [a for a in args if isinstance(a, torch.Tensor) and a.requires_grad] |
| if tensors: |
| register_multi_grad_hook(tensors, self._get_pop_fn(name, True)) |
| |
| def _fw_post_hook(self, mod, input, output): |
| name = self._get_mod_name(mod) |
| self._get_pop_fn(name, False)() |
| |
| args, _ = tree_flatten(output) |
| tensors = [a for a in args if isinstance(a, torch.Tensor) and a.requires_grad] |
| if tensors: |
| register_multi_grad_hook(tensors, self._get_append_fn(name, True)) |
| |
| def __enter__(self): |
| self._fw_pre_handle = register_module_forward_pre_hook(self._fw_pre_hook) |
| self._fw_post_handle = register_module_forward_hook(self._fw_post_hook) |
| return self |
| |
| def __exit__(self, *args): |
| self._fw_pre_handle.remove() |
| self._fw_post_handle.remove() |