| """This file exports ONNX ops for opset 14. |
| |
| Note [ONNX operators that are added/updated in opset 14] |
| ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ |
| New operators: |
| HardSwish, Trilu |
| |
| Updated operators: |
| Reshape |
| Add, Sub, Mul, Div |
| GRU, LSTM, RNN |
| BatchNorm, Cumsum, Relu |
| """ |
| |
| # EDITING THIS FILE? READ THIS FIRST! |
| # see Note [Edit Symbolic Files] in README.md |
| from __future__ import annotations |
| |
| import functools |
| from typing import Optional |
| |
| import torch |
| from torch.onnx import _constants, _type_utils, symbolic_helper |
| from torch.onnx._globals import GLOBALS |
| from torch.onnx._internal import _beartype, jit_utils, registration |
| |
| __all__ = [ |
| "hardswish", |
| "tril", |
| "triu", |
| "reshape", |
| "batch_norm", |
| "quantized_hardswish", |
| "scaled_dot_product_attention", |
| ] |
| |
| _onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=14) |
| |
| |
| @_onnx_symbolic("aten::hardswish") |
| @symbolic_helper.parse_args("v") |
| @_beartype.beartype |
| def hardswish(g: jit_utils.GraphContext, self): |
| return g.op("HardSwish", self) |
| |
| |
| @_onnx_symbolic("aten::tril") |
| @_beartype.beartype |
| def tril(g: jit_utils.GraphContext, self, diagonal, out=None): |
| return g.op("Trilu", self, diagonal, upper_i=0) |
| |
| |
| @_onnx_symbolic("aten::triu") |
| @_beartype.beartype |
| def triu(g: jit_utils.GraphContext, self, diagonal, out=None): |
| return g.op("Trilu", self, diagonal, upper_i=1) |
| |
| |
| @_onnx_symbolic("aten::reshape") |
| @symbolic_helper.quantized_args(True) |
| @symbolic_helper.parse_args("v", "v") |
| @_beartype.beartype |
| def reshape(g: jit_utils.GraphContext, self, shape): |
| # NOTE: Due to bug in ORT https://github.com/microsoft/onnxruntime/issues/10664 |
| # Reshape export cannot utilize the new allowzero attribute introduced in opset 14. |
| return symbolic_helper._reshape_helper(g, self, shape, allowzero=0) |
| |
| |
| @_onnx_symbolic("aten::batch_norm") |
| @symbolic_helper.parse_args("v", "v", "v", "v", "v", "i", "f", "f", "i") |
| @_beartype.beartype |
| def batch_norm( |
| g: jit_utils.GraphContext, |
| input, |
| weight, |
| bias, |
| running_mean, |
| running_var, |
| training, |
| momentum, |
| eps, |
| cudnn_enabled, |
| ): |
| if ( |
| torch.is_autocast_enabled() |
| and not symbolic_helper.args_have_same_dtype( |
| [input, weight, bias, running_mean, running_var] |
| ) |
| and GLOBALS.export_onnx_opset_version < 15 |
| ): |
| return symbolic_helper._onnx_opset_unsupported_detailed( |
| "BatchNormalization", |
| 14, |
| 15, |
| "All input tensors must have the same `dtype`." |
| " Turn off Autocast or export using opset version 15.", |
| input, |
| ) |
| |
| symbolic_helper.check_training_mode(training, "batch_norm") |
| weight, bias, running_mean, running_var = symbolic_helper._batchnorm_helper( |
| g, input, weight, bias, running_mean, running_var |
| ) |
| out = g.op( |
| "BatchNormalization", |
| input, |
| weight, |
| bias, |
| running_mean, |
| running_var, |
| epsilon_f=eps, |
| momentum_f=1 - momentum, |
| training_mode_i=0 if not training else 1, |
| outputs=1 if not training else 3, |
| ) |
| if not training: |
| return out |
| else: |
| res, new_running_mean, new_running_var = out |
| new_running_mean.setType(running_mean.type()) |
| new_running_var.setType(running_var.type()) |
| return res |
| |
| |
| @_onnx_symbolic("quantized::hardswish") |
| @_beartype.beartype |
| def quantized_hardswish(g: jit_utils.GraphContext, x, op_scale, op_zero_point): |
| x, _, _, _ = symbolic_helper.dequantize_helper(g, x) |
| |
| output = hardswish(g, x) |
| |
| return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point) |
| |
| |
| # Ported from |
| # https://github.com/microsoft/onnxscript/blob/6b1b81700b4523f31d8c6d3321e5d8ef5d42b764/onnxscript/function_libs/torch_aten/ops/nn.py#L1504 |
| # aten_scaled_dot_product_attention |
| # NOTE: Need op.Trilu |
| @_onnx_symbolic("aten::scaled_dot_product_attention") |
| @symbolic_helper.parse_args("v", "v", "v", "v", "f", "b", "v") |
| @_beartype.beartype |
| def scaled_dot_product_attention( |
| g: jit_utils.GraphContext, |
| query: torch._C.Value, |
| key: torch._C.Value, |
| value: torch._C.Value, |
| attn_mask: Optional[torch._C.Value] = None, |
| dropout_p: float = 0.0, |
| is_causal: bool = False, |
| scale: Optional[torch._C.Value] = None, |
| ): |
| assert (not is_causal) or ( |
| is_causal and symbolic_helper._is_none(attn_mask) |
| ), "is_causal and attn_mask cannot be set at the same time" |
| |
| scale = symbolic_helper._maybe_get_const(scale, "f") |
| if symbolic_helper._is_none(scale): |
| scale = _attention_scale(g, query) |
| |
| if is_causal: |
| attn_mask = _causal_attention_mask(g, query, key) |
| |
| # Swap the last two axes of key |
| # NOTE: onnx-script has different logic here, because the attribute perms in |
| # transpose needs list of ints |
| key_shape_builtin = symbolic_helper._get_tensor_rank(key) |
| key_transposed_axes = list(range(key_shape_builtin)) |
| key_transposed_axes[-1], key_transposed_axes[-2] = ( |
| key_transposed_axes[-2], |
| key_transposed_axes[-1], |
| ) |
| key_transposed = g.op("Transpose", key, perm_i=key_transposed_axes) |
| |
| # https://github.com/pytorch/pytorch/blob/12da0c70378b5be9135c6fda62a9863bce4a4818/aten/src/ATen/native/transformers/attention.cpp#L653 |
| # Scale q, k before matmul for stability see https://tinyurl.com/sudb9s96 for math |
| query_scaled = g.op("Mul", query, g.op("Sqrt", scale)) |
| key_transposed_scaled = g.op("Mul", key_transposed, g.op("Sqrt", scale)) |
| mul_qk = g.op("MatMul", query_scaled, key_transposed_scaled) |
| |
| if symbolic_helper._is_none(attn_mask): |
| mul_qk_add = mul_qk |
| elif ( |
| _type_utils.JitScalarType.from_value(attn_mask) |
| == _type_utils.JitScalarType.BOOL |
| ): |
| # Turn the Boolean mask to float: attn_mask.masked_fill(not attn_mask, -float('inf')) |
| const_zero = g.op("Constant", value_t=torch.tensor([0.0])) |
| const_neg_inf = g.op("Constant", value_t=torch.tensor([-float("inf")])) |
| attn_mask = g.op("Where", attn_mask, const_zero, const_neg_inf) |
| mul_qk_add = g.op("Add", mul_qk, attn_mask) |
| elif _type_utils.JitScalarType.from_value(attn_mask) in ( |
| _type_utils.JitScalarType.FLOAT, |
| _type_utils.JitScalarType.HALF, |
| ): |
| mul_qk_add = g.op("Add", mul_qk, attn_mask) |
| else: |
| raise ValueError( |
| f"Unsupported type for attn_mask: {_type_utils.JitScalarType.from_value(attn_mask)}" |
| ) |
| |
| attn_weight = g.op("Softmax", mul_qk_add, axis_i=-1) |
| |
| if dropout_p != 0: |
| attn_weight = g.op( |
| "Dropout", |
| attn_weight, |
| g.op("Constant", value_t=torch.tensor(dropout_p, dtype=torch.float)), |
| ) |
| |
| return g.op("MatMul", attn_weight, value) |
| |
| |
| @_beartype.beartype |
| def _attention_scale( |
| g: jit_utils.GraphContext, query: torch._C.Value |
| ) -> torch._C.Value: |
| """Calculate the scale factor for the attention result. |
| |
| Args: |
| query: Tensor of shape [..., L, E] |
| |
| Returns: |
| Scalar scale factor := 1 / math.sqrt(query.size(-1)) |
| """ |
| query_shape = g.op("Shape", query) |
| query_shape_last = g.op( |
| "Slice", |
| query_shape, |
| g.op("Constant", value_t=torch.tensor([-1], dtype=torch.int64)), |
| g.op( |
| "Constant", value_t=torch.tensor([_constants.INT64_MAX], dtype=torch.int64) |
| ), |
| ) |
| embedding_size = g.op( |
| "Cast", |
| query_shape_last, |
| to_i=_type_utils.JitScalarType.from_value(query).onnx_type(), |
| ) |
| const_one = g.op("Constant", value_t=torch.tensor([1.0], dtype=torch.float)) |
| scale = g.op("Div", const_one, g.op("Sqrt", embedding_size)) |
| return scale |
| |
| |
| @_beartype.beartype |
| def _causal_attention_mask( |
| g: jit_utils.GraphContext, query: torch._C.Value, key: torch._C.Value |
| ) -> torch._C.Value: |
| """Create a causal mask for the given query and key tensors. |
| |
| Equivalent to:: |
| mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0) |
| attn_mask = torch.zeros(L, S, dtype=torch.float) |
| attn_mask = attn_mask.masked_fill(not mask, -float('inf')) |
| |
| Args: |
| query: Tensor of shape [..., L, E] |
| key: Tensor of shape [..., S, E] |
| |
| Returns: |
| Tensor of shape [L, S] |
| """ |
| |
| query_shape = g.op("Shape", query) |
| key_shape = g.op("Shape", key) |
| |
| last_idx = g.op("Constant", value_t=torch.tensor([-1], dtype=torch.int64)) |
| second_last_idx = g.op("Constant", value_t=torch.tensor([-2], dtype=torch.int64)) |
| target_length = g.op("Slice", query_shape, second_last_idx, last_idx) |
| source_length = g.op("Slice", key_shape, second_last_idx, last_idx) |
| # attn_mask = torch.ones(L, S) := { |
| size = g.op("Concat", target_length, source_length, axis_i=0) |
| const_one = g.op("Constant", value_t=torch.tensor([1.0])) |
| attn_mask = g.op("Expand", const_one, size) |
| # } |
| attn_mask = g.op("Trilu", attn_mask, upper_i=0) |
| # The causal mask has 0s in the lower triangle and -inf in the upper triangle. |
| const_zero = g.op("Constant", value_t=torch.tensor([0.0])) |
| const_neg_inf = g.op("Constant", value_t=torch.tensor([-float("inf")])) |
| attn_mask = g.op( |
| "Where", g.op("Equal", attn_mask, const_zero), const_neg_inf, const_zero |
| ) |
| return attn_mask |