| import functools |
| import sys |
| import warnings |
| from typing import Callable, Sequence |
| |
| import torch |
| import torch._C._onnx as _C_onnx |
| import torch.onnx |
| from torch import _C |
| |
| # Monkey-patch graph manipulation methods on Graph, used for the ONNX symbolics |
| from torch.onnx import ( # noqa: F401 |
| _patch_torch, |
| _type_utils, |
| errors, |
| symbolic_helper, |
| symbolic_opset9 as opset9, |
| ) |
| from torch.onnx._globals import GLOBALS |
| from torch.onnx._internal import _beartype, jit_utils, registration |
| |
| # EDITING THIS FILE? READ THIS FIRST! |
| # see Note [Edit Symbolic Files] in README.md |
| |
| # This file exports ONNX ops for opset 10 |
| # Opset 10 is supported by ONNX release 1.5.0 |
| # release on 04/24/19 |
| |
| |
| __all__ = [ |
| "dequantize", |
| "div", |
| "embedding_bag", |
| "fake_quantize_per_tensor_affine", |
| "flip", |
| "fmod", |
| "isfinite", |
| "isinf", |
| "nan_to_num", |
| "quantize_per_tensor", |
| "quantized_add_relu", |
| "quantized_add", |
| "quantized_cat", |
| "quantized_conv2d_relu", |
| "quantized_conv2d", |
| "quantized_group_norm", |
| "quantized_hardswish", |
| "quantized_instance_norm", |
| "quantized_layer_norm", |
| "quantized_leaky_relu", |
| "quantized_linear", |
| "quantized_mul", |
| "quantized_sigmoid", |
| "slice", |
| "sort", |
| "topk", |
| ] |
| |
| |
| _onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=10) |
| |
| |
| def _apply_params(*args, **kwargs): |
| """Returns a decorator that calls the decorated (higher-order) function with the given parameters.""" |
| |
| def _apply(fn): |
| return fn(*args, **kwargs) |
| |
| return _apply |
| |
| |
| @_onnx_symbolic("aten::div") |
| @_beartype.beartype |
| def div(g: jit_utils.GraphContext, self, other, *args): |
| if len(args) == 0: |
| return opset9.true_divide(g, self, other) |
| else: |
| return _div_rounding_mode(g, self, other, *args) |
| |
| |
| @symbolic_helper.parse_args("v", "v", "s") |
| @_beartype.beartype |
| def _div_rounding_mode(g: jit_utils.GraphContext, self, other, rounding_mode): |
| if rounding_mode == "floor": |
| return _floor_divide(g, self, other) |
| else: |
| return opset9._div_rounding_mode(g, self, other, rounding_mode) |
| |
| |
| @_onnx_symbolic("aten::_floor_divide") |
| @_beartype.beartype |
| def _floor_divide(g: jit_utils.GraphContext, self, other): |
| if symbolic_helper._is_fp(self) or symbolic_helper._is_fp(other): |
| out = opset9.true_divide(g, self, other) |
| return g.op("Floor", out) |
| else: |
| # Integer division does trunction rounding |
| div = g.op("Div", self, other) |
| # Division is negative if: self < 0 != other < 0 |
| zero = g.op("Constant", value_t=torch.tensor(0, dtype=torch.int64)) |
| negative = g.op("Xor", g.op("Less", self, zero), g.op("Less", other, zero)) |
| |
| # For negative numbers with self % other != 0, subtract 1 to round down instead of up |
| mod = g.op("Mod", self, other, fmod_i=0) |
| fixup_mask = g.op("And", negative, g.op("Not", g.op("Equal", mod, zero))) |
| |
| one = g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64)) |
| fixup = g.op("Sub", div, one) |
| return g.op("Where", fixup_mask, fixup, div) |
| |
| |
| @_onnx_symbolic("aten::sort") |
| @symbolic_helper.parse_args("v", "i", "i", "none") |
| @_beartype.beartype |
| def sort(g: jit_utils.GraphContext, self, dim, decending, out=None): |
| return symbolic_helper._sort_helper(g, self, dim, decending=decending, out=out) |
| |
| |
| @_onnx_symbolic("aten::topk") |
| @symbolic_helper.parse_args("v", "v", "i", "i", "i", "none") |
| @_beartype.beartype |
| def topk(g: jit_utils.GraphContext, self, k, dim, largest, sorted, out=None): |
| return symbolic_helper._topk_helper( |
| g, self, k, dim, largest=largest, sorted=sorted, out=out |
| ) |
| |
| |
| @_onnx_symbolic( |
| "aten::max_pool1d", |
| decorate=[ |
| _apply_params( |
| "max_pool1d", torch.nn.modules.utils._single, 1, return_indices=False |
| ) |
| ], |
| ) |
| @_onnx_symbolic( |
| "aten::max_pool2d", |
| decorate=[ |
| _apply_params( |
| "max_pool2d", torch.nn.modules.utils._pair, 2, return_indices=False |
| ) |
| ], |
| ) |
| @_onnx_symbolic( |
| "aten::max_pool3d", |
| decorate=[ |
| _apply_params( |
| "max_pool3d", torch.nn.modules.utils._triple, 3, return_indices=False |
| ) |
| ], |
| ) |
| @_onnx_symbolic( |
| "aten::max_pool1d_with_indices", |
| decorate=[ |
| _apply_params( |
| "max_pool1d_with_indices", |
| torch.nn.modules.utils._single, |
| 1, |
| return_indices=True, |
| ) |
| ], |
| ) |
| @_onnx_symbolic( |
| "aten::max_pool2d_with_indices", |
| decorate=[ |
| _apply_params( |
| "max_pool2d_with_indices", |
| torch.nn.modules.utils._pair, |
| 2, |
| return_indices=True, |
| ) |
| ], |
| ) |
| @_onnx_symbolic( |
| "aten::max_pool3d_with_indices", |
| decorate=[ |
| _apply_params( |
| "max_pool3d_with_indices", |
| torch.nn.modules.utils._triple, |
| 3, |
| return_indices=True, |
| ) |
| ], |
| ) |
| @_beartype.beartype |
| def _max_pool(name: str, tuple_fn: Callable, ndims: int, return_indices: bool): |
| @symbolic_helper.quantized_args(True, False, False, False, False, False) |
| @symbolic_helper.parse_args("v", "is", "is", "is", "is", "i") |
| def symbolic_fn(g, input, kernel_size, stride, padding, dilation, ceil_mode): |
| if not stride: |
| stride = kernel_size |
| kwargs = { |
| "kernel_shape_i": tuple_fn(kernel_size), |
| "pads_i": tuple_fn(padding) * 2, |
| "strides_i": tuple_fn(stride), |
| "ceil_mode_i": ceil_mode, |
| } |
| if set(tuple_fn(dilation)) != {1}: |
| kwargs["dilations_i"] = tuple_fn(dilation) |
| # easy but hacky way to get flattened indices values |
| # to be used to convert the indices values to non-flattened. |
| # In ONNX the indices are computed as a flatten 1-D tensor, |
| # so the values in indices are in [0, N x C x D1 x ... x Dn). |
| # To convert the indices to the same format used by Pytorch, |
| # we first execute a maxpool with a kernel and stride of 1 on the same input. |
| # This will result in a tensor of indices in which each index will have it's own value. |
| # Using this tensor as a reference, we extract the first index of each axis and subtract |
| # it from each index of this axis in the indices to convert. |
| # This step will result in a tensor were each dimension has values of indices within |
| # the dimension it is in. |
| # For more information : |
| # https://github.com/pytorch/pytorch/pull/16455#issuecomment-460776407 |
| if return_indices: |
| r, indices = g.op("MaxPool", input, outputs=2, **kwargs) |
| _, flattened_indices = g.op( |
| "MaxPool", |
| input, |
| outputs=2, |
| kernel_shape_i=[1 for _ in range(ndims)], |
| strides_i=[1 for _ in range(ndims)], |
| ) |
| # convert indices to have non-flattened indices values |
| s = symbolic_helper._slice_helper( |
| g, |
| flattened_indices, |
| axes=[2 + i for i in range(ndims)], |
| starts=tuple_fn(0), |
| ends=tuple_fn(1), |
| ) |
| indices = opset9.sub(g, indices, s) |
| return r, indices |
| else: |
| r = g.op("MaxPool", input, outputs=1, **kwargs) |
| return r |
| |
| return symbolic_fn |
| |
| |
| @_onnx_symbolic( |
| "aten::avg_pool1d", |
| decorate=[_apply_params("avg_pool1d", torch.nn.modules.utils._single)], |
| ) |
| @_onnx_symbolic( |
| "aten::avg_pool2d", |
| decorate=[_apply_params("avg_pool2d", torch.nn.modules.utils._pair)], |
| ) |
| @_onnx_symbolic( |
| "aten::avg_pool3d", |
| decorate=[_apply_params("avg_pool3d", torch.nn.modules.utils._triple)], |
| ) |
| @_beartype.beartype |
| def _avg_pool(name, tuple_fn): |
| @symbolic_helper.quantized_args(True, False, False, False, False, False, False) |
| @symbolic_helper.parse_args("v", "is", "is", "is", "i", "i", "none") |
| @_beartype.beartype |
| def symbolic_fn( |
| g, |
| input: _C.Value, |
| kernel_size: Sequence[int], |
| stride: Sequence[int], |
| padding: Sequence[int], |
| ceil_mode: int, |
| count_include_pad: int, |
| divisor_override=None, |
| ): |
| if not stride: |
| stride = kernel_size |
| padding = symbolic_helper._avgpool_helper( |
| tuple_fn, padding, kernel_size, stride, divisor_override, name |
| ) |
| assert isinstance(padding, tuple) |
| if count_include_pad: |
| input = opset9._op_with_optional_float_cast( |
| g, |
| "Pad", |
| input, |
| pads_i=((0,) * 2 + padding) * 2, |
| mode_s="constant", |
| value_f=0.0, |
| opset_before=11, |
| ) |
| padding = (0,) * len(padding) |
| output = g.op( |
| "AveragePool", |
| input, |
| kernel_shape_i=tuple_fn(kernel_size), |
| strides_i=tuple_fn(stride), |
| pads_i=padding * 2, |
| ceil_mode_i=ceil_mode, |
| ) |
| return output |
| |
| return symbolic_fn |
| |
| |
| @_onnx_symbolic( |
| "aten::upsample_nearest1d", |
| decorate=[_apply_params("upsample_nearest1d", 3, "nearest")], |
| ) |
| @_onnx_symbolic( |
| "aten::upsample_nearest2d", |
| decorate=[_apply_params("upsample_nearest2d", 4, "nearest")], |
| ) |
| @_onnx_symbolic( |
| "aten::upsample_nearest3d", |
| decorate=[_apply_params("upsample_nearest3d", 5, "nearest")], |
| ) |
| @_onnx_symbolic( |
| "aten::upsample_linear1d", |
| decorate=[_apply_params("upsample_linear1d", 3, "linear")], |
| ) |
| @_onnx_symbolic( |
| "aten::upsample_bilinear2d", |
| decorate=[_apply_params("upsample_bilinear2d", 4, "linear")], |
| ) |
| @_onnx_symbolic( |
| "aten::upsample_trilinear3d", |
| decorate=[_apply_params("upsample_trilinear3d", 5, "linear")], |
| ) |
| @_beartype.beartype |
| def _interpolate(name, dim, interpolate_mode): |
| @symbolic_helper.quantized_args(True, False, False) |
| @_beartype.beartype |
| def symbolic_fn(g, input, output_size, *args): |
| scales, align_corners = symbolic_helper._get_interpolate_attributes( |
| g, interpolate_mode, args |
| ) |
| symbolic_helper._interpolate_warning(interpolate_mode) |
| align_corners = symbolic_helper._maybe_get_scalar(align_corners) |
| if align_corners: |
| return symbolic_helper._unimplemented(name, "align_corners == True", input) |
| if scales is None: |
| scales = symbolic_helper._interpolate_size_to_scales( |
| g, input, output_size, dim |
| ) |
| return g.op("Resize", input, scales, mode_s=interpolate_mode) |
| |
| return symbolic_fn |
| |
| |
| @_onnx_symbolic("aten::__interpolate") |
| @_beartype.beartype |
| def __interpolate( |
| g: jit_utils.GraphContext, |
| input, |
| size, |
| scale_factor, |
| mode, |
| align_corners, |
| recompute_scale_factor, |
| antialias, |
| ): |
| scales, mode = symbolic_helper._interpolate_get_scales_and_mode( |
| g, input, size, scale_factor, mode, align_corners |
| ) |
| return g.op("Resize", input, scales, mode_s=mode) |
| |
| |
| @_beartype.beartype |
| def _slice( |
| g: jit_utils.GraphContext, |
| input, |
| axes, |
| starts, |
| ends, |
| steps=None, |
| dynamic_slice=False, |
| ): |
| if dynamic_slice: |
| starts = symbolic_helper._unsqueeze_helper(g, starts, [0]) |
| ends = symbolic_helper._unsqueeze_helper(g, ends, [0]) |
| if isinstance(axes, int): |
| axes = g.op("Constant", value_t=torch.tensor(axes)) |
| axes = symbolic_helper._unsqueeze_helper(g, axes, [0]) |
| else: |
| assert len(starts) == len(ends) |
| assert len(starts) == len(axes) |
| assert steps is None or len(starts) == len(steps) |
| if ( |
| len(starts) == 1 |
| and starts[0] == 0 |
| and ends[0] == 9223372036854775807 |
| and (steps is None or (len(steps) == 1 and steps[0] == 1)) |
| ): |
| return input |
| axes = g.op("Constant", value_t=torch.tensor(axes)) |
| starts = g.op("Constant", value_t=torch.tensor(starts)) |
| ends = g.op("Constant", value_t=torch.tensor(ends)) |
| if steps is None: |
| return g.op("Slice", input, starts, ends, axes) |
| steps = g.op("Constant", value_t=torch.tensor(steps)) |
| return g.op("Slice", input, starts, ends, axes, steps) |
| |
| |
| @_onnx_symbolic("aten::slice") |
| @_beartype.beartype |
| def slice(g: jit_utils.GraphContext, self, *args): |
| if len(args) == 4: |
| # aten::slice(Tensor self, int dim, int? start=None, int? end=None, int step=1) -> Tensor |
| dim, start, end, step = args |
| elif len(args) == 3: |
| # aten::slice(t[] l, int? start=None, int? end=None, int step=1) -> t[] |
| start, end, step = args |
| dim = 0 |
| else: |
| raise errors.SymbolicValueError("Unknown aten::slice signature", self) |
| is_start_none = start.node().kind() == "prim::Constant" and isinstance( |
| start.type(), _C.NoneType |
| ) |
| is_end_none = end.node().kind() == "prim::Constant" and isinstance( |
| end.type(), _C.NoneType |
| ) |
| is_start_onnx_const = start.node().kind() == "onnx::Constant" |
| is_end_onnx_const = end.node().kind() == "onnx::Constant" |
| step = symbolic_helper._parse_arg(step, "i") |
| if ( |
| (not is_start_none and not is_start_onnx_const) |
| or (not isinstance(end, int) and not is_end_none and not is_end_onnx_const) |
| or (not isinstance(dim, int) and dim.node().kind() != "onnx::Constant") |
| ): |
| dynamic_slice = True |
| if is_start_none: |
| start = g.op("Constant", value_t=torch.tensor(0)) |
| if is_end_none: |
| end = g.op("Constant", value_t=torch.tensor(9223372036854775807)) |
| else: |
| start = [0 if is_start_none else symbolic_helper._parse_arg(start, "i")] |
| end = [ |
| 9223372036854775807 if is_end_none else symbolic_helper._parse_arg(end, "i") |
| ] |
| dim = [symbolic_helper._parse_arg(dim, "i")] |
| dynamic_slice = False |
| return symbolic_helper._slice_helper( |
| g, |
| self, |
| axes=dim, |
| starts=start, |
| ends=end, |
| steps=[step], |
| dynamic_slice=dynamic_slice, |
| ) |
| |
| |
| @_onnx_symbolic("aten::flip") |
| @symbolic_helper.parse_args("v", "is") |
| @_beartype.beartype |
| def flip(g: jit_utils.GraphContext, input, dims): |
| return symbolic_helper._slice_helper( |
| g, |
| input, |
| axes=dims, |
| starts=[-1] * len(dims), |
| ends=[-9223372036854775807] * len(dims), |
| steps=[-1] * len(dims), |
| ) |
| |
| |
| @_onnx_symbolic("aten::fmod") |
| @_beartype.beartype |
| def fmod(g: jit_utils.GraphContext, input, other): |
| return g.op("Mod", input, other, fmod_i=1) |
| |
| |
| @_onnx_symbolic("aten::embedding_bag") |
| @symbolic_helper.parse_args("v", "v", "v", "i", "i", "i", "v", "i", "i") |
| @_beartype.beartype |
| def embedding_bag( |
| g: jit_utils.GraphContext, |
| embedding_matrix, |
| indices, |
| offsets, |
| scale_grad_by_freq, |
| mode, |
| sparse, |
| per_sample_weights, |
| include_last_offset, |
| padding_idx, |
| ): |
| if scale_grad_by_freq and GLOBALS.export_training: |
| return symbolic_helper._onnx_unsupported( |
| "embedding_bag with scale_grad_by_freq for training mode" |
| ) |
| if padding_idx is not None and padding_idx >= 0: |
| raise RuntimeError("embedding_bag with padding_idx") |
| |
| warnings.warn( |
| "Export of embedding_bag with dynamic input/offsets shape is not supported in opset 10. " |
| "Please use opset 11 or higher to export model for dynamic input shape.'" |
| ) |
| offsets_dim_0 = symbolic_helper._get_tensor_dim_size(offsets, 0) |
| if offsets_dim_0 is not None: |
| if include_last_offset: |
| offset_len = offsets_dim_0 - 1 |
| offsets_extended = offsets |
| else: |
| offset_len = offsets_dim_0 |
| offsets_extended = [ |
| offsets, |
| g.op("Constant", value_t=torch.tensor([sys.maxsize])), |
| ] |
| offsets_extended = g.op("Concat", *offsets_extended, axis_i=0) |
| list_ = [] |
| for i in range(offset_len): |
| start_ = symbolic_helper._unsqueeze_helper( |
| g, |
| opset9.select(g, offsets_extended, torch.tensor(0), torch.tensor(i)), |
| [0], |
| ) |
| end_ = symbolic_helper._unsqueeze_helper( |
| g, |
| opset9.select( |
| g, offsets_extended, torch.tensor(0), torch.tensor(i + 1) |
| ), |
| [0], |
| ) |
| axes_ = g.op("Constant", value_t=torch.tensor([0])) |
| indices_row = g.op("Slice", indices, start_, end_, axes_) |
| |
| embeddings = g.op("Gather", embedding_matrix, indices_row) |
| if not symbolic_helper._is_none(per_sample_weights): |
| per_sample_weights_row = g.op( |
| "Slice", per_sample_weights, start_, end_, axes_ |
| ) |
| per_sample_weights_row = symbolic_helper._unsqueeze_helper( |
| g, per_sample_weights_row, [1] |
| ) |
| embeddings = g.op("Mul", embeddings, per_sample_weights_row) |
| if mode == 0: |
| embeddings = symbolic_helper._reducesum_helper( |
| g, embeddings, axes_i=[0], keepdims_i=0 |
| ) |
| elif mode == 1: |
| embeddings = g.op("ReduceMean", embeddings, axes_i=[0], keepdims_i=0) |
| else: |
| embeddings = g.op("ReduceMax", embeddings, axes_i=[0], keepdims_i=0) |
| |
| embeddings = symbolic_helper._unsqueeze_helper(g, embeddings, [0]) |
| list_.append(embeddings) |
| |
| output = g.op("Concat", *list_, axis_i=0) |
| # aten::embedding_bag returns a tuple of 4 elements: output, offset2bag, bag_size, max_indices. |
| # But the last three outputs are not used in torch.nn.EmbeddingBag or torch.nn.functional.embedding_bag. |
| return output, None, None, None |
| else: |
| return symbolic_helper._onnx_unsupported( |
| "embedding_bag with unknown shape of offsets for opset 10 is not supported. " |
| "please use opset 11 or higher." |
| ) |
| |
| |
| @_onnx_symbolic("aten::fake_quantize_per_tensor_affine") |
| @symbolic_helper.parse_args("v", "v", "v", "i", "i") |
| @_beartype.beartype |
| def fake_quantize_per_tensor_affine( |
| g: jit_utils.GraphContext, |
| inputs, |
| scale, |
| zero_point, |
| quant_min=-128, |
| quant_max=127, |
| ): |
| # NOTE: (0, 127) is a special case. PyTorch restricts activations to be in the range (0, 127). |
| # https://github.com/pytorch/pytorch/blob/b34b192d6b97325c9f78e5995c48c8498ede34bd/torch/ao/quantization/observer.py#L1422 |
| if (quant_min, quant_max) == (0, 127): |
| symbolic_helper._onnx_opset_unsupported_detailed( |
| "fake_quantize_per_tensor_affine", |
| 10, |
| 13, |
| "Quantize range (0, 127) not supported, requires opset 13 Clip", |
| inputs, |
| ) |
| if (quant_min, quant_max) not in [(0, 255), (-128, 127)]: |
| raise errors.SymbolicValueError( |
| f"For (quant_min, quant_max), ONNX allows only (0, 255) and (-128, 127). " |
| f"Got ({quant_min}, {quant_max})", |
| inputs, |
| ) |
| scale = symbolic_helper._maybe_get_scalar(scale) |
| if scale is None: |
| symbolic_helper._onnx_opset_unsupported_detailed( |
| "fake_quantize_per_tensor_affine", |
| 10, |
| 13, |
| "Non-constant scale not supported", |
| inputs, |
| ) |
| scale = scale.float().data # Avoid exporter generating double type |
| if quant_min == 0: |
| zero_point = g.op("Cast", zero_point, to_i=_C_onnx.TensorProtoDataType.UINT8) |
| else: |
| zero_point = g.op("Cast", zero_point, to_i=_C_onnx.TensorProtoDataType.INT8) |
| return g.op( |
| "DequantizeLinear", |
| g.op("QuantizeLinear", inputs, scale, zero_point), |
| scale, |
| zero_point, |
| ) |
| |
| |
| @_onnx_symbolic("aten::isinf") |
| @_beartype.beartype |
| def isinf(g: jit_utils.GraphContext, input): |
| return g.op("IsInf", opset9._cast_Double(g, input, False)) # type: ignore[attr-defined] |
| |
| |
| @_onnx_symbolic("aten::isfinite") |
| @_beartype.beartype |
| def isfinite(g: jit_utils.GraphContext, input): |
| inf_node = isinf(g, input) |
| nan_node = opset9.isnan(g, input) |
| return opset9.__not_(g, opset9.__or_(g, inf_node, nan_node)) |
| |
| |
| @_onnx_symbolic("aten::quantize_per_tensor") |
| @_beartype.beartype |
| def quantize_per_tensor(g: jit_utils.GraphContext, input, scale, zero_point, dtype): |
| dtype = symbolic_helper._get_const(dtype, "i", "dtype") |
| # TODO(justinchuby): Extract all the cast ops into a helper function. |
| zero_point = g.op( |
| "Cast", zero_point, to_i=_type_utils.JitScalarType(dtype).onnx_type() |
| ) |
| scale = g.op("Cast", scale, to_i=_C_onnx.TensorProtoDataType.FLOAT) |
| return symbolic_helper.quantize_helper(g, input, scale, zero_point) |
| |
| |
| @_onnx_symbolic("aten::dequantize") |
| @_beartype.beartype |
| def dequantize(g: jit_utils.GraphContext, input): |
| return symbolic_helper.dequantize_helper(g, input)[0] |
| |
| |
| @_onnx_symbolic("aten::nan_to_num") |
| @symbolic_helper.parse_args("v", "f", "f", "f") |
| @_beartype.beartype |
| def nan_to_num(g: jit_utils.GraphContext, input, nan, posinf, neginf): |
| # Cannot create a int type tensor with inf/nan values, so we simply |
| # return the original tensor |
| if not symbolic_helper._is_fp(input): |
| return input |
| input_dtype = _type_utils.JitScalarType.from_name(input.type().scalarType()).dtype() |
| if nan is None: |
| nan = 0.0 |
| nan_cond = opset9.isnan(g, input) |
| nan_result = g.op( |
| "Where", |
| nan_cond, |
| g.op("Constant", value_t=torch.tensor([nan], dtype=input_dtype)), |
| input, |
| ) |
| |
| # For None values of posinf, neginf we use the greatest/lowest finite |
| # value representable by input’s dtype. |
| finfo = torch.finfo(input_dtype) |
| if posinf is None: |
| posinf = finfo.max |
| posinf_cond = opset9.logical_and( |
| g, |
| isinf(g, nan_result), |
| opset9.gt(g, nan_result, g.op("Constant", value_t=torch.LongTensor([0]))), |
| ) |
| nan_posinf_result = g.op( |
| "Where", |
| posinf_cond, |
| g.op("Constant", value_t=torch.tensor([posinf], dtype=input_dtype)), |
| nan_result, |
| ) |
| |
| if neginf is None: |
| neginf = finfo.min |
| neginf_cond = opset9.logical_and( |
| g, |
| isinf(g, nan_posinf_result), |
| opset9.lt( |
| g, nan_posinf_result, g.op("Constant", value_t=torch.LongTensor([0])) |
| ), |
| ) |
| return g.op( |
| "Where", |
| neginf_cond, |
| g.op("Constant", value_t=torch.tensor([neginf], dtype=input_dtype)), |
| nan_posinf_result, |
| ) |
| |
| |
| # Quantized symbolics --------------------------------------------------------- |
| # https://github.com/pytorch/pytorch/wiki/PyTorch-ONNX-exporter#quantized-model-export |
| # Support starts from opset 10 because `DequantizeLinear` and `QuantizeLinear` were |
| # introduced in opset version 10. |
| @_onnx_symbolic("quantized::linear") |
| @_beartype.beartype |
| def quantized_linear( |
| g: jit_utils.GraphContext, q_input, q_weight, bias, op_scale, op_zero_point |
| ): |
| input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input) |
| weight, weight_scale, _, _ = symbolic_helper.dequantize_helper(g, q_weight) |
| q_bias = symbolic_helper.requantize_bias_helper(g, bias, input_scale, weight_scale) |
| bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias) |
| |
| output = opset9.linear(g, input, weight, bias) |
| |
| return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) |
| |
| |
| @_onnx_symbolic("quantized::add") |
| @_beartype.beartype |
| def quantized_add(g: jit_utils.GraphContext, x, y, op_scale, op_zero_point): |
| x, _, _, _ = symbolic_helper.dequantize_helper(g, x) |
| y, _, _, _ = symbolic_helper.dequantize_helper(g, y) |
| |
| output = opset9.add(g, x, y) |
| |
| return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) |
| |
| |
| @_onnx_symbolic("quantized::add_relu") |
| @_beartype.beartype |
| def quantized_add_relu(g: jit_utils.GraphContext, x, y, op_scale, op_zero_point): |
| x, _, _, _ = symbolic_helper.dequantize_helper(g, x) |
| y, _, _, _ = symbolic_helper.dequantize_helper(g, y) |
| |
| output = opset9.add(g, x, y) |
| output = opset9.relu(g, output) |
| |
| return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) |
| |
| |
| @_onnx_symbolic("quantized::mul") |
| @_beartype.beartype |
| def quantized_mul(g: jit_utils.GraphContext, x, y, op_scale, op_zero_point): |
| x, _, _, _ = symbolic_helper.dequantize_helper(g, x) |
| y, _, _, _ = symbolic_helper.dequantize_helper(g, y) |
| |
| output = opset9.mul(g, x, y) |
| |
| return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) |
| |
| |
| @_onnx_symbolic("quantized::hardswish") |
| @_beartype.beartype |
| def quantized_hardswish(g: jit_utils.GraphContext, x, op_scale, op_zero_point): |
| x, _, _, _ = symbolic_helper.dequantize_helper(g, x) |
| |
| output = opset9.hardswish(g, x) |
| |
| return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) |
| |
| |
| @_onnx_symbolic("quantized::sigmoid") |
| @_beartype.beartype |
| def quantized_sigmoid(g: jit_utils.GraphContext, x, op_scale, op_zero_point): |
| x, _, _, _ = symbolic_helper.dequantize_helper(g, x) |
| |
| output = opset9.sigmoid(g, x) |
| |
| return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) |
| |
| |
| @_onnx_symbolic("quantized::leaky_relu") |
| @_beartype.beartype |
| def quantized_leaky_relu( |
| g: jit_utils.GraphContext, x, negative_slope, inplace, op_scale, op_zero_point |
| ): |
| x, _, _, _ = symbolic_helper.dequantize_helper(g, x) |
| |
| output = opset9.leaky_relu(g, x, negative_slope, inplace) |
| |
| return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) |
| |
| |
| @_onnx_symbolic("quantized::layer_norm") |
| @_beartype.beartype |
| def quantized_layer_norm( |
| g: jit_utils.GraphContext, |
| x, |
| normalized_shape, |
| weight, |
| bias, |
| eps, |
| op_scale, |
| op_zero_point, |
| ): |
| x, _, _, _ = symbolic_helper.dequantize_helper(g, x) |
| |
| output = opset9.layer_norm(g, x, normalized_shape, weight, bias, eps, False) |
| |
| return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) |
| |
| |
| @_onnx_symbolic("quantized::group_norm") |
| @_beartype.beartype |
| def quantized_group_norm( |
| g: jit_utils.GraphContext, |
| x, |
| num_groups, |
| weight, |
| bias, |
| eps, |
| op_scale, |
| op_zero_point, |
| ): |
| x, _, _, _ = symbolic_helper.dequantize_helper(g, x) |
| |
| output = opset9.group_norm(g, x, num_groups, weight, bias, eps, False) |
| |
| return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) |
| |
| |
| @_onnx_symbolic("quantized::instance_norm") |
| @symbolic_helper.parse_args("v", "v", "v", "f", "v", "v") |
| @_beartype.beartype |
| def quantized_instance_norm( |
| g: jit_utils.GraphContext, |
| q_input, |
| weight, |
| bias, |
| eps, |
| op_scale, |
| op_zero_point, |
| ): |
| input, _, _, _ = symbolic_helper.dequantize_helper(g, q_input) |
| |
| output = opset9.instance_norm( |
| g, input, weight, bias, None, None, False, 0.0, eps, False |
| ) |
| |
| return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) |
| |
| |
| @_onnx_symbolic("quantized::conv2d_relu") |
| @_beartype.beartype |
| def quantized_conv2d_relu( |
| g: jit_utils.GraphContext, |
| q_input, |
| q_weight, |
| bias, |
| stride, |
| padding, |
| dilation, |
| groups, |
| op_scale, |
| op_zero_point, |
| ): |
| input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input) |
| weight, weight_scale, _, _ = symbolic_helper.dequantize_helper(g, q_weight) |
| q_bias = symbolic_helper.requantize_bias_helper(g, bias, input_scale, weight_scale) |
| bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias) |
| |
| output = opset9.conv2d(g, input, weight, bias, stride, padding, dilation, groups) |
| output = opset9.relu(g, output) |
| |
| return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) |
| |
| |
| @_onnx_symbolic("quantized::conv2d") |
| @_beartype.beartype |
| def quantized_conv2d( |
| g: jit_utils.GraphContext, |
| q_input, |
| q_weight, |
| bias, |
| stride, |
| padding, |
| dilation, |
| groups, |
| op_scale, |
| op_zero_point, |
| ): |
| input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input) |
| weight, weight_scale, _, _ = symbolic_helper.dequantize_helper(g, q_weight) |
| q_bias = symbolic_helper.requantize_bias_helper(g, bias, input_scale, weight_scale) |
| bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias) |
| |
| output = opset9.conv2d(g, input, weight, bias, stride, padding, dilation, groups) |
| |
| return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) |
| |
| |
| @_onnx_symbolic("quantized::cat") |
| @symbolic_helper.parse_args("v", "i", "v", "v") |
| @_beartype.beartype |
| def quantized_cat( |
| g: jit_utils.GraphContext, |
| q_inputs: _C.Value, |
| dim: int, |
| op_scale: _C.Value, |
| op_zero_point: _C.Value, |
| ) -> _C.Value: |
| unpacked_inputs = symbolic_helper._unpack_list(q_inputs) |
| dequantized = [ |
| symbolic_helper.dequantize_helper(g, input)[0] for input in unpacked_inputs |
| ] |
| concatenated = g.op("Concat", *dequantized, axis_i=dim) |
| return symbolic_helper.quantize_helper(g, concatenated, op_scale, op_zero_point) |