blob: 4a34fa1fadf6698a9ac7b59daf0a0c4a2cac4078 [file] [log] [blame]
import torch
from torch.utils._pytree import tree_map
from typing import List, Any, Dict, Optional, Union
from collections import defaultdict
from torch.utils._python_dispatch import TorchDispatchMode
from math import prod
__all__ = ["FlopCounterMode"]
aten = torch.ops.aten
def get_shape(i):
if isinstance(i, torch.Tensor):
return i.shape
return i
def mm_flop(input_shapes: List[Any], _=None) -> int:
"""
Count flops for matmul.
"""
# Inputs should be a list of length 2.
# Inputs contains the shapes of two matrices.
assert len(input_shapes) == 2, input_shapes
a_shape, b_shape = input_shapes
m, k = a_shape
k2, n = b_shape
assert k == k2
# NB(chilli): Should be 2 * k - 1 technically for FLOPs.
return m * n * 2 * k
def addmm_flop(input_shapes: List[Any], _=None) -> int:
"""
Count flops for addmm
"""
assert len(input_shapes) == 3
_, a_shape, b_shape = input_shapes
return mm_flop([a_shape, b_shape])
def bmm_flop(input_shapes: List[Any], _=None) -> int:
"""
Count flops for the bmm operation.
"""
# Inputs should be a list of length 2.
# Inputs contains the shapes of two tensor.
assert len(input_shapes) == 2, len(input_shapes)
a_shape, b_shape = input_shapes
b, m, k = a_shape
b2, k2, n = b_shape
assert b == b2
assert k == k2
# NB(chilli): Should be 2 * k - 1 technically for FLOPs.
flop = b * m * n * 2 * k
return flop
def baddbmm_flop(input_shapes: List[Any], _=None) -> int:
"""
Count flops for the baddbmm operation.
"""
# Inputs should be a list of length 3.
# Inputs contains the shapes of three tensors.
assert len(input_shapes) == 3, len(input_shapes)
_, a_shape, b_shape = input_shapes
return bmm_flop([a_shape, b_shape])
def conv_flop_count(
x_shape: List[int],
w_shape: List[int],
out_shape: List[int],
transposed: bool = False,
) -> int:
"""
Count flops for convolution. Note only multiplication is
counted. Computation for bias are ignored.
Flops for a transposed convolution are calculated as
flops = (x_shape[2:] * prod(w_shape) * batch_size).
Args:
x_shape (list(int)): The input shape before convolution.
w_shape (list(int)): The filter shape.
out_shape (list(int)): The output shape after convolution.
transposed (bool): is the convolution transposed
Returns:
int: the number of flops
"""
batch_size = x_shape[0]
conv_shape = (x_shape if transposed else out_shape)[2:]
c_out, c_in, *dims = w_shape
# NB(chilli): I don't think this properly accounts for padding :think:
# NB(chilli): Should be 2 * c_in - 1 technically for FLOPs.
flop = batch_size * prod(conv_shape) * c_out * prod(dims) * 2 * c_in
return flop
def conv_flop(input_shapes: List[Any], output_shapes: List[Any]) -> int:
"""
Count flops for convolution.
"""
x_shape, w_shape = input_shapes[:2]
out_shape = output_shapes[0]
transposed = input_shapes[6]
return conv_flop_count(x_shape, w_shape, out_shape, transposed=transposed)
def transpose_shape(shape):
return [shape[1], shape[0]] + list(shape[2:])
def conv_backward_flop(input_shapes: List[Any], output_shapes: List[Any]) -> int:
grad_out_shape, x_shape, w_shape = input_shapes[:3]
output_mask = input_shapes[-1]
fwd_transposed = input_shapes[7]
flop_count = 0
if output_mask[0]:
grad_input_shape = get_shape(output_shapes[0])
flop_count += conv_flop_count(grad_out_shape, w_shape, grad_input_shape, not fwd_transposed)
if output_mask[1]:
grad_weight_shape = get_shape(output_shapes[1])
flop_count += conv_flop_count(transpose_shape(x_shape), grad_out_shape, grad_weight_shape, fwd_transposed)
return flop_count
flop_mapping = {
aten.mm: mm_flop,
aten.addmm: addmm_flop,
aten.bmm: bmm_flop,
aten.baddbmm: baddbmm_flop,
aten.convolution: conv_flop,
aten._convolution: conv_flop,
aten.convolution_backward: conv_backward_flop,
}
def normalize_tuple(x):
if not isinstance(x, tuple):
return (x,)
return x
# Thanks BingChat!
def convert_num_to_suffix_str(number):
# Define the suffixes for different orders of magnitude
suffixes = ["", "K", "M", "B", "T"]
# Find the index of the appropriate suffix based on the number of digits
index = max(0, min(len(suffixes) - 1, (len(str(number)) - 1) // 3))
# Divide the number by 1000^index and format it to two decimal places
value = "{:.2f}".format(number / (1000 ** index))
# Return the value and the suffix as a string
return value + suffixes[index]
class FlopCounterMode(TorchDispatchMode):
"""
``FlopCounterMode`` is a context manager that counts the number of
flops within its context. It does this using a ``TorchDispatchMode``.
It also supports hierarchical output by passing a module (or list of modules) to FlopCounterMode on construction.
Example usage
.. code-block:: python
mod = ...
flop_counter = FlopCounterMode(mod)
with flop_counter:
mod.sum().backward()
"""
def __init__(
self,
mods: Optional[Union[torch.nn.Module, List[torch.nn.Module]]] = None,
depth: int = 2,
display: bool = True,
custom_mapping: Dict[Any, Any] = None):
self.flop_counts: Dict[str, Dict[Any, int]] = defaultdict(lambda: defaultdict(int))
self.depth = depth
self.parents = ["Global"]
self.display = display
if custom_mapping is None:
custom_mapping = {}
if isinstance(mods, torch.nn.Module):
mods = [mods]
self.mods = mods
if mods is not None:
for mod in mods:
prefix = type(mod).__name__
for name, module in dict(mod.named_modules()).items():
if name == "":
name = prefix
else:
name = ".".join([prefix, name])
module.register_forward_pre_hook(self._enter_module(name))
module.register_forward_hook(self._exit_module(name))
self.flop_mapping = {**flop_mapping, **custom_mapping}
def _enter_module(self, name):
def f(module, inputs):
inputs = normalize_tuple(inputs)
out = self._create_pre_module(name)(*inputs)
return out
return f
def _exit_module(self, name):
def f(module, inputs, outputs):
outputs = normalize_tuple(outputs)
return self._create_post_module(name)(*outputs)
return f
def _create_post_module(self, name):
class PushState(torch.autograd.Function):
@staticmethod
def forward(ctx, *args):
assert(self.parents[-1] == name)
self.parents.pop()
args = tree_map(lambda x: x.clone() if isinstance(x, torch.Tensor) else x, args)
if len(args) == 1:
return args[0]
return args
@staticmethod
def backward(ctx, *grad_outs):
self.parents.append(name)
return grad_outs
return PushState.apply
def _create_pre_module(self, name):
class PopState(torch.autograd.Function):
@staticmethod
def forward(ctx, *args):
self.parents.append(name)
args = tree_map(lambda x: x.clone() if isinstance(x, torch.Tensor) else x, args)
if len(args) == 1:
return args[0]
return args
@staticmethod
def backward(ctx, *grad_outs):
assert(self.parents[-1] == name)
self.parents.pop()
return grad_outs
return PopState.apply
def get_flop_counts(self) -> Dict[str, Dict[Any, int]]:
"""Returns the flop counts as a dictionary of dictionaries. The outer
dictionary is keyed by module name, and the inner dictionary is keyed by
operation name.
Returns:
Dict[str, Dict[Any, int]]: The flop counts as a dictionary.
"""
return dict(self.flop_counts)
def get_table(self, depth=None):
if depth is None:
depth = self.depth
if depth is None:
depth = 999999
import tabulate
tabulate.PRESERVE_WHITESPACE = True
header = ["Module", "FLOPS"]
values = []
max_flops = 0
def process_mod(mod_name, depth):
nonlocal max_flops
total_flops = sum(self.flop_counts[mod_name].values())
max_flops = max(total_flops, max_flops)
padding = " " * depth
values = []
values.append([padding + mod_name, convert_num_to_suffix_str(total_flops)])
for k, v in self.flop_counts[mod_name].items():
values.append([padding + " - " + str(k), convert_num_to_suffix_str(v)])
return values
for mod in self.flop_counts.keys():
if mod == 'Global':
continue
mod_depth = mod.count(".") + 1
if mod_depth > depth:
continue
cur_values = process_mod(mod, mod_depth - 1)
for value in cur_values:
values.append(value)
# We do a bit of messing around here to only output the "Global" value
# if there are any FLOPs in there that aren't already fully contained by
# a module.
if 'Global' in self.flop_counts and sum(self.flop_counts['Global'].values()) > max_flops:
for idx, value in enumerate(values):
values[idx][0] = " " + values[idx][0]
values = process_mod('Global', 0) + values
if len(values) == 0:
values = [["Global", "0"]]
return tabulate.tabulate(values, headers=header, colalign=("left", "right"))
def __enter__(self):
self.flop_counts.clear()
super().__enter__()
return self
def __exit__(self, *args):
if self.display:
print(self.get_table(self.depth))
super().__exit__(*args)
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
kwargs = kwargs if kwargs else {}
out = func(*args, **kwargs)
func_packet = func._overloadpacket
if func_packet in self.flop_mapping:
flop_count_func = self.flop_mapping[func_packet]
args_shape, out_shape = tree_map(get_shape, (args, normalize_tuple(out)))
flop_count = flop_count_func(args_shape, out_shape) # type: ignore[operator]
for par in self.parents:
self.flop_counts[par][func_packet] += flop_count
return out