| # flake8: noqa: E266, C417, B950 |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| |
| |
| ##### Quantization Primitives ###### |
| |
| |
| def dynamically_quantize_per_channel(x, quant_min, quant_max, target_dtype): |
| # assumes symmetric quantization |
| # assumes axis == 0 |
| # assumes dense memory format |
| # TODO(future): relax ^ as needed |
| |
| # default setup for affine quantization of activations |
| eps = torch.finfo(torch.float32).eps |
| |
| # get min and max |
| min_val, max_val = torch.aminmax(x, dim=1) |
| |
| # calculate scales and zero_points based on min and max |
| # reference: https://fburl.com/code/srbiybme |
| min_val_neg = torch.min(min_val, torch.zeros_like(min_val)) |
| max_val_pos = torch.max(max_val, torch.zeros_like(max_val)) |
| device = min_val_neg.device |
| |
| # reference: https://fburl.com/code/4wll53rk |
| max_val_pos = torch.max(-min_val_neg, max_val_pos) |
| scales = max_val_pos / (float(quant_max - quant_min) / 2) |
| # ensure scales is the same dtype as the original tensor |
| scales = torch.clamp(scales, min=eps).to(x.dtype) |
| zero_points = torch.zeros(min_val_neg.size(), dtype=torch.int64, device=device) |
| |
| # quantize based on qmin/qmax/scales/zp |
| # reference: https://www.internalfb.com/code/fbsource/[8edc275012b1]/fbcode/caffe2/torch/ao/quantization/fx/_decomposed.py?lines=63 |
| x_div = x / scales.unsqueeze(-1) |
| x_round = torch.round(x_div) |
| x_zp = x_round + zero_points.unsqueeze(-1) |
| quant = torch.clamp(x_zp, quant_min, quant_max).to(target_dtype) |
| |
| return quant, scales, zero_points |
| |
| |
| ##### Weight-only int8 per-channel quantized code ###### |
| |
| |
| def replace_linear_weight_only_int8_per_channel(module): |
| for name, child in module.named_children(): |
| if isinstance(child, nn.Linear): |
| setattr( |
| module, |
| name, |
| WeightOnlyInt8Linear(child.in_features, child.out_features), |
| ) |
| else: |
| replace_linear_weight_only_int8_per_channel(child) |
| |
| |
| class WeightOnlyInt8QuantHandler: |
| def __init__(self, mod): |
| self.mod = mod |
| |
| @torch.no_grad() |
| def create_quantized_state_dict(self): |
| cur_state_dict = self.mod.state_dict() |
| for fqn, mod in self.mod.named_modules(): |
| if isinstance(mod, torch.nn.Linear): |
| int8_weight, scales, _ = dynamically_quantize_per_channel( |
| mod.weight.float(), -128, 127, torch.int8 |
| ) |
| cur_state_dict[f"{fqn}.weight"] = int8_weight.to("cpu") |
| cur_state_dict[f"{fqn}.scales"] = scales.to(mod.weight.dtype).to("cpu") |
| |
| return cur_state_dict |
| |
| def convert_for_runtime(self): |
| replace_linear_weight_only_int8_per_channel(self.mod) |
| return self.mod |
| |
| |
| class WeightOnlyInt8Linear(torch.nn.Module): |
| __constants__ = ["in_features", "out_features"] |
| in_features: int |
| out_features: int |
| weight: torch.Tensor |
| |
| def __init__( |
| self, |
| in_features: int, |
| out_features: int, |
| bias: bool = True, |
| device=None, |
| dtype=None, |
| ) -> None: |
| factory_kwargs = {"device": device, "dtype": dtype} |
| super().__init__() |
| self.in_features = in_features |
| self.out_features = out_features |
| self.register_buffer( |
| "weight", torch.empty((out_features, in_features), dtype=torch.int8) |
| ) |
| self.register_buffer("scales", torch.ones(out_features, dtype=torch.bfloat16)) |
| |
| def forward(self, input: torch.Tensor) -> torch.Tensor: |
| return F.linear(input, self.weight.to(dtype=input.dtype)) * self.scales |