| # mypy: allow-untyped-defs |
| |
| from torch import nn |
| |
| |
| class QuantStub(nn.Module): |
| r"""Quantize stub module, before calibration, this is same as an observer, |
| it will be swapped as `nnq.Quantize` in `convert`. |
| |
| Args: |
| qconfig: quantization configuration for the tensor, |
| if qconfig is not provided, we will get qconfig from parent modules |
| """ |
| |
| def __init__(self, qconfig=None): |
| super().__init__() |
| if qconfig: |
| self.qconfig = qconfig |
| |
| def forward(self, x): |
| return x |
| |
| |
| class DeQuantStub(nn.Module): |
| r"""Dequantize stub module, before calibration, this is same as identity, |
| this will be swapped as `nnq.DeQuantize` in `convert`. |
| |
| Args: |
| qconfig: quantization configuration for the tensor, |
| if qconfig is not provided, we will get qconfig from parent modules |
| """ |
| |
| def __init__(self, qconfig=None): |
| super().__init__() |
| if qconfig: |
| self.qconfig = qconfig |
| |
| def forward(self, x): |
| return x |
| |
| |
| class QuantWrapper(nn.Module): |
| r"""A wrapper class that wraps the input module, adds QuantStub and |
| DeQuantStub and surround the call to module with call to quant and dequant |
| modules. |
| |
| This is used by the `quantization` utility functions to add the quant and |
| dequant modules, before `convert` function `QuantStub` will just be observer, |
| it observes the input tensor, after `convert`, `QuantStub` |
| will be swapped to `nnq.Quantize` which does actual quantization. Similarly |
| for `DeQuantStub`. |
| """ |
| quant: QuantStub |
| dequant: DeQuantStub |
| module: nn.Module |
| |
| def __init__(self, module): |
| super().__init__() |
| qconfig = getattr(module, "qconfig", None) |
| self.add_module("quant", QuantStub(qconfig)) |
| self.add_module("dequant", DeQuantStub(qconfig)) |
| self.add_module("module", module) |
| self.train(module.training) |
| |
| def forward(self, X): |
| X = self.quant(X) |
| X = self.module(X) |
| return self.dequant(X) |