| import torch |
| from torch._export.passes.constant_folding import constant_fold |
| from torch.ao.quantization.pt2e.duplicate_dq_pass import DuplicateDQPass |
| from torch.ao.quantization.pt2e.port_metadata_pass import PortNodeMetaForQDQ |
| from torch.ao.quantization.quantizer import ( # noqa: F401 |
| DerivedQuantizationSpec, |
| FixedQParamsQuantizationSpec, |
| QuantizationAnnotation, |
| QuantizationSpec, |
| QuantizationSpecBase, |
| Quantizer, |
| SharedQuantizationSpec, |
| ) |
| from torch.fx import GraphModule, Node |
| from torch.fx.passes.infra.pass_manager import PassManager |
| |
| from .pt2e.prepare import prepare |
| from .pt2e.qat_utils import _fold_conv_bn_qat, _fuse_conv_bn_qat |
| from .pt2e.representation import reference_representation_rewrite |
| from .pt2e.utils import _disallow_eval_train, _fuse_conv_bn_, _get_node_name_to_scope |
| from .quantize_fx import _convert_to_reference_decomposed_fx |
| |
| |
| __all__ = [ |
| "prepare_pt2e", |
| "prepare_qat_pt2e", |
| "convert_pt2e", |
| ] |
| |
| |
| def prepare_pt2e( |
| model: GraphModule, |
| quantizer: Quantizer, |
| ) -> GraphModule: |
| """Prepare a model for post training quantization |
| |
| Args: |
| * `model` (torch.fx.GraphModule): a model captured by `torch.export` API |
| in the short term we are using `torch._export.capture_pre_autograd_graph`, |
| in the long term we'll migrate to some `torch.export` API |
| * `quantizer`: A backend specific quantizer that conveys how user want the |
| model to be quantized. Tutorial for how to write a quantizer can be found here: |
| https://pytorch.org/tutorials/prototype/pt2e_quantizer.html |
| |
| Return: |
| A GraphModule with observer (based on quantizer annotation), ready for calibration |
| |
| Example:: |
| |
| import torch |
| from torch.ao.quantization.quantize_pt2e import prepare_pt2e |
| from torch._export import capture_pre_autograd_graph |
| from torch.ao.quantization.quantizer import ( |
| XNNPACKQuantizer, |
| get_symmetric_quantization_config, |
| ) |
| |
| class M(torch.nn.Module): |
| def __init__(self) -> None: |
| super().__init__() |
| self.linear = torch.nn.Linear(5, 10) |
| |
| def forward(self, x): |
| return self.linear(x) |
| |
| # initialize a floating point model |
| float_model = M().eval() |
| |
| # define calibration function |
| def calibrate(model, data_loader): |
| model.eval() |
| with torch.no_grad(): |
| for image, target in data_loader: |
| model(image) |
| |
| # Step 1. program capture |
| # NOTE: this API will be updated to torch.export API in the future, but the captured |
| # result shoud mostly stay the same |
| m = capture_pre_autograd_graph(m, *example_inputs) |
| # we get a model with aten ops |
| |
| # Step 2. quantization |
| # backend developer will write their own Quantizer and expose methods to allow |
| # users to express how they |
| # want the model to be quantized |
| quantizer = XNNPACKQuantizer().set_global(get_symmetric_quantization_config()) |
| m = prepare_pt2e(m, quantizer) |
| |
| # run calibration |
| # calibrate(m, sample_inference_data) |
| """ |
| torch._C._log_api_usage_once("quantization_api.quantize_pt2e.prepare_pt2e") |
| original_graph_meta = model.meta |
| node_name_to_scope = _get_node_name_to_scope(model) |
| # TODO: check qconfig_mapping to make sure conv and bn are both configured |
| # to be quantized before fusion |
| # TODO: (maybe) rewrite this with subgraph_rewriter |
| _fuse_conv_bn_(model) |
| model = quantizer.transform_for_annotation(model) |
| quantizer.annotate(model) |
| quantizer.validate(model) |
| model = prepare(model, node_name_to_scope, is_qat=False) |
| model.meta.update(original_graph_meta) |
| model = _disallow_eval_train(model) |
| return model |
| |
| |
| def prepare_qat_pt2e( |
| model: GraphModule, |
| quantizer: Quantizer, |
| ) -> GraphModule: |
| """Prepare a model for quantization aware training |
| |
| Args: |
| * `model` (torch.fx.GraphModule): see :func:`~torch.ao.quantization.quantize_pt2e.prepare_pt2e` |
| * `quantizer`: see :func:`~torch.ao.quantization.quantize_pt2e.prepare_pt2e` |
| |
| Return: |
| A GraphModule with fake quant modules (based on quantizer annotation), ready for |
| quantization aware training |
| |
| Example:: |
| import torch |
| from torch.ao.quantization.quantize_pt2e import prepare_qat_pt2e |
| from torch._export import capture_pre_autograd_graph |
| from torch.ao.quantization.quantizer import ( |
| XNNPACKQuantizer, |
| get_symmetric_quantization_config, |
| ) |
| |
| class M(torch.nn.Module): |
| def __init__(self) -> None: |
| super().__init__() |
| self.linear = torch.nn.Linear(5, 10) |
| |
| def forward(self, x): |
| return self.linear(x) |
| |
| # initialize a floating point model |
| float_model = M().eval() |
| |
| # define the training loop for quantization aware training |
| def train_loop(model, train_data): |
| model.train() |
| for image, target in data_loader: |
| ... |
| |
| # Step 1. program capture |
| # NOTE: this API will be updated to torch.export API in the future, but the captured |
| # result shoud mostly stay the same |
| m = capture_pre_autograd_graph(m, *example_inputs) |
| # we get a model with aten ops |
| |
| # Step 2. quantization |
| # backend developer will write their own Quantizer and expose methods to allow |
| # users to express how they |
| # want the model to be quantized |
| quantizer = XNNPACKQuantizer().set_global(get_symmetric_quantization_config()) |
| m = prepare_qat_pt2e(m, quantizer) |
| |
| # run quantization aware training |
| train_loop(prepared_model, train_loop) |
| |
| """ |
| torch._C._log_api_usage_once("quantization_api.quantize_pt2e.prepare_qat_pt2e") |
| original_graph_meta = model.meta |
| node_name_to_scope = _get_node_name_to_scope(model) |
| model = quantizer.transform_for_annotation(model) |
| quantizer.annotate(model) |
| quantizer.validate(model) |
| # Perform fusion after annotate to avoid quantizing ops in the new |
| # subgraph that don't need to be quantized |
| # TODO: only fuse if conv and bn are both configured to be quantized |
| _fuse_conv_bn_qat(model) |
| model = prepare(model, node_name_to_scope, is_qat=True) |
| model.meta.update(original_graph_meta) |
| model = _disallow_eval_train(model) |
| return model |
| |
| |
| _QUANT_OPS = [ |
| torch.ops.quantized_decomposed.quantize_per_tensor.default, |
| torch.ops.quantized_decomposed.quantize_per_tensor.tensor, |
| torch.ops.quantized_decomposed.quantize_per_channel.default, |
| ] |
| |
| |
| def _quant_node_constraint(n: Node) -> bool: |
| """If there is any pure ops between get_attr and quantize op they will be const propagated |
| e.g. get_attr(weight) -> transpose -> quantize -> dequantize* |
| (Note: dequantize op is not going to be constant propagated) |
| |
| This filter is added because we don't want to constant fold the things that are not |
| related to quantization |
| """ |
| return n.op == "call_function" and n.target in _QUANT_OPS |
| |
| |
| def convert_pt2e( |
| model: GraphModule, |
| use_reference_representation: bool = False, |
| fold_quantize: bool = True, |
| ) -> GraphModule: |
| """Convert a calibrated/trained model to a quantized model |
| |
| Args: |
| * `model` (torch.fx.GraphModule): calibrated/trained model |
| * `use_reference_representation` (bool): boolean flag to indicate whether to produce referece representation or not |
| * `fold_quantize` (bool): boolean flag for whether fold the quantize op or not |
| |
| Returns: |
| quantized model, either in q/dq representation or reference representation |
| |
| Example:: |
| |
| # prepared_model: the model produced by `prepare_pt2e`/`prepare_qat_pt2e` and calibration/training |
| # `convert_pt2e` produces a quantized model that represents quantized computation with |
| # quantize dequantize ops and fp32 ops by default. |
| # Please refer to |
| # https://pytorch.org/tutorials/prototype/pt2e_quant_ptq_static.html#convert-the-calibrated-model-to-a-quantized-model |
| # for detailed explanation of output quantized model |
| quantized_model = convert_pt2e(prepared_model) |
| |
| """ # flake8: noqa |
| torch._C._log_api_usage_once("quantization_api.quantize_pt2e.convert_pt2e") |
| if not isinstance(use_reference_representation, bool): |
| raise ValueError( |
| "Unexpected argument type for `use_reference_representation`, " |
| f"please make sure you intend to pass argument {use_reference_representation} to convert_pt2e" |
| ) |
| original_graph_meta = model.meta |
| model = _convert_to_reference_decomposed_fx(model) |
| model = _fold_conv_bn_qat(model) |
| |
| pm = PassManager([DuplicateDQPass()]) |
| model = pm(model).graph_module |
| |
| pm = PassManager([PortNodeMetaForQDQ()]) |
| model = pm(model).graph_module |
| |
| if fold_quantize: |
| constant_fold(model, _quant_node_constraint) |
| |
| if use_reference_representation: |
| model = reference_representation_rewrite(model) |
| |
| model.meta.update(original_graph_meta) |
| model = _disallow_eval_train(model) |
| return model |