blob: 89f16ad931d40424477a8ccad8c0bddfab5c81ee [file] [log] [blame]
# flake8: noqa: E266, C417, B950
import torch
import torch.nn as nn
import torch.nn.functional as F
##### Quantization Primitives ######
def dynamically_quantize_per_channel(x, quant_min, quant_max, target_dtype):
# assumes symmetric quantization
# assumes axis == 0
# assumes dense memory format
# TODO(future): relax ^ as needed
# default setup for affine quantization of activations
eps = torch.finfo(torch.float32).eps
# get min and max
min_val, max_val = torch.aminmax(x, dim=1)
# calculate scales and zero_points based on min and max
# reference: https://fburl.com/code/srbiybme
min_val_neg = torch.min(min_val, torch.zeros_like(min_val))
max_val_pos = torch.max(max_val, torch.zeros_like(max_val))
device = min_val_neg.device
# reference: https://fburl.com/code/4wll53rk
max_val_pos = torch.max(-min_val_neg, max_val_pos)
scales = max_val_pos / (float(quant_max - quant_min) / 2)
# ensure scales is the same dtype as the original tensor
scales = torch.clamp(scales, min=eps).to(x.dtype)
zero_points = torch.zeros(min_val_neg.size(), dtype=torch.int64, device=device)
# quantize based on qmin/qmax/scales/zp
# reference: https://www.internalfb.com/code/fbsource/[8edc275012b1]/fbcode/caffe2/torch/ao/quantization/fx/_decomposed.py?lines=63
x_div = x / scales.unsqueeze(-1)
x_round = torch.round(x_div)
x_zp = x_round + zero_points.unsqueeze(-1)
quant = torch.clamp(x_zp, quant_min, quant_max).to(target_dtype)
return quant, scales, zero_points
##### Weight-only int8 per-channel quantized code ######
def replace_linear_weight_only_int8_per_channel(module):
for name, child in module.named_children():
if isinstance(child, nn.Linear):
setattr(
module,
name,
WeightOnlyInt8Linear(child.in_features, child.out_features),
)
else:
replace_linear_weight_only_int8_per_channel(child)
class WeightOnlyInt8QuantHandler:
def __init__(self, mod):
self.mod = mod
@torch.no_grad()
def create_quantized_state_dict(self):
cur_state_dict = self.mod.state_dict()
for fqn, mod in self.mod.named_modules():
if isinstance(mod, torch.nn.Linear):
int8_weight, scales, _ = dynamically_quantize_per_channel(
mod.weight.float(), -128, 127, torch.int8
)
cur_state_dict[f"{fqn}.weight"] = int8_weight.to("cpu")
cur_state_dict[f"{fqn}.scales"] = scales.to(mod.weight.dtype).to("cpu")
return cur_state_dict
def convert_for_runtime(self):
replace_linear_weight_only_int8_per_channel(self.mod)
return self.mod
class WeightOnlyInt8Linear(torch.nn.Module):
__constants__ = ["in_features", "out_features"]
in_features: int
out_features: int
weight: torch.Tensor
def __init__(
self,
in_features: int,
out_features: int,
bias: bool = True,
device=None,
dtype=None,
) -> None:
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.register_buffer(
"weight", torch.empty((out_features, in_features), dtype=torch.int8)
)
self.register_buffer("scales", torch.ones(out_features, dtype=torch.bfloat16))
def forward(self, input: torch.Tensor) -> torch.Tensor:
return F.linear(input, self.weight.to(dtype=input.dtype)) * self.scales