blob: dcf03e7a28d0856d5ea225c419b2fbd0a5fbf420 [file] [log] [blame]
import torch
from collections import defaultdict
from torch import nn, Tensor
from typing import List, Tuple, Dict, Union, Callable
# Type helpers
InputsType = Union[Tensor, Tuple[Tensor, ...]]
# A Getter takes in a device and returns a callable and the inputs to that callable
GetterReturnType = Tuple[Callable[..., Tensor], InputsType]
GetterType = Callable[[torch.device], GetterReturnType]
# V here refers to the v in either vjp, jvp, vhp or hvp
VType = Union[None, Tensor, Tuple[Tensor, ...]]
# Type used to store timing results. The first key is the model name, the second key
# is the task name, the result is a Tuple of: speedup, mean_before, var_before, mean_after, var_after.
TimingResultType = Dict[str, Dict[str, Tuple[float, ...]]]
# Utilities to make nn.Module "functional"
# In particular the goal is to be able to provide a function that takes as input
# the parameters and evaluate the nn.Module using fixed inputs.
def _del_nested_attr(obj: nn.Module, names: List[str]) -> None:
"""
Deletes the attribute specified by the given list of names.
For example, to delete the attribute obj.conv.weight,
use _del_nested_attr(obj, ['conv', 'weight'])
"""
if len(names) == 1:
delattr(obj, names[0])
else:
_del_nested_attr(getattr(obj, names[0]), names[1:])
def _set_nested_attr(obj: nn.Module, names: List[str], value: Tensor) -> None:
"""
Set the attribute specified by the given list of names to value.
For example, to set the attribute obj.conv.weight,
use _del_nested_attr(obj, ['conv', 'weight'], value)
"""
if len(names) == 1:
setattr(obj, names[0], value)
else:
_set_nested_attr(getattr(obj, names[0]), names[1:], value)
def extract_weights(mod: nn.Module) -> Tuple[Tuple[Tensor, ...], List[str]]:
"""
This function removes all the Parameters from the model and
return them as a tuple as well as their original attribute names.
The weights must be re-loaded with `load_weights` before the model
can be used again.
Note that this function modifies the model in place and after this
call, mod.parameters() will be empty.
"""
orig_params = tuple(mod.parameters())
# Remove all the parameters in the model
names = []
for name, p in list(mod.named_parameters()):
_del_nested_attr(mod, name.split("."))
names.append(name)
# Make params regular Tensors instead of nn.Parameter
params = tuple(p.detach().requires_grad_() for p in orig_params)
return params, names
def load_weights(mod: nn.Module, names: List[str], params: Tuple[Tensor, ...]) -> None:
"""
Reload a set of weights so that `mod` can be used again to perform a forward pass.
Note that the `params` are regular Tensors (that can have history) and so are left
as Tensors. This means that mod.parameters() will still be empty after this call.
"""
for name, p in zip(names, params):
_set_nested_attr(mod, name.split("."), p)
# Utilities to read/write markdown table-like content.
def to_markdown_table(res: TimingResultType, header: Tuple[str, ...] = None) -> str:
if header is None:
header = ("model", "task", "mean", "var")
out = ""
def write_line(*args):
nonlocal out
out += "| {} |\n".format(" | ".join(str(a) for a in args))
# Make it a markdown table
write_line(*header)
write_line(*["--"] * len(header))
for model, tasks in res.items():
for task, line in tasks.items():
write_line(*(model, task) + line)
return out
def from_markdown_table(data: str) -> TimingResultType:
out = data.strip().split("\n")
out = out[2:] # Ignore the header lines
res: TimingResultType
res = defaultdict(defaultdict)
for line in out:
model, task, mean, var = [f.strip() for f in line.strip().split("|") if f]
res[model][task] = (float(mean), float(var))
return res
def check_for_functorch():
try:
import functorch # noqa: F401
return True
except ImportError:
return False