| """This file exports ONNX ops for opset 9. |
| |
| Opset 9 is supported by ONNX release 1.4.1 |
| release on 01/23/19 |
| """ |
| from __future__ import annotations |
| |
| import builtins |
| import functools |
| import math |
| import sys |
| import warnings |
| from typing import Callable, List, Optional, Sequence, Tuple, Union |
| |
| import torch |
| import torch._C._onnx as _C_onnx |
| import torch.nn.modules.utils |
| import torch.onnx |
| from torch import _C |
| |
| # Monkey-patch graph manipulation methods on Graph, used for the ONNX symbolics |
| from torch.onnx import _constants, _deprecation, _type_utils, errors, symbolic_helper |
| from torch.onnx._globals import GLOBALS |
| from torch.onnx._internal import _beartype, jit_utils, registration |
| from torch.types import Number |
| |
| # EDITING THIS FILE? READ THIS FIRST! |
| # see Note [Edit Symbolic Files] in README.md |
| |
| __all__ = [ |
| "abs", |
| "acos", |
| "add", |
| "addcmul", |
| "addmm", |
| "alias", |
| "amax", |
| "amin", |
| "aminmax", |
| "arange", |
| "argmax", |
| "argmin", |
| "as_strided", |
| "as_tensor", |
| "asin", |
| "atan", |
| "atan2", |
| "baddbmm", |
| "batch_norm", |
| "bernoulli", |
| "bitwise_not", |
| "bitwise_or", |
| "bmm", |
| "broadcast_tensors", |
| "broadcast_to", |
| "bucketize", |
| "cat", |
| "cdist", |
| "ceil", |
| "clamp_max", |
| "clamp_min", |
| "clamp", |
| "clone", |
| "constant_pad_nd", |
| "contiguous", |
| "conv_tbc", |
| "conv_transpose1d", |
| "conv_transpose2d", |
| "conv_transpose3d", |
| "conv1d", |
| "conv2d", |
| "conv3d", |
| "convert_element_type", |
| "convolution", |
| "cos", |
| "cosine_similarity", |
| "cross", |
| "cumsum", |
| "detach", |
| "dim", |
| "div", |
| "dot", |
| "dropout", |
| "elu", |
| "embedding_bag", |
| "embedding", |
| "empty_like", |
| "empty", |
| "eq", |
| "erf", |
| "exp", |
| "expand_as", |
| "expand", |
| "eye", |
| "fill", |
| "flatten", |
| "floor_divide", |
| "floor", |
| "floordiv", |
| "frobenius_norm", |
| "full_like", |
| "full", |
| "gather", |
| "ge", |
| "gelu", |
| "get_pool_ceil_padding", |
| "glu", |
| "group_norm", |
| "gt", |
| "hann_window", |
| "hardshrink", |
| "hardsigmoid", |
| "hardswish", |
| "hardtanh", |
| "index_add", |
| "index_copy", |
| "index_fill", |
| "index_put", |
| "index_select", |
| "index", |
| "instance_norm", |
| "is_floating_point", |
| "is_pinned", |
| "isnan", |
| "item", |
| "kl_div", |
| "layer_norm", |
| "le", |
| "leaky_relu", |
| "lerp", |
| "lift", |
| "linalg_cross", |
| "linalg_matrix_norm", |
| "linalg_norm", |
| "linalg_vector_norm", |
| "linear", |
| "linspace", |
| "log_sigmoid", |
| "log_softmax", |
| "log", |
| "log10", |
| "log1p", |
| "log2", |
| "logical_and", |
| "logical_not", |
| "logical_or", |
| "logical_xor", |
| "logit", |
| "logsumexp", |
| "lstm_cell", |
| "lstm", |
| "lt", |
| "masked_fill", |
| "masked_fill_", |
| "matmul", |
| "max_pool1d_with_indices", |
| "max_pool2d_with_indices", |
| "max_pool3d_with_indices", |
| "max", |
| "maximum", |
| "meshgrid", |
| "min", |
| "minimum", |
| "mish", |
| "mm", |
| "movedim", |
| "mse_loss", |
| "mul", |
| "multinomial", |
| "mv", |
| "narrow", |
| "native_layer_norm", |
| "ne", |
| "neg", |
| "new_empty", |
| "new_full", |
| "new_ones", |
| "new_zeros", |
| "nonzero_numpy", |
| "nonzero", |
| "norm", |
| "numel", |
| "numpy_T", |
| "one_hot", |
| "ones_like", |
| "ones", |
| "onnx_placeholder", |
| "overload_by_arg_count", |
| "pad", |
| "pairwise_distance", |
| "permute", |
| "pixel_shuffle", |
| "pixel_unshuffle", |
| "pow", |
| "prelu", |
| "prim_constant_chunk", |
| "prim_constant_split", |
| "prim_constant", |
| "prim_data", |
| "prim_device", |
| "prim_dtype", |
| "prim_if", |
| "prim_layout", |
| "prim_list_construct", |
| "prim_list_unpack", |
| "prim_loop", |
| "prim_max", |
| "prim_min", |
| "prim_shape", |
| "prim_tolist", |
| "prim_tuple_construct", |
| "prim_type", |
| "prim_unchecked_cast", |
| "prim_uninitialized", |
| "rand_like", |
| "rand", |
| "randint_like", |
| "randint", |
| "randn_like", |
| "randn", |
| "reciprocal", |
| "reflection_pad", |
| "relu", |
| "relu6", |
| "remainder", |
| "repeat_interleave", |
| "repeat", |
| "replication_pad", |
| "reshape_as", |
| "reshape", |
| "roll", |
| "rrelu", |
| "rsqrt", |
| "rsub", |
| "scalar_tensor", |
| "scatter_add", |
| "scatter", |
| "select", |
| "selu", |
| "sigmoid", |
| "sign", |
| "silu", |
| "sin", |
| "size", |
| "slice", |
| "softmax", |
| "softplus", |
| "softshrink", |
| "sort", |
| "split_with_sizes", |
| "split", |
| "sqrt", |
| "square", |
| "squeeze", |
| "stack", |
| "std_mean", |
| "std", |
| "sub", |
| "t", |
| "take", |
| "tan", |
| "tanh", |
| "tanhshrink", |
| "tensor", |
| "threshold", |
| "to", |
| "topk", |
| "transpose", |
| "true_divide", |
| "type_as", |
| "unbind", |
| "unfold", |
| "unsafe_chunk", |
| "unsafe_split_with_sizes", |
| "unsafe_split", |
| "unsqueeze", |
| "unsupported_complex_operators", |
| "noop_complex_operators", |
| "unused", |
| "var_mean", |
| "var", |
| "view_as", |
| "view", |
| "where", |
| "wrap_logical_op_with_cast_to", |
| "wrap_logical_op_with_negation", |
| "zeros_like", |
| "zeros", |
| "zero", |
| ] |
| |
| |
| _onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=9) |
| |
| |
| def _apply_params(*args, **kwargs): |
| """Returns a decorator that calls the decorated (higher-order) function with the given parameters.""" |
| |
| def _apply(fn): |
| return fn(*args, **kwargs) |
| |
| return _apply |
| |
| |
| def _export(name: str): |
| """Exports the function in the current global namespace.""" |
| |
| def wrapper(func): |
| globals()[name] = func |
| __all__.append(name) |
| return func |
| |
| return wrapper |
| |
| |
| @_beartype.beartype |
| def unused(g): |
| """Represents "missing" optional inputs.""" |
| n = g.op("prim::Constant") |
| n.setType(_C.OptionalType.ofTensor()) |
| return n |
| |
| |
| @_onnx_symbolic("aten::_shape_as_tensor") |
| @_beartype.beartype |
| def _shape_as_tensor(g: jit_utils.GraphContext, input): |
| return g.op("Shape", input) |
| |
| |
| @_onnx_symbolic("aten::_reshape_from_tensor") |
| @_beartype.beartype |
| def _reshape_from_tensor(g: jit_utils.GraphContext, input, shape): |
| if isinstance(shape, list): |
| shape = g.op("Concat", *shape, axis_i=0) |
| return reshape(g, input, shape) |
| |
| |
| @_onnx_symbolic("aten::reshape") |
| @symbolic_helper.quantized_args(True) |
| @_beartype.beartype |
| def reshape(g: jit_utils.GraphContext, self, shape): |
| return symbolic_helper._reshape_helper(g, self, shape) |
| |
| |
| @_onnx_symbolic("aten::reshape_as") |
| @symbolic_helper.quantized_args(True) |
| @_beartype.beartype |
| def reshape_as(g: jit_utils.GraphContext, self, other): |
| shape = g.op("Shape", other) |
| return reshape(g, self, shape) |
| |
| |
| @_onnx_symbolic("aten::add") |
| @_beartype.beartype |
| def add(g: jit_utils.GraphContext, self, other, alpha=None): |
| if symbolic_helper._is_value(self) and symbolic_helper._is_tensor_list(self): |
| return symbolic_helper._onnx_opset_unsupported_detailed( |
| "Add", 9, 11, "Add between list of tensors not supported", self |
| ) |
| if alpha and symbolic_helper._scalar(symbolic_helper._maybe_get_scalar(alpha)) != 1: |
| other = g.op("Mul", other, alpha) |
| return g.op("Add", self, other) |
| |
| |
| @_onnx_symbolic("aten::sub") |
| @_beartype.beartype |
| def sub(g: jit_utils.GraphContext, self, other, alpha=None): |
| if alpha and symbolic_helper._scalar(symbolic_helper._maybe_get_scalar(alpha)) != 1: |
| other = g.op("Mul", other, alpha) |
| return g.op("Sub", self, other) |
| |
| |
| @_onnx_symbolic("aten::rsub") |
| @_beartype.beartype |
| def rsub(g: jit_utils.GraphContext, self, other, alpha=None): |
| return sub(g, other, self, alpha=alpha) |
| |
| |
| @_onnx_symbolic("aten::mul") |
| @_beartype.beartype |
| def mul(g: jit_utils.GraphContext, self, other): |
| if symbolic_helper._is_bool(self) and symbolic_helper._is_bool(other): |
| # ONNX Mul doesn't support Boolean, so use And as an equivalent operator. |
| return g.op("And", self, other) |
| else: |
| return g.op("Mul", self, other) |
| |
| |
| @_onnx_symbolic("aten::div") |
| @_beartype.beartype |
| def div(g: jit_utils.GraphContext, self, other, *args): |
| if len(args) == 0: |
| return true_divide(g, self, other) |
| else: |
| return _div_rounding_mode(g, self, other, *args) |
| |
| |
| @_onnx_symbolic("aten::addcmul") |
| @symbolic_helper.parse_args("v", "v", "v", "f") |
| @_beartype.beartype |
| def addcmul(g: jit_utils.GraphContext, self, tensor1, tensor2, value=1.0): |
| value_tens = g.op("Constant", value_t=torch.tensor([value])) |
| return add(g, self, mul(g, mul(g, tensor1, tensor2), value_tens)) |
| |
| |
| @symbolic_helper.parse_args("v", "v", "s") |
| @_beartype.beartype |
| def _div_rounding_mode(g: jit_utils.GraphContext, self, other, rounding_mode): |
| if rounding_mode is None: |
| return true_divide(g, self, other) |
| elif rounding_mode == "floor": |
| return _floor_divide(g, self, other) |
| elif rounding_mode == "trunc": |
| return _trunc_divide(g, self, other) |
| else: |
| raise errors.SymbolicValueError( |
| f'Unsupported rounding mode: "{rounding_mode}". Expected None, "floor" or "trunc"', |
| self, |
| ) |
| |
| |
| @_beartype.beartype |
| def _trunc_divide(g: jit_utils.GraphContext, self, other): |
| out = g.op("Div", self, other) |
| # the correct operation is truncate, which is not supported in ONNX, |
| # we cannot call floor since it will behave differently for negative numbers |
| # (eg. -0.1 should become -0 ) |
| # - if scalar_type information are not available, assume that |
| # we need to call floor (treat as float) |
| out = g.op("Cast", out, to_i=_C_onnx.TensorProtoDataType.INT64) |
| |
| # Matching PyTorch's behavior: |
| # - if self is fp the output's type is self's type |
| # - if self is not fp and other is fp, the output is of type JitScalarType.FLOAT |
| # - self is not fp and other is not fp, the output's type is self's output type |
| # - the output type defaults to Float |
| scalar_type = _type_utils.JitScalarType.from_value( |
| self, _type_utils.JitScalarType.UNDEFINED |
| ) |
| if scalar_type != _type_utils.JitScalarType.UNDEFINED: |
| if not symbolic_helper._is_fp(self) and symbolic_helper._is_fp(other): |
| out = g.op("Cast", out, to_i=_C_onnx.TensorProtoDataType.FLOAT) |
| else: |
| out = g.op( |
| "Cast", |
| out, |
| to_i=scalar_type.onnx_type(), |
| ) |
| else: |
| out = g.op("Cast", out, to_i=_C_onnx.TensorProtoDataType.FLOAT) |
| return out |
| |
| |
| @_beartype.beartype |
| def _floor_divide(g: jit_utils.GraphContext, self, other): |
| if symbolic_helper._is_fp(self) or symbolic_helper._is_fp(other): |
| out = true_divide(g, self, other) |
| return g.op("Floor", out) |
| else: |
| # Integer division does trunction rounding |
| div = g.op("Div", self, other) |
| # Division is negative if: self < 0 != other < 0 |
| zero = g.op("Constant", value_t=torch.tensor(0, dtype=torch.int64)) |
| negative = g.op( |
| "Xor", |
| symbolic_helper._lt_helper(g, self, zero), |
| symbolic_helper._lt_helper(g, other, zero), |
| ) |
| |
| # For negative numbers with self % other != 0, subtract 1 to round down instead of up |
| mod = g.op("Sub", self, g.op("Mul", div, other)) |
| fixup_mask = g.op("And", negative, g.op("Not", g.op("Equal", mod, zero))) |
| |
| one = g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64)) |
| fixup = g.op("Mul", fixup_mask, one) |
| return g.op("Sub", div, fixup) |
| |
| |
| @_onnx_symbolic("aten::floor_divide") |
| @_beartype.beartype |
| def floor_divide(g: jit_utils.GraphContext, self, other): |
| # Deprecated behavior, floor_divide actually truncates |
| return _trunc_divide(g, self, other) |
| |
| |
| @_onnx_symbolic("aten::floordiv") |
| @_beartype.beartype |
| def floordiv(g: jit_utils.GraphContext, self, other): |
| return floor_divide(g, self, other) |
| |
| |
| @_onnx_symbolic("aten::true_divide") |
| @_beartype.beartype |
| def true_divide(g: jit_utils.GraphContext, self, other): |
| """Division where both inputs are cast to floating types |
| |
| If both inputs are floating, performs div as usual |
| If only one input is a floating type, the other input is cast to its type |
| If neither input is a floating type, both inputs are cast to the default scalar type |
| """ |
| |
| # Case 1: either values are floating |
| # Performs div as usual. |
| # Implicit casting will be handled in scalar type analysis pass. |
| if symbolic_helper._is_fp(self) or symbolic_helper._is_fp(other): |
| return g.op("Div", self, other) |
| |
| # Case 2: neither is floating |
| # Casts both inputs to the default scalar type |
| scalar_type = torch.get_default_dtype() |
| onnx_scalar_type = _C_onnx.TensorProtoDataType.FLOAT |
| assert scalar_type is torch.float or scalar_type is torch.double |
| if torch.get_default_dtype() is torch.double: |
| onnx_scalar_type = _C_onnx.TensorProtoDataType.DOUBLE |
| |
| self = g.op("Cast", self, to_i=onnx_scalar_type) |
| other = g.op("Cast", other, to_i=onnx_scalar_type) |
| return g.op("Div", self, other) |
| |
| |
| @_onnx_symbolic("aten::reciprocal") |
| @_beartype.beartype |
| def reciprocal(g: jit_utils.GraphContext, self): |
| # torch.reciprocal implicitly casts to float, so we do the same. |
| if not symbolic_helper._is_fp(self): |
| self = g.op("Cast", self, to_i=_C_onnx.TensorProtoDataType.FLOAT) |
| return g.op("Reciprocal", self) |
| |
| |
| @_onnx_symbolic("aten::cat") |
| @symbolic_helper.parse_args("v", "i") |
| @_beartype.beartype |
| def cat(g: jit_utils.GraphContext, tensor_list, dim): |
| tensors = symbolic_helper._unpack_list(tensor_list) |
| # torch.cat ignores empty tensors such as `torch.Tensor([])` |
| # These needs to be removed as input from ONNX's concat too, otherwise shape inference |
| # will likely fail due to inputs with different ranks (0 for empty tensor, > 0 for anything else) |
| nonempty_tensors = [] |
| for t in tensors: |
| if symbolic_helper._is_constant(t) and not symbolic_helper._get_tensor_dim_size( |
| t, 0 |
| ): |
| continue |
| nonempty_tensors.append(t) |
| assert len(nonempty_tensors) > 0 |
| assert all( |
| symbolic_helper._get_tensor_rank(nonempty_tensors[0]) is None |
| or symbolic_helper._get_tensor_rank(t) is None |
| or symbolic_helper._get_tensor_rank(t) |
| == symbolic_helper._get_tensor_rank(nonempty_tensors[0]) |
| for t in nonempty_tensors |
| ) |
| tensor_list.node().removeAllInputs() |
| for t in nonempty_tensors: |
| tensor_list.node().addInput(t) |
| |
| tensors = symbolic_helper._unpack_list(tensor_list) |
| return g.op("Concat", *tensors, axis_i=dim) |
| |
| |
| @_onnx_symbolic("aten::stack") |
| @symbolic_helper.parse_args("v", "i") |
| @_beartype.beartype |
| def stack(g: jit_utils.GraphContext, tensor_list, dim): |
| unsqueezed = [ |
| symbolic_helper._unsqueeze_helper(g, t, [dim]) |
| for t in symbolic_helper._unpack_list(tensor_list) |
| ] |
| return g.op("Concat", *unsqueezed, axis_i=dim) |
| |
| |
| @_onnx_symbolic("aten::list") |
| @_beartype.beartype |
| def _list(g: jit_utils.GraphContext, self): |
| return self |
| |
| |
| @_onnx_symbolic("aten::mm") |
| @_beartype.beartype |
| def mm(g: jit_utils.GraphContext, self, other): |
| # Create a dummy C tensor. Only needed for API purposes, the value is |
| # since beta = 0 |
| C = g.op("Constant", value_t=torch.tensor([1])) |
| return g.op("Gemm", self, other, C, beta_f=0.0, alpha_f=1.0) |
| |
| |
| @_onnx_symbolic("aten::bmm") |
| @_beartype.beartype |
| def bmm(g: jit_utils.GraphContext, self, other): |
| return g.op("MatMul", self, other) |
| |
| |
| @_onnx_symbolic("aten::matmul") |
| @_beartype.beartype |
| def matmul(g: jit_utils.GraphContext, self, other): |
| return g.op("MatMul", self, other) |
| |
| |
| @_onnx_symbolic("aten::addmm") |
| @symbolic_helper.parse_args("v", "v", "v", "t", "t") |
| @_beartype.beartype |
| def addmm(g: jit_utils.GraphContext, self, mat1, mat2, beta, alpha): |
| scalar_type = None |
| self_scalar_type = symbolic_helper._try_get_scalar_type(self) |
| mat1_scalar_type = symbolic_helper._try_get_scalar_type(mat1) |
| mat2_scalar_type = symbolic_helper._try_get_scalar_type(mat2) |
| if self_scalar_type is not None: |
| scalar_type = self_scalar_type |
| elif mat1_scalar_type is not None: |
| scalar_type = mat1_scalar_type |
| elif mat2_scalar_type is not None: |
| scalar_type = mat2_scalar_type |
| |
| mat1_rank = symbolic_helper._get_tensor_rank(mat1) |
| mat2_rank = symbolic_helper._get_tensor_rank(mat2) |
| |
| def is_not_none_nor(v, u): |
| return v is not None and v != u |
| |
| if scalar_type is not None and ( |
| is_not_none_nor(mat1_rank, 2) or is_not_none_nor(mat2_rank, 2) |
| ): |
| res1 = g.op("MatMul", mat1, mat2) |
| res2 = self |
| |
| alpha = symbolic_helper._scalar(alpha) |
| beta = symbolic_helper._scalar(beta) |
| |
| if alpha != 1: |
| alpha = g.op( |
| "Constant", value_t=torch.tensor(alpha, dtype=scalar_type.dtype()) |
| ) |
| res1 = g.op("Mul", res1, alpha) |
| if beta != 1: |
| beta = g.op( |
| "Constant", |
| value_t=torch.tensor( |
| symbolic_helper._scalar(beta), dtype=scalar_type.dtype() |
| ), |
| ) |
| res2 = g.op("Mul", res2, beta) |
| |
| return g.op("Add", res1, res2) |
| |
| return g.op( |
| "Gemm", |
| mat1, |
| mat2, |
| self, |
| beta_f=symbolic_helper._scalar(beta), |
| alpha_f=symbolic_helper._scalar(alpha), |
| ) |
| |
| |
| @_onnx_symbolic("aten::neg") |
| @_beartype.beartype |
| def neg(g: jit_utils.GraphContext, self): |
| return g.op("Neg", self) |
| |
| |
| @_onnx_symbolic("aten::sqrt") |
| @_beartype.beartype |
| def sqrt(g: jit_utils.GraphContext, self): |
| if _type_utils.JitScalarType.from_value( |
| self, _type_utils.JitScalarType.UNDEFINED |
| ) in { |
| _type_utils.JitScalarType.UINT8, |
| _type_utils.JitScalarType.INT8, |
| _type_utils.JitScalarType.INT16, |
| _type_utils.JitScalarType.INT, |
| _type_utils.JitScalarType.INT64, |
| }: |
| # torch converts all int inputs to sqrt to float |
| self = g.op("Cast", self, to_i=_C_onnx.TensorProtoDataType.FLOAT) |
| |
| return g.op("Sqrt", self) |
| |
| |
| @_onnx_symbolic("aten::rsqrt") |
| @_beartype.beartype |
| def rsqrt(g: jit_utils.GraphContext, self): |
| return g.op( |
| "Div", symbolic_helper._if_scalar_type_as(torch.ones(1), self), sqrt(g, self) |
| ) |
| |
| |
| @_onnx_symbolic("aten::tanh") |
| # Fixed scale and zero_point, discovered from aten/src/ATen/native/quantized/cpu/qtanh.cpp |
| @symbolic_helper.quantized_args(True, scale=2.0 / 256.0, zero_point=128) |
| @_beartype.beartype |
| def tanh(g: jit_utils.GraphContext, self): |
| return g.op("Tanh", self) |
| |
| |
| @_onnx_symbolic("aten::sin") |
| @_beartype.beartype |
| def sin(g: jit_utils.GraphContext, self): |
| return g.op("Sin", self) |
| |
| |
| @_onnx_symbolic("aten::cos") |
| @_beartype.beartype |
| def cos(g: jit_utils.GraphContext, self): |
| return g.op("Cos", self) |
| |
| |
| @_onnx_symbolic("aten::tan") |
| @_beartype.beartype |
| def tan(g: jit_utils.GraphContext, self): |
| return g.op("Tan", self) |
| |
| |
| @_onnx_symbolic("aten::asin") |
| @_beartype.beartype |
| def asin(g: jit_utils.GraphContext, self): |
| return g.op("Asin", self) |
| |
| |
| @_onnx_symbolic("aten::acos") |
| @_beartype.beartype |
| def acos(g: jit_utils.GraphContext, self): |
| return g.op("Acos", self) |
| |
| |
| @_onnx_symbolic("aten::atan") |
| @_beartype.beartype |
| def atan(g: jit_utils.GraphContext, self): |
| return g.op("Atan", self) |
| |
| |
| @_onnx_symbolic("aten::atan2") |
| @_beartype.beartype |
| def atan2(g: jit_utils.GraphContext, self, other): |
| # self is y, and other is x on coordinate |
| slope = g.op("Div", self, other) |
| atan = g.op("Atan", slope) |
| const_zero = g.op("Constant", value_t=torch.tensor(0)) |
| const_pi = g.op("Constant", value_t=torch.tensor(math.pi)) |
| |
| condition_second_or_third_quadrant = g.op("Greater", self, const_zero) |
| second_third_quadrant = g.op( |
| "Where", |
| condition_second_or_third_quadrant, |
| g.op("Add", atan, const_pi), |
| g.op("Sub", atan, const_pi), |
| ) |
| |
| condition_14_or_23_quadrant = g.op("Less", other, const_zero) |
| result = g.op("Where", condition_14_or_23_quadrant, second_third_quadrant, atan) |
| |
| return result |
| |
| |
| @_onnx_symbolic("aten::sigmoid") |
| # Fixed scale and zero_point, discovered from aten/src/ATen/native/quantized/cpu/qsigmoid.cpp |
| @symbolic_helper.quantized_args(True, scale=1.0 / 256.0, zero_point=0) |
| @_beartype.beartype |
| def sigmoid(g: jit_utils.GraphContext, self): |
| return g.op("Sigmoid", self) |
| |
| |
| @_onnx_symbolic("aten::sign") |
| @_beartype.beartype |
| def sign(g: jit_utils.GraphContext, self): |
| return g.op("Sign", self) |
| |
| |
| @symbolic_helper.quantized_args(True) |
| @_beartype.beartype |
| def _slice(g: jit_utils.GraphContext, input, axes, starts, ends): |
| assert len(starts) == len(ends) |
| if len(starts) == 1 and starts[0] == 0 and ends[0] == _constants.INT64_MAX: |
| return input |
| return g.op("Slice", input, axes_i=axes, starts_i=starts, ends_i=ends) |
| |
| |
| @_beartype.beartype |
| def _maybe_cast_reduce_op_input(g: jit_utils.GraphContext, self): |
| scalar_type = _type_utils.JitScalarType.from_value( |
| self, _type_utils.JitScalarType.UNDEFINED |
| ) |
| if scalar_type != _type_utils.JitScalarType.UNDEFINED: |
| # This check only covers traced modules where dtype is present |
| # pytorch reduce-ops cast all other integral types to int64 |
| if ( |
| not symbolic_helper._is_fp(self) |
| and scalar_type != _type_utils.JitScalarType.INT64 |
| ): |
| self = g.op("Cast", self, to_i=_C_onnx.TensorProtoDataType.INT64) |
| return self |
| |
| |
| @_beartype.beartype |
| def _reduce_op_symbolic(onnx_op_name, allow_multi_dim_support=True): |
| @_beartype.beartype |
| def symbolic(g, self, dim=None, keepdim=None): |
| self = _maybe_cast_reduce_op_input(g, self) |
| if dim is None or dim == tuple(): |
| # Dim can be 0, which will cause (not dim) == True. So we don't want to do |
| # (not dim) |
| # all-reduce path |
| return symbolic_helper._handle_reduce_dim_none(g, self, onnx_op_name) |
| else: |
| # dim-reduce path |
| desc = "is" if allow_multi_dim_support else "i" |
| dim, keepdim = symbolic_helper._get_const( |
| dim, desc, "dim" |
| ), symbolic_helper._get_const(keepdim, "i", "keepdim") |
| dim_list = dim if allow_multi_dim_support else [dim] |
| return g.op(onnx_op_name, self, axes_i=dim_list, keepdims_i=keepdim) |
| |
| return symbolic |
| |
| |
| @_beartype.beartype |
| def overload_by_arg_count(fn): |
| @functools.wraps(fn) |
| @_beartype.beartype |
| def wrapper(g, *args): |
| overloads = fn(g, *args) |
| for overload in overloads: |
| arg_descriptors = overload._arg_descriptors |
| if len(arg_descriptors) == len(args): |
| return overload(g, *args) |
| return symbolic_helper._unimplemented( |
| f"aten::{fn.__name__}", f"with {len(args)} arguments" |
| ) |
| |
| return wrapper |
| |
| |
| @_onnx_symbolic("aten::sum", decorate=[_apply_params("ReduceSum", "sum")]) |
| @_onnx_symbolic("aten::mean", decorate=[_apply_params("ReduceMean", "mean")]) |
| # torch.prod does not support multidimensional "dim" |
| @_onnx_symbolic( |
| "aten::prod", |
| decorate=[_apply_params("ReduceProd", "prod", allow_multi_dim_support=False)], |
| ) |
| @_beartype.beartype |
| def _reduce_with_dtype(onnx_op: str, name: str, allow_multi_dim_support: bool = True): |
| symbolic = _reduce_op_symbolic( |
| onnx_op, allow_multi_dim_support=allow_multi_dim_support |
| ) |
| |
| @overload_by_arg_count |
| def reduce(g, *args, **kwargs): |
| @symbolic_helper.quantized_args(True) |
| @symbolic_helper.parse_args("v", "none") |
| def reduce_nodim(g, self, dtype): |
| dtype_onnx = None |
| if dtype.node().kind() == "onnx::Constant": |
| dtype = symbolic_helper._get_const(dtype, "i", "dtype") |
| dtype_onnx = _type_utils.JitScalarType(dtype).onnx_type() |
| self = g.op("Cast", self, to_i=dtype_onnx) |
| elif dtype.node().kind() != "prim::Constant": |
| return symbolic_helper._unimplemented(name, "dtype", dtype) |
| result = symbolic(g, self) |
| if dtype_onnx is not None: |
| result_dtype_onnx = _type_utils.JitScalarType.from_value( |
| result |
| ).onnx_type() |
| if result_dtype_onnx != dtype_onnx: |
| result = g.op("Cast", result, to_i=dtype_onnx) |
| return result |
| |
| dim_desc = "is" if allow_multi_dim_support else "i" |
| |
| @symbolic_helper.quantized_args(True) |
| @symbolic_helper.parse_args("v", dim_desc, "i", "none") # type: ignore[arg-type] |
| def reduce_dim(g, self, dim, keepdim, dtype): |
| dtype_onnx = None |
| if dtype.node().kind() == "onnx::Constant": |
| dtype = symbolic_helper._get_const(dtype, "i", "dtype") |
| dtype_onnx = _type_utils.JitScalarType(dtype).onnx_type() |
| self = g.op("Cast", self, to_i=dtype_onnx) |
| elif dtype.node().kind() != "prim::Constant": |
| return symbolic_helper._unimplemented(name, "dtype", dtype) |
| result = symbolic(g, self, dim, keepdim) |
| if dtype_onnx is not None: |
| result_dtype_onnx = _type_utils.JitScalarType.from_value( |
| result |
| ).onnx_type() |
| if result_dtype_onnx != dtype_onnx: |
| result = g.op("Cast", result, to_i=dtype_onnx) |
| return result |
| |
| return reduce_nodim, reduce_dim |
| |
| return reduce |
| |
| |
| @_onnx_symbolic("aten::cumsum") |
| @symbolic_helper.parse_args("v", "i", "none") |
| @_beartype.beartype |
| def cumsum(g: jit_utils.GraphContext, input, dim, dtype): |
| if symbolic_helper.is_caffe2_aten_fallback(): |
| if dtype.node().kind() != "prim::Constant": |
| return symbolic_helper._unimplemented("cumsum", "dtype", dtype) |
| return g.at("cumsum", input, dim_i=dim) |
| |
| symbolic_helper._onnx_opset_unsupported("cumsum", 9, 11, input) |
| |
| |
| @_onnx_symbolic("aten::_sample_dirichlet") |
| @_beartype.beartype |
| def _sample_dirichlet(g: jit_utils.GraphContext, self, generator): |
| if symbolic_helper.is_caffe2_aten_fallback(): |
| if not symbolic_helper._is_none(generator): |
| return symbolic_helper._unimplemented( |
| "_sample_dirichlet", "We are not able to export generator", self |
| ) |
| return g.at("_sample_dirichlet", self) |
| return symbolic_helper._onnx_unsupported("_sample_dirichlet", self) |
| |
| |
| @_onnx_symbolic("aten::_standard_gamma") |
| @_beartype.beartype |
| def _standard_gamma(g: jit_utils.GraphContext, self, generator): |
| if symbolic_helper.is_caffe2_aten_fallback(): |
| if not symbolic_helper._is_none(generator): |
| return symbolic_helper._unimplemented( |
| "_standard_gamma", "not able to export generator", self |
| ) |
| return g.at("_standard_gamma", self) |
| |
| return symbolic_helper._onnx_unsupported("_standard_gamma", self) |
| |
| |
| @_onnx_symbolic("aten::t") |
| @_beartype.beartype |
| def t(g: jit_utils.GraphContext, self): |
| rank = symbolic_helper._get_tensor_rank(self) |
| if rank is None or rank < 2: |
| # The transpose of a 1d or 0d tensor is itself. ONNX does not define the behavior |
| # clearly and onnxruntime fails on these cases. So we add an Identity node to |
| # mirror the behavior of eager mode. |
| return g.op("Identity", self) |
| return g.op("Transpose", self, perm_i=(1, 0)) |
| |
| |
| @_onnx_symbolic("aten::numpy_T") |
| @symbolic_helper.quantized_args(True) |
| @_beartype.beartype |
| def numpy_T(g: jit_utils.GraphContext, input): |
| ndim = symbolic_helper._get_tensor_rank(input) |
| assert ndim is not None |
| perm = list(reversed(range(0, ndim))) |
| return g.op("Transpose", input, perm_i=perm) |
| |
| |
| @_onnx_symbolic("aten::expand") |
| @symbolic_helper.quantized_args(True) |
| @_beartype.beartype |
| def expand(g: jit_utils.GraphContext, self, size, implicit): |
| size = symbolic_helper._maybe_get_const(size, "is") |
| if not symbolic_helper._is_value(size): |
| size = g.op("Constant", value_t=torch.LongTensor(size)) |
| elif symbolic_helper._is_packed_list(size): |
| # Expand with -1 dim value means dim is unchanged. |
| # Since onnx::expand supports two-way broadcasting, |
| # -1 dim value can be exported to onnx as 1 |
| size = symbolic_helper._reshape_helper( |
| g, stack(g, size, 0), g.op("Constant", value_t=torch.tensor([-1])) |
| ) |
| dtype = _type_utils.JitScalarType.INT64 |
| ones = ones_like(g, size, dtype) |
| neg_ones = mul(g, ones, g.op("Constant", value_t=torch.tensor(-1))) |
| size = where(g, g.op("Equal", size, neg_ones), ones, size) |
| return g.op("Expand", self, size) |
| |
| |
| @_onnx_symbolic("aten::broadcast_to") |
| @symbolic_helper.quantized_args(True) |
| @_beartype.beartype |
| def broadcast_to(g: jit_utils.GraphContext, self, size): |
| size = symbolic_helper._maybe_get_const(size, "is") |
| if not symbolic_helper._is_value(size): |
| size = g.op("Constant", value_t=torch.LongTensor(size)) |
| elif symbolic_helper._is_packed_list(size): |
| # Expand with -1 dim value means dim is unchanged. |
| # Since onnx::expand supports two-way broadcasting, |
| # -1 dim value can be exported to onnx as 1 |
| size = symbolic_helper._reshape_helper( |
| g, stack(g, size, 0), g.op("Constant", value_t=torch.tensor([-1])) |
| ) |
| dtype = _type_utils.JitScalarType.INT64 |
| ones = ones_like(g, size, dtype) |
| neg_ones = mul(g, ones, g.op("Constant", value_t=torch.tensor(-1))) |
| size = where(g, g.op("Equal", size, neg_ones), ones, size) |
| return g.op("Expand", self, size) |
| |
| |
| @_onnx_symbolic("aten::expand_as") |
| @symbolic_helper.quantized_args(True, True) |
| @_beartype.beartype |
| def expand_as(g: jit_utils.GraphContext, self, other): |
| self_t = symbolic_helper._maybe_get_const(self, "t") |
| if isinstance(self_t, torch.Tensor): |
| orig_type = self_t.dtype |
| self_t = self_t.to(torch.double) |
| dims = [] |
| for d in range(self_t.dim()): |
| if torch.equal(self_t.mean(d).unsqueeze(d).expand_as(self_t), self_t): |
| dims.append(d) |
| self = g.op( |
| "Constant", value_t=self_t.mean(dims, keepdim=True).to(orig_type) |
| ) |
| |
| shape = g.op("Shape", other) |
| return g.op("Expand", self, shape) |
| |
| |
| @_onnx_symbolic("aten::embedding") |
| @symbolic_helper.quantized_args(True) |
| @symbolic_helper.parse_args("v", "v", "i", "b", "v") |
| @_beartype.beartype |
| def embedding( |
| g: jit_utils.GraphContext, |
| weight, |
| indices, |
| padding_idx, |
| scale_grad_by_freq, |
| sparse, |
| ): |
| if scale_grad_by_freq and GLOBALS.export_training: |
| raise errors.SymbolicValueError( |
| "Unsupported: ONNX export of embedding with scale_grad_by_freq=True " |
| "for training mode. ONNX does not support scaling the gradients.", |
| weight, |
| ) |
| if padding_idx >= 0 and GLOBALS.export_training: |
| warnings.warn( |
| "Warning: ONNX export of embedding with padding_idx >= 0 " |
| "for training mode. " |
| "ONNX does not support not updating the embedding vector at padding_idx during training." |
| ) |
| |
| return g.op("Gather", weight, indices) |
| |
| |
| @_onnx_symbolic("aten::embedding_bag") |
| @symbolic_helper.quantized_args(True) |
| @symbolic_helper.parse_args("v", "v", "v", "i", "i", "i", "v", "i", "i") |
| @_beartype.beartype |
| def embedding_bag( |
| g: jit_utils.GraphContext, |
| embedding_matrix, |
| indices, |
| offsets, |
| scale_grad_by_freq, |
| mode, |
| sparse, |
| per_sample_weights, |
| include_last_offset, |
| padding_idx, |
| ): |
| if not symbolic_helper._is_none(per_sample_weights): |
| return symbolic_helper._onnx_unsupported( |
| "embedding_bag with per_sample_weights" |
| ) |
| if symbolic_helper.is_caffe2_aten_fallback(): |
| return g.at( |
| "embedding_bag", |
| embedding_matrix, |
| indices, |
| offsets, |
| outputs=4, |
| scale_grad_by_freq_i=scale_grad_by_freq, |
| mode_i=mode, |
| sparse_i=sparse, |
| include_last_offset_i=include_last_offset, |
| padding_idx_i=padding_idx, |
| ) |
| |
| return symbolic_helper._onnx_unsupported("embedding_bag", embedding_matrix) |
| |
| |
| @_onnx_symbolic("aten::size") |
| @symbolic_helper.quantized_args(True, quantize_output=False) |
| @_beartype.beartype |
| def size(g: jit_utils.GraphContext, self, dim=None): |
| if dim is None: |
| return g.op("Shape", self) |
| if symbolic_helper._maybe_get_const(dim, "i") < 0: |
| rank = symbolic_helper._get_tensor_rank(self) |
| if rank is not None: |
| dim = symbolic_helper._maybe_get_const(dim, "i") + rank |
| dim = g.op("Constant", value_t=torch.tensor(dim)) |
| return symbolic_helper._size_helper(g, self, dim) |
| |
| |
| @_onnx_symbolic("aten::transpose") |
| @symbolic_helper.quantized_args(True) |
| @symbolic_helper.parse_args("v", "i", "i") |
| @_beartype.beartype |
| def transpose(g: jit_utils.GraphContext, self, dim0, dim1): |
| if dim0 == dim1: # micro-optimization |
| return self |
| |
| # NB: Transpose in ONNX is actually a Permute |
| rank = symbolic_helper._get_tensor_rank(self) |
| if rank is not None: |
| axes = list(range(rank)) |
| axes[dim0], axes[dim1] = axes[dim1], axes[dim0] |
| return g.op("Transpose", self, perm_i=axes) |
| elif symbolic_helper.is_caffe2_aten_fallback(): |
| # if we don't have dim information we cannot |
| # output a permute so use ATen instead |
| return g.at("transpose", self, overload_name="int", dim0_i=dim0, dim1_i=dim1) |
| else: |
| raise errors.SymbolicValueError( |
| "Unsupported: ONNX export of transpose for tensor of unknown rank.", |
| self, |
| ) |
| |
| |
| @_onnx_symbolic("aten::permute") |
| @symbolic_helper.parse_args("v", "is") |
| @_beartype.beartype |
| def permute(g: jit_utils.GraphContext, self, dims): |
| if dims == list(range(0, len(dims))): |
| return self |
| return g.op("Transpose", self, perm_i=dims) |
| |
| |
| @_onnx_symbolic("aten::view") |
| @symbolic_helper.quantized_args(True) |
| @_beartype.beartype |
| def view(g: jit_utils.GraphContext, self, size): |
| return reshape(g, self, size) |
| |
| |
| @_onnx_symbolic("aten::view_as") |
| @_beartype.beartype |
| def view_as(g: jit_utils.GraphContext, self, other): |
| shape = g.op("Shape", other) |
| return reshape(g, self, shape) |
| |
| |
| @_onnx_symbolic("aten::unsafe_chunk") |
| @symbolic_helper.parse_args("v", "i", "i", "i") |
| @_beartype.beartype |
| def unsafe_chunk(g: jit_utils.GraphContext, self, chunks, dim, _outputs=None): |
| if _outputs is None: |
| return symbolic_helper._onnx_opset_unsupported_detailed( |
| "unsafe_chunk", 9, 11, "Dynamic number of outputs not supported", self |
| ) |
| size = symbolic_helper._get_tensor_dim_size(self, dim) |
| if size is None: |
| return symbolic_helper._unimplemented( |
| "unsafe_chunk", "unknown dimension size", self |
| ) |
| split_size = (size + chunks - 1) // chunks |
| splits = [split_size] * (size // split_size) |
| leftover = size % split_size |
| if leftover: |
| splits.append(leftover) |
| return g.op("Split", self, split_i=splits, axis_i=dim, outputs=_outputs) |
| |
| |
| @_onnx_symbolic("aten::split") |
| @symbolic_helper.parse_args("v", "v", "i", "i") |
| @_beartype.beartype |
| def split(g: jit_utils.GraphContext, self, split_size_or_sizes, dim, _outputs=None): |
| if not symbolic_helper._is_split_static(split_size_or_sizes, _outputs): |
| return symbolic_helper._onnx_opset_unsupported_detailed( |
| "split", 9, 11, "Dynamic number of outputs not supported", self |
| ) |
| split_val = symbolic_helper._node_get(split_size_or_sizes.node(), "value") |
| if split_val.dim() > 0: |
| return split_with_sizes(g, self, split_size_or_sizes, dim, _outputs) |
| split_size = symbolic_helper._get_const(split_size_or_sizes, "i", "split_size") |
| |
| size = symbolic_helper._get_tensor_dim_size(self, dim) |
| if size is None: |
| if _outputs is not None: |
| size = split_size * _outputs |
| else: |
| return symbolic_helper._onnx_opset_unsupported_detailed( |
| "split", 9, 11, "Unknown dimension size not supported", self |
| ) |
| splits = [split_size] * (size // split_size) |
| leftover = size % split_size |
| if leftover: |
| splits.append(leftover) |
| return g.op("Split", self, split_i=splits, axis_i=dim, outputs=_outputs) |
| |
| |
| @_onnx_symbolic("aten::unsafe_split") |
| @_beartype.beartype |
| def unsafe_split( |
| g: jit_utils.GraphContext, self, split_size_or_sizes, dim, _outputs=None |
| ): |
| return split(g, self, split_size_or_sizes, dim, _outputs) |
| |
| |
| @_onnx_symbolic("aten::split_with_sizes") |
| @symbolic_helper.parse_args("v", "is", "i", "i") |
| @_beartype.beartype |
| def split_with_sizes(g: jit_utils.GraphContext, self, split_sizes, dim, _outputs=None): |
| if not symbolic_helper._is_split_static(split_sizes, _outputs): |
| return symbolic_helper._onnx_opset_unsupported_detailed( |
| "split_with_sizes", 9, 11, "Dynamic number of outputs not supported", self |
| ) |
| return g.op("Split", self, split_i=split_sizes, axis_i=dim, outputs=_outputs) |
| |
| |
| @_onnx_symbolic("aten::unsafe_split_with_sizes") |
| @_beartype.beartype |
| def unsafe_split_with_sizes( |
| g: jit_utils.GraphContext, self, split_sizes, dim, _outputs=None |
| ): |
| return split_with_sizes(g, self, split_sizes, dim, _outputs) |
| |
| |
| @_onnx_symbolic("aten::unbind") |
| @symbolic_helper.parse_args("v", "i", "i") |
| @_beartype.beartype |
| def unbind(g: jit_utils.GraphContext, self, dim=0, _outputs=None): |
| if _outputs is None: |
| return symbolic_helper._onnx_opset_unsupported_detailed( |
| "unbind", 9, 11, "Dynamic number of outputs not supported", self |
| ) |
| |
| outputs = g.op("Split", self, split_i=[1] * _outputs, axis_i=dim, outputs=_outputs) |
| outputs = [outputs] if _outputs == 1 else outputs |
| squeezed_outputs = [ |
| symbolic_helper._squeeze_helper(g, out, [dim]) for out in outputs |
| ] |
| return squeezed_outputs |
| |
| |
| @_onnx_symbolic("aten::select") |
| @symbolic_helper.quantized_args(True) |
| @symbolic_helper.parse_args("v", "i", "v") |
| @_beartype.beartype |
| def select(g: jit_utils.GraphContext, self, dim, index): |
| index = symbolic_helper._maybe_get_scalar(index) |
| if (not symbolic_helper._is_value(index)) and (index < 0): |
| if index == -1: |
| end_index = _constants.INT64_MAX |
| else: |
| end_index = index + 1 |
| slice_node = symbolic_helper._slice_helper( |
| g, self, axes=[dim], starts=[index], ends=[end_index] |
| ) |
| return symbolic_helper._squeeze_helper(g, slice_node, [dim]) |
| else: |
| # FIXME(justinchuby): can index be an int and not a value? |
| return g.op("Gather", self, index, axis_i=dim) |
| |
| |
| @_onnx_symbolic("aten::square") |
| @_beartype.beartype |
| def square(g: jit_utils.GraphContext, self): |
| return g.op("Mul", self, self) |
| |
| |
| @_onnx_symbolic("aten::squeeze") |
| @_beartype.beartype |
| def squeeze(g: jit_utils.GraphContext, self, dim=None): |
| if dim is None: |
| return g.op("Squeeze", self) |
| |
| squeeze_dim = symbolic_helper._get_const(dim, "i", "dim") |
| # Handle negative dims |
| if squeeze_dim < 0: |
| rank = symbolic_helper._get_tensor_rank(self) |
| if rank is not None: |
| warnings.warn( |
| "ONNX export squeeze with negative axis " |
| + str(squeeze_dim) |
| + " might cause the onnx model to be incorrect. " |
| + "Negative axis is not supported in ONNX. " |
| + "Axis is converted to " |
| + str(squeeze_dim + rank) |
| + " based on input shape at export time. " |
| + "Passing an tensor of different rank in execution will be incorrect." |
| ) |
| squeeze_dim += rank |
| else: |
| return symbolic_helper._unimplemented( |
| "squeeze", "negative axis with unknown input rank", self |
| ) |
| |
| dim_size = symbolic_helper._get_tensor_dim_size(self, squeeze_dim) |
| if dim_size is None: |
| warnings.warn( |
| "This model contains a squeeze operation on dimension " |
| + str(squeeze_dim) |
| + " on an input " |
| + "with unknown shape. Note that if the size of dimension " |
| + str(squeeze_dim) |
| + " of the input " |
| + "is not 1, the ONNX model will return an error. Opset version 11 supports squeezing on " |
| + "non-singleton dimensions, it is recommended to export this model using opset " |
| + "version 11 or higher." |
| ) |
| return symbolic_helper._squeeze_helper(g, self, axes_i=[squeeze_dim]) |
| if dim_size > 1: |
| warnings.warn( |
| "This model contains a squeeze operation on dimension " |
| + str(squeeze_dim) |
| + ". The size of " |
| + "this dimension in the given input is " |
| + str(dim_size) |
| + ". The model will " |
| + "be exported without the squeeze node. If the model is intended to be used with dynamic " |
| + "input shapes, please use opset version 11 to " |
| + "export the model." |
| ) |
| return self |
| |
| warnings.warn( |
| "This model contains a squeeze operation on dimension " |
| + str(squeeze_dim) |
| + ". If the model is " |
| + "intended to be used with dynamic input shapes, please use opset version 11 to export the model." |
| ) |
| return symbolic_helper._squeeze_helper(g, self, axes_i=[squeeze_dim]) |
| |
| |
| @_onnx_symbolic("aten::prelu") |
| @_beartype.beartype |
| def prelu(g: jit_utils.GraphContext, self, weight): |
| self_rank = symbolic_helper._get_tensor_rank(self) |
| weight_sizes = symbolic_helper._get_tensor_sizes(weight) |
| weight_rank = len(weight_sizes) |
| if self_rank is not None: |
| if self_rank > 2: |
| # make weight unidirectional broadcastable |
| weight = symbolic_helper._unsqueeze_helper( |
| g, weight, list(range(1, self_rank - 1)) |
| ) |
| elif self_rank == 0 and weight_sizes == [1]: |
| # self and weight are both scalar but weight has rank == 1, squeeze weight. |
| weight = symbolic_helper._squeeze_helper(g, weight, [0]) |
| weight_rank = 0 |
| |
| if self_rank is not None and weight_rank is not None: |
| assert ( |
| self_rank >= weight_rank |
| ), f"rank(x) should be >= rank(slope) but got {self_rank} < {weight_rank}" |
| return g.op("PRelu", self, weight) |
| |
| |
| @_onnx_symbolic("aten::silu") |
| @_beartype.beartype |
| def silu(g: jit_utils.GraphContext, input): |
| return g.op("Mul", input, g.op("Sigmoid", input)) |
| |
| |
| @_onnx_symbolic("aten::mish") |
| @_beartype.beartype |
| def mish(g: jit_utils.GraphContext, input): |
| return g.op("Mul", input, g.op("Tanh", g.op("Softplus", input))) |
| |
| |
| @_beartype.beartype |
| def _op_with_optional_float_cast(g: jit_utils.GraphContext, op_name, *args, **kwargs): |
| """Some PyTorch operators (e.g., Clip/Min/ReLU/Pad) are super set of ONNX in terms of data types. |
| This function maximizes the exportability of PyTorch-ONNX by allowing ONNX-unsupported PyTorch |
| operator data type. For example, `Cast<int>(Clip<float>(Cast<float>(INPUT)))` can be used to mimic |
| `Clip<int>(INPUT)` (opset version < 12). |
| |
| Args: |
| g (torch._C.Graph): graph to write the ONNX representation into. |
| op_name (str): operator name in ONNX. |
| *args (tuple): operands to the operator. |
| **kwargs (dict): attributes to the operator along with "opset_before" (optional, None by default) |
| indicating the smallest opset version to trigger such casting behavior and "target_float_t" |
| (optional, torch.onnx.JitScalarType.FLOAT by default) indicating the data type of internal operator. |
| |
| Returns: |
| Optional[torch._C.Value, Tuple[torch._C.Value, ...]]: output(s) of the operator. |
| """ |
| opset_before = kwargs.pop("opset_before", None) |
| target_float_t = kwargs.pop("target_float_t", _type_utils.JitScalarType.FLOAT) |
| |
| inputs = list(args) |
| dtype_0 = _type_utils.JitScalarType.from_value(inputs[0]) |
| |
| require_cast = not symbolic_helper._is_fp(inputs[0]) and ( |
| opset_before is None or GLOBALS.export_onnx_opset_version < opset_before |
| ) |
| |
| if require_cast: |
| for input in inputs: |
| if input.isCompleteTensor(): |
| input_scalar_type = _type_utils.JitScalarType.from_value(input) |
| if input_scalar_type != dtype_0: |
| raise errors.SymbolicValueError( |
| f"Inputs of {op_name} must have same dtype." |
| f"Got {dtype_0.scalar_name()} and {input_scalar_type.scalar_name()}", |
| input, |
| ) |
| for i, input in enumerate(inputs): |
| if input.isCompleteTensor() and not symbolic_helper._is_fp(input): |
| inputs[i] = g.op( |
| "Cast", |
| input, |
| to_i=target_float_t.onnx_type(), |
| ) |
| |
| self = g.op(op_name, *inputs, **kwargs) |
| |
| if require_cast: |
| self = g.op("Cast", self, to_i=dtype_0.onnx_type()) |
| |
| return self |
| |
| |
| @_onnx_symbolic("aten::relu") |
| @symbolic_helper.quantized_args(True) |
| @_beartype.beartype |
| def relu(g: jit_utils.GraphContext, input): |
| return _op_with_optional_float_cast(g, "Relu", input, opset_before=14) |
| |
| |
| @_onnx_symbolic("aten::relu6") |
| @symbolic_helper.quantized_args(True) |
| @_beartype.beartype |
| def relu6(g: jit_utils.GraphContext, input): |
| return clamp(g, input, 0, 6) |
| |
| |
| @_onnx_symbolic("aten::ceil") |
| @_beartype.beartype |
| def ceil(g: jit_utils.GraphContext, input): |
| return g.op("Ceil", input) |
| |
| |
| @_onnx_symbolic("aten::floor") |
| @_beartype.beartype |
| def floor(g: jit_utils.GraphContext, input): |
| return g.op("Floor", input) |
| |
| |
| @_onnx_symbolic("aten::len") |
| @_beartype.beartype |
| def _len(g: jit_utils.GraphContext, self): |
| sz_0 = size(g, self, g.op("Constant", value_t=torch.LongTensor([0]))) |
| return symbolic_helper._squeeze_helper(g, sz_0, [0]) |
| |
| |
| @_onnx_symbolic("aten::threshold") |
| @symbolic_helper.parse_args("v", "t", "t") |
| @_beartype.beartype |
| def threshold(g: jit_utils.GraphContext, self, threshold, value): |
| # See Note [Export inplace] |
| if symbolic_helper._scalar(threshold) != 0: |
| return symbolic_helper._unimplemented("threshold", "non-zero threshold", self) |
| if symbolic_helper._scalar(value) != 0: |
| return symbolic_helper._unimplemented("threshold", "non-zero value", self) |
| return g.op("Relu", self) |
| |
| |
| @_onnx_symbolic("aten::leaky_relu") |
| @symbolic_helper.quantized_args(True) |
| @symbolic_helper.parse_args("v", "f", "b") |
| @_beartype.beartype |
| def leaky_relu( |
| g: jit_utils.GraphContext, |
| input: _C.Value, |
| negative_slope: float, |
| inplace: bool = False, |
| ): |
| # See Note [Export inplace] |
| return g.op("LeakyRelu", input, alpha_f=negative_slope) |
| |
| |
| @_onnx_symbolic("aten::glu") |
| @symbolic_helper.parse_args("v", "i") |
| @_beartype.beartype |
| def glu(g: jit_utils.GraphContext, input, dim): |
| dim_size = symbolic_helper._get_tensor_dim_size(input, dim) |
| if dim_size is not None: |
| assert dim_size % 2 == 0 |
| |
| first, second = g.op("Split", input, axis_i=dim, outputs=2) |
| return g.op("Mul", first, g.op("Sigmoid", second)) |
| |
| |
| @_onnx_symbolic("aten::softmax") |
| @symbolic_helper.parse_args("v", "i", "none") |
| @_beartype.beartype |
| def softmax(g: jit_utils.GraphContext, input, dim, dtype=None): |
| # Softmax does normalization at vector level. |
| # PyTorch and ONNX use different strategies to split the input tensor into vectors. |
| # Thus dim and axis have different meanings. |
| # PyTorch slices the input tensor into vectors along the `dim`-th dimension. |
| # ONNX reshapes the input into a 2-D tensor, and `axis` indicates where the input is coerced. |
| # If input is a 2 x 3 tensor: |
| # input = [[1.0, 1.0, 1.0], |
| # [1.0, 1,0, 1,0]] |
| # with dim = 0, the result is: |
| # result = [[0.5, 0.5, 0.5], |
| # [0.5, 0.5, 0.5]] |
| # with axis = 0, the result is: |
| # result = [[0.167, 0.167, 0.167], |
| # [0.167, 0.167, 0.167]] |
| # So only when dim and axis both equal to ndim - 1 (the last dimension), |
| # their semantics are equivalent. |
| # So use softmax when dim and axis both equal to ndim - 1, |
| # otherwise transpose the input to put the vectors to be normalized to the last dimension. |
| # When input rank is not known at export time we compute softmax using a subgraph |
| # with other operators |
| input_dim = symbolic_helper._get_tensor_rank(input) |
| if input_dim is not None: |
| # TODO: remove this as onnx opset 11 spec allows negative axes |
| if dim < 0: |
| dim = input_dim + dim |
| |
| is_transpose_required = input_dim != dim + 1 |
| |
| if is_transpose_required: |
| axes = list(range(input_dim)) |
| axes[dim], axes[-1] = axes[-1], axes[dim] |
| input = g.op("Transpose", input, perm_i=axes) |
| dim = input_dim - 1 |
| |
| softmax = g.op("Softmax", input, axis_i=dim) |
| if dtype and dtype.node().kind() != "prim::Constant": |
| parsed_dtype = symbolic_helper._get_const(dtype, "i", "dtype") |
| softmax = g.op( |
| "Cast", |
| softmax, |
| to_i=_type_utils.JitScalarType(parsed_dtype).onnx_type(), |
| ) |
| |
| if is_transpose_required: |
| softmax = g.op("Transpose", softmax, perm_i=axes) |
| return softmax |
| |
| # Apply max normalization. |
| input = g.op("Sub", input, g.op("ReduceMax", input, axes_i=[dim], keepdims_i=1)) |
| |
| exp = g.op("Exp", input) |
| sum = symbolic_helper._reducesum_helper(g, exp, axes_i=[dim]) |
| softmax = g.op("Div", exp, sum) |
| if dtype and dtype.node().kind() != "prim::Constant": |
| parsed_dtype = symbolic_helper._get_const(dtype, "i", "dtype") |
| softmax = g.op( |
| "Cast", softmax, to_i=_type_utils.JitScalarType(parsed_dtype).onnx_type() |
| ) |
| return softmax |
| |
| |
| @_onnx_symbolic("aten::softplus") |
| @_beartype.beartype |
| def softplus(g: jit_utils.GraphContext, self, beta, threshold): |
| beta_const = symbolic_helper._maybe_get_const(beta, "f") |
| if beta_const != 1: |
| return g.op("Div", g.op("Softplus", g.op("Mul", self, beta)), beta) |
| return g.op("Softplus", self) |
| |
| |
| @_onnx_symbolic("aten::get_pool_ceil_padding") |
| @_beartype.beartype |
| def get_pool_ceil_padding(input, kernel_size, stride, padding): |
| # TODO(justinchuby): Looks like this op is deprecated in torch |
| sizes = symbolic_helper._get_tensor_sizes(input) |
| dim = sizes[-len(padding) :] if sizes is not None else None |
| if dim is None or any(i is None for i in dim): |
| return symbolic_helper._unimplemented( |
| "get_pool_ceil_padding", "input size not accessible", input |
| ) |
| ceiled_output_dim = [ |
| int(math.ceil((dim[i] + 2 * padding[i] - kernel_size[i]) / float(stride[i]))) |
| + 1 |
| for i in range(0, len(padding)) |
| ] |
| # ensure last pooling starts inside |
| ceiled_output_dim = [ |
| ceiled_output_dim[i] - 1 |
| if (((ceiled_output_dim[i] - 1) * stride[i]) >= (dim[i] + padding[i])) |
| else ceiled_output_dim[i] |
| for i in range(0, len(ceiled_output_dim)) |
| ] |
| padding_ceil = [ |
| 0 |
| if (stride[i] == 1) |
| else ( |
| kernel_size[i] |
| - (dim[i] + 2 * padding[i] - ((ceiled_output_dim[i] - 1) * stride[i] + 1)) |
| ) |
| for i in range(0, len(padding)) |
| ] |
| # ensure padding is not > kernel_size |
| padding_ceil = [ |
| ( |
| int(padding_ceil[i]) |
| if padding_ceil[i] < kernel_size[i] - 1 |
| else int(kernel_size[i] - 1) |
| ) |
| if ((padding_ceil[i] + 2 * padding[i]) >= (kernel_size[i])) |
| else int(padding_ceil[i]) |
| for i in range(0, len(padding_ceil)) |
| ] |
| return padding_ceil |
| |
| |
| @_onnx_symbolic( |
| "aten::max_pool1d", |
| decorate=[ |
| _apply_params( |
| "max_pool1d", torch.nn.modules.utils._single, 1, return_indices=False |
| ), |
| _export("max_pool1d"), |
| ], |
| ) |
| @_onnx_symbolic( |
| "aten::max_pool2d", |
| decorate=[ |
| _apply_params( |
| "max_pool2d", torch.nn.modules.utils._pair, 2, return_indices=False |
| ), |
| _export("max_pool2d"), |
| ], |
| ) |
| @_onnx_symbolic( |
| "aten::max_pool3d", |
| decorate=[ |
| _apply_params( |
| "max_pool3d", torch.nn.modules.utils._triple, 3, return_indices=False |
| ), |
| _export("max_pool3d"), |
| ], |
| ) |
| @_beartype.beartype |
| def _max_pool(name, tuple_fn, ndims, return_indices): |
| @symbolic_helper.quantized_args(True, False, False, False, False, False) |
| @symbolic_helper.parse_args("v", "is", "is", "is", "is", "i") |
| @_beartype.beartype |
| def symbolic_fn(g, input, kernel_size, stride, padding, dilation, ceil_mode): |
| if set(tuple_fn(dilation)) != {1}: |
| return symbolic_helper._unimplemented(name, "dilation", input) |
| if not stride: |
| stride = kernel_size |
| padding = tuple(tuple_fn(padding)) |
| if ceil_mode: |
| padding_ceil = get_pool_ceil_padding(input, kernel_size, stride, padding) |
| padding = padding + tuple(a + b for (a, b) in zip(padding_ceil, padding)) |
| else: |
| padding = padding * 2 |
| kwargs = { |
| "kernel_shape_i": tuple_fn(kernel_size), |
| "pads_i": padding, |
| "strides_i": tuple_fn(stride), |
| } |
| # easy but hacky way to get flattened indices values |
| # to be used to convert the indices values to non-flattened. |
| # In ONNX the indices are computed as a flatten 1-D tensor, |
| # so the values in indices are in [0, N x C x D1 x ... x Dn). |
| # To convert the indices to the same format used by Pytorch, |
| # we first execute a maxpool with a kernel and stride of 1 on the same input. |
| # This will result in a tensor of indices in which each index will have it's own value. |
| # Using this tensor as a reference, we extract the first index of each axis and subtract |
| # it from each index of this axis in the indices to convert. |
| # This step will result in a tensor were each dimension has values of indices within |
| # the dimension it is in. |
| # For more information : |
| # https://github.com/pytorch/pytorch/pull/16455#issuecomment-460776407 |
| if return_indices: |
| r, indices = g.op("MaxPool", input, outputs=2, **kwargs) |
| _, flattened_indices = g.op( |
| "MaxPool", |
| input, |
| outputs=2, |
| kernel_shape_i=[1 for _ in range(ndims)], |
| strides_i=[1 for _ in range(ndims)], |
| ) |
| # convert indices to have non-flattened indices values |
| s = symbolic_helper._slice_helper( |
| g, |
| flattened_indices, |
| axes=[2 + i for i in range(ndims)], |
| starts=list(tuple_fn(0)), |
| ends=list(tuple_fn(1)), |
| ) |
| indices = sub(g, indices, s) |
| return r, indices |
| else: |
| r = g.op("MaxPool", input, outputs=1, **kwargs) |
| return r |
| |
| return symbolic_fn |
| |
| |
| max_pool1d_with_indices = _onnx_symbolic("aten::max_pool1d_with_indices")( |
| _max_pool( |
| "max_pool1d_with_indices", |
| torch.nn.modules.utils._single, |
| 1, |
| return_indices=True, |
| ) |
| ) |
| max_pool2d_with_indices = _onnx_symbolic("aten::max_pool2d_with_indices")( |
| _max_pool( |
| "max_pool2d_with_indices", |
| torch.nn.modules.utils._pair, |
| 2, |
| return_indices=True, |
| ) |
| ) |
| max_pool3d_with_indices = _onnx_symbolic("aten::max_pool3d_with_indices")( |
| _max_pool( |
| "max_pool3d_with_indices", |
| torch.nn.modules.utils._triple, |
| 3, |
| return_indices=True, |
| ) |
| ) |
| |
| |
| @_onnx_symbolic( |
| "aten::avg_pool1d", |
| decorate=[ |
| _apply_params("avg_pool1d", torch.nn.modules.utils._single), |
| _export("avg_pool1d"), |
| ], |
| ) |
| @_onnx_symbolic( |
| "aten::avg_pool2d", |
| decorate=[ |
| _apply_params("avg_pool2d", torch.nn.modules.utils._pair), |
| _export("avg_pool2d"), |
| ], |
| ) |
| @_onnx_symbolic( |
| "aten::avg_pool3d", |
| decorate=[ |
| _apply_params("avg_pool3d", torch.nn.modules.utils._triple), |
| _export("avg_pool3d"), |
| ], |
| ) |
| @_beartype.beartype |
| def _avg_pool(name, tuple_fn): |
| @symbolic_helper.quantized_args(True) |
| @symbolic_helper.parse_args("v", "is", "is", "is", "i", "i", "none") |
| @_beartype.beartype |
| def symbolic_fn( |
| g, |
| input: _C.Value, |
| kernel_size: Sequence[int], |
| stride: Sequence[int], |
| padding: Union[int, Sequence[int]], |
| ceil_mode: int, |
| count_include_pad: int, |
| divisor_override=None, |
| ): |
| if not stride: |
| stride = kernel_size |
| padding = symbolic_helper._avgpool_helper( |
| tuple_fn, padding, kernel_size, stride, divisor_override, name |
| ) |
| assert isinstance(padding, tuple) |
| adjusted_padding = padding |
| # Although onnx::AvgPool provides count_include_pad, |
| # The corner case of Average Pooling with ceil_mode on |
| # PyTorch allows sliding window go off bound, which leads to |
| # this accommodation. |
| # More detail on https://github.com/pytorch/pytorch/issues/57178 |
| if count_include_pad: |
| input = _op_with_optional_float_cast( |
| g, |
| "Pad", |
| input, |
| pads_i=((0,) * 2 + padding) * 2, |
| mode_s="constant", |
| value_f=0.0, |
| opset_before=11, |
| ) |
| adjusted_padding = (0,) * len(padding) |
| if ceil_mode: |
| padding_ceil = get_pool_ceil_padding(input, kernel_size, stride, padding) |
| adjusted_padding = adjusted_padding + tuple( |
| a + b for (a, b) in zip(padding_ceil, adjusted_padding) |
| ) |
| else: |
| adjusted_padding = adjusted_padding * 2 |
| output = g.op( |
| "AveragePool", |
| input, |
| kernel_shape_i=tuple_fn(kernel_size), |
| strides_i=tuple_fn(stride), |
| pads_i=adjusted_padding, |
| ) |
| return output |
| |
| return symbolic_fn |
| |
| |
| @_onnx_symbolic( |
| "aten::adaptive_avg_pool1d", |
| decorate=[ |
| _apply_params( |
| "adaptive_avg_pool1d", "AveragePool", torch.nn.modules.utils._single |
| ), |
| _export("adaptive_avg_pool1d"), |
| ], |
| ) |
| @_onnx_symbolic( |
| "aten::adaptive_avg_pool2d", |
| decorate=[ |
| _apply_params( |
| "adaptive_avg_pool2d", "AveragePool", torch.nn.modules.utils._pair |
| ), |
| _export("adaptive_avg_pool2d"), |
| ], |
| ) |
| @_onnx_symbolic( |
| "aten::adaptive_avg_pool3d", |
| decorate=[ |
| _apply_params( |
| "adaptive_avg_pool3d", "AveragePool", torch.nn.modules.utils._triple |
| ), |
| _export("adaptive_avg_pool3d"), |
| ], |
| ) |
| @_onnx_symbolic( |
| "aten::adaptive_max_pool1d", |
| decorate=[ |
| _apply_params( |
| "adaptive_max_pool1d", |
| "MaxPool", |
| torch.nn.modules.utils._single, |
| max_pool1d_with_indices, |
| ), |
| _export("adaptive_max_pool1d"), |
| ], |
| ) |
| @_onnx_symbolic( |
| "aten::adaptive_max_pool2d", |
| decorate=[ |
| _apply_params( |
| "adaptive_max_pool2d", |
| "MaxPool", |
| torch.nn.modules.utils._pair, |
| max_pool2d_with_indices, |
| ), |
| _export("adaptive_max_pool2d"), |
| ], |
| ) |
| @_onnx_symbolic( |
| "aten::adaptive_max_pool3d", |
| decorate=[ |
| _apply_params( |
| "adaptive_max_pool3d", |
| "MaxPool", |
| torch.nn.modules.utils._triple, |
| max_pool3d_with_indices, |
| ), |
| _export("adaptive_max_pool3d"), |
| ], |
| ) |
| @_beartype.beartype |
| def _adaptive_pool(name, type, tuple_fn, fn=None): |
| @symbolic_helper.quantized_args(True, False) |
| @_beartype.beartype |
| def symbolic_fn(g, input, output_size): |
| # _adaptive_pool is supported for cases where output_size is 1 for all dimensions, |
| # by executing a GlobalPool. |
| # It is also supported for cases where the output size is a factor of the input size. |
| # For these cases the stride and kernel size are uniform along all the indices of |
| # the same dimension, which makes it possible to export it to ONNX. |
| # for MaxPool, GlobalMaxPool does not return indices, |
| # so we try using max_poolxd_with_indices, and if it is not possible |
| # (input is not a complete tensor or output size not factor of input size) |
| # then we call GlobalAveragePool and return None for the indices |
| output_size_value = output_size |
| try: |
| output_size = symbolic_helper._parse_arg(output_size, "is") |
| except Exception: |
| # FIXME(justinchuby): Avoid catching Exception. |
| # Catch a more specific exception instead. |
| return symbolic_helper._onnx_unsupported( |
| "adaptive pooling, since output_size is not constant.", input |
| ) |
| if output_size == [1] * len(output_size) and type == "AveragePool": |
| return g.op("GlobalAveragePool", input) |
| sizes = symbolic_helper._get_tensor_sizes(input) |
| try: |
| dim = sizes[2:] |
| except Exception: |
| # FIXME(justinchuby): Avoid catching Exception. |
| # Catch a more specific exception instead. |
| dim = None |
| if dim is None or any(i is None for i in dim): |
| if output_size == [1] * len(output_size): |
| return g.op("GlobalMaxPool", input), None |
| return symbolic_helper._unimplemented( |
| name, "input size not accessible", input |
| ) |
| # verify if output size % input size = 0 for all dim |
| mod = [dim[i] % output_size[i] for i in range(0, len(dim))] |
| if mod != [0] * len(mod): |
| if output_size == [1] * len(output_size): |
| return g.op("GlobalMaxPool", input), None |
| return symbolic_helper._unimplemented( |
| name, "output size that are not factor of input size", output_size_value |
| ) |
| k = [int(dim[i] / output_size[i]) for i in range(0, len(dim))] |
| # call max_poolxd_with_indices to get indices in the output |
| if type == "MaxPool": |
| return fn(g, input, k, k, (0,) * len(dim), (1,) * len(dim), False) |
| output = g.op(type, input, kernel_shape_i=tuple_fn(k), strides_i=tuple_fn(k)) |
| return output |
| |
| return symbolic_fn |
| |
| |
| @_beartype.beartype |
| def _prepare_onnx_paddings(dim: int, pad): |
| """Generate paddings in ONNX order based on pad in pytorch. |
| Args: |
| dim: the dimension of the tensor. |
| pad: the paddings in pytorch. |
| The order is dim_n_begin, dim_n_end, dim_n-1_begin, dim_n-1_end, ... |
| """ |
| # The desired order of paddings is |
| # dim_0_begin, dim_1_begin, ... , dim_0_end, ..., dim_n_end. |
| # n is the dimension of input. |
| # assume zero-dimensions in the beginning |
| paddings = list(pad[:]) + [0] * (dim * 2 - len(pad)) |
| # reverse order and collate first beginnings and then ends |
| paddings = paddings[-2::-2] + paddings[-1::-2] |
| return paddings |
| |
| |
| @_beartype.beartype |
| def _convert_padding_node(input): |
| padding = symbolic_helper._maybe_get_const(input, "is") |
| if symbolic_helper._is_value(padding) and symbolic_helper._is_packed_list(padding): |
| input_list = symbolic_helper._unpack_list(padding) |
| try: |
| padding = [ |
| symbolic_helper._get_const(v, "i", "padding") for v in input_list |
| ] |
| except Exception: |
| # FIXME(justinchuby): Avoid catching Exception. |
| # Catch a more specific exception instead. |
| return symbolic_helper._onnx_opset_unsupported_detailed( |
| "Pad", 9, 11, "The sizes of the padding must be constant", input |
| ) |
| return padding |
| |
| |
| @_onnx_symbolic("aten::constant_pad_nd") |
| @_beartype.beartype |
| def constant_pad_nd(g: jit_utils.GraphContext, input, padding, value): |
| mode = "constant" |
| try: |
| value = symbolic_helper._get_const(value, "f", "value") |
| except Exception: |
| # FIXME(justinchuby): Avoid catching Exception. |
| # Catch a more specific exception instead. |
| return symbolic_helper._onnx_opset_unsupported_detailed( |
| "Pad", 9, 11, "The value for the padding must be constant", value |
| ) |
| |
| padding = _convert_padding_node(padding) |
| paddings = _prepare_onnx_paddings(symbolic_helper._get_tensor_rank(input), padding) |
| return _op_with_optional_float_cast( |
| g, "Pad", input, pads_i=paddings, mode_s=mode, value_f=value, opset_before=11 |
| ) |
| |
| |
| @_beartype.beartype |
| def _pad_circular(g: jit_utils.GraphContext, input: _C.Value, pad: _C.Value): |
| padding = _convert_padding_node(pad) |
| assert len(padding) % 2 == 0 |
| ndim = len(padding) // 2 |
| |
| cur = input |
| for idx in range(ndim): |
| pad_r = padding[-(2 * idx + 1)] |
| pad_l = padding[-(2 * idx + 2)] |
| tensors = [] |
| if pad_l > 0: |
| left = symbolic_helper._slice_helper( |
| g, cur, axes=[2 + idx], starts=[-(pad_l)], ends=[_constants.INT64_MAX] |
| ) |
| tensors.append(left) |
| |
| if pad_l < 0 or pad_r < 0: |
| start = builtins.max(0, -pad_l) |
| end = -(builtins.max(0, -pad_r)) |
| middle = symbolic_helper._slice_helper( |
| g, |
| cur, |
| axes=[2 + idx], |
| starts=[start], |
| ends=[end], |
| ) |
| tensors.append(middle) |
| else: |
| tensors.append(cur) |
| |
| if pad_r > 0: |
| right = symbolic_helper._slice_helper( |
| g, cur, axes=[2 + idx], starts=[0], ends=[pad_r] |
| ) |
| tensors.append(right) |
| |
| cur = g.op("Concat", *tensors, axis_i=(2 + idx)) |
| |
| return cur |
| |
| |
| @_onnx_symbolic("aten::reflection_pad1d") |
| @_onnx_symbolic("aten::reflection_pad2d") |
| @_onnx_symbolic("aten::reflection_pad3d") |
| @_beartype.beartype |
| def reflection_pad(g: jit_utils.GraphContext, input, padding): |
| mode = "reflect" |
| padding = _convert_padding_node(padding) |
| paddings = _prepare_onnx_paddings(symbolic_helper._get_tensor_rank(input), padding) |
| return _op_with_optional_float_cast( |
| g, "Pad", input, pads_i=paddings, mode_s=mode, opset_before=11 |
| ) |
| |
| |
| @_onnx_symbolic("aten::replication_pad1d") |
| @_onnx_symbolic("aten::replication_pad2d") |
| @_onnx_symbolic("aten::replication_pad3d") |
| @_beartype.beartype |
| def replication_pad(g: jit_utils.GraphContext, input, padding): |
| mode = "edge" |
| padding = _convert_padding_node(padding) |
| paddings = _prepare_onnx_paddings(symbolic_helper._get_tensor_rank(input), padding) |
| return _op_with_optional_float_cast( |
| g, "Pad", input, pads_i=paddings, mode_s=mode, opset_before=11 |
| ) |
| |
| |
| @_onnx_symbolic("aten::pad") |
| @_beartype.beartype |
| def pad( |
| g: jit_utils.GraphContext, |
| input: _C.Value, |
| pad: _C.Value, |
| mode: _C.Value, |
| value: _C.Value, |
| ): |
| mode = symbolic_helper._parse_arg(mode, "s") |
| if mode == "replicate": |
| return replication_pad(g, input, pad) |
| elif mode == "reflect": |
| return reflection_pad(g, input, pad) |
| elif mode == "constant": |
| return constant_pad_nd(g, input, pad, value) |
| elif mode == "circular": |
| return _pad_circular(g, input, pad) |
| else: |
| raise errors.SymbolicValueError(f"Unrecognized padding mode {mode}", input) |
| |
| |
| @_onnx_symbolic( |
| "aten::upsample_nearest1d", |
| decorate=[ |
| _apply_params("upsample_nearest1d", 3, "nearest"), |
| _export("upsample_nearest1d"), |
| ], |
| ) |
| @_onnx_symbolic( |
| "aten::upsample_nearest2d", |
| decorate=[ |
| _apply_params("upsample_nearest2d", 4, "nearest"), |
| _export("upsample_nearest2d"), |
| ], |
| ) |
| @_onnx_symbolic( |
| "aten::upsample_nearest3d", |
| decorate=[ |
| _apply_params("upsample_nearest3d", 5, "nearest"), |
| _export("upsample_nearest3d"), |
| ], |
| ) |
| @_onnx_symbolic( |
| "aten::upsample_linear1d", |
| decorate=[ |
| _apply_params("upsample_linear1d", 3, "linear"), |
| _export("upsample_linear1d"), |
| ], |
| ) |
| @_onnx_symbolic( |
| "aten::upsample_bilinear2d", |
| decorate=[ |
| _apply_params("upsample_bilinear2d", 4, "linear"), |
| _export("upsample_bilinear2d"), |
| ], |
| ) |
| @_onnx_symbolic( |
| "aten::upsample_trilinear3d", |
| decorate=[ |
| _apply_params("upsample_trilinear3d", 5, "linear"), |
| _export("upsample_trilinear3d"), |
| ], |
| ) |
| @_beartype.beartype |
| def _interpolate(name: str, dim: int, interpolate_mode: str): |
| def symbolic_fn(g, input, output_size, *args): |
| scales, align_corners = symbolic_helper._get_interpolate_attributes( |
| g, interpolate_mode, args |
| ) |
| symbolic_helper._interpolate_warning(interpolate_mode) |
| align_corners = symbolic_helper._maybe_get_scalar(align_corners) |
| if align_corners: |
| return symbolic_helper._unimplemented(name, "align_corners == True", input) |
| if scales is None: |
| scales = symbolic_helper._interpolate_size_to_scales( |
| g, input, output_size, dim |
| ) |
| return g.op("Upsample", input, scales, mode_s=interpolate_mode) |
| |
| return symbolic_fn |
| |
| |
| @_onnx_symbolic("aten::__interpolate") |
| @_beartype.beartype |
| def __interpolate( |
| g: jit_utils.GraphContext, |
| input, |
| size, |
| scale_factor, |
| mode, |
| align_corners, |
| recompute_scale_factor, |
| antialias, |
| ): |
| scales, mode = symbolic_helper._interpolate_get_scales_and_mode( |
| g, input, size, scale_factor, mode, align_corners |
| ) |
| return g.op("Upsample", input, scales, mode_s=mode) |
| |
| |
| @_onnx_symbolic("aten::bitwise_not") |
| @_beartype.beartype |
| def bitwise_not(g: jit_utils.GraphContext, input): |
| if not symbolic_helper._is_bool(input): |
| raise errors.SymbolicValueError( |
| "ONNX export does NOT support exporting bitwise Not " |
| "for non-boolean input values", |
| input, |
| ) |
| return g.op("Not", input) |
| |
| |
| @_onnx_symbolic("aten::bitwise_or") |
| @_beartype.beartype |
| def bitwise_or(g, self, other): |
| if not symbolic_helper._is_bool(self): |
| raise errors.SymbolicValueError( |
| "ONNX export does NOT support exporting bitwise OR " |
| "for non-boolean input values. self: ", |
| self, |
| ) |
| if not symbolic_helper._is_bool(other): |
| raise errors.SymbolicValueError( |
| "ONNX export does NOT support exporting bitwise OR " |
| "for non-boolean input values. other: ", |
| other, |
| ) |
| return g.op("Or", self, other) |
| |
| |
| @_beartype.beartype |
| def wrap_logical_op_with_cast_to(to_type): |
| def decorator(fn): |
| @functools.wraps(fn) |
| def wrap_with_cast(g, input, other): |
| to_cast_func = globals()[f"_cast_{to_type}"] |
| return fn(g, to_cast_func(g, input, False), to_cast_func(g, other, False)) |
| |
| return wrap_with_cast |
| |
| return decorator |
| |
| |
| @_beartype.beartype |
| def wrap_logical_op_with_negation(func: Callable) -> Callable: |
| @functools.wraps(func) |
| def wrap_with_not(g, input, other): |
| return g.op("Not", func(g, input, other)) |
| |
| return wrap_with_not |
| |
| |
| @_onnx_symbolic("aten::__not_") |
| @_beartype.beartype |
| def __not_(g: jit_utils.GraphContext, self): |
| if not symbolic_helper._is_bool(self): |
| raise errors.SymbolicValueError( |
| "ONNX export does NOT support exporting bitwise Not " |
| "for non-boolean input values", |
| self, |
| ) |
| return g.op("Not", self) |
| |
| |
| @_onnx_symbolic("aten::eq") |
| @symbolic_helper.quantized_args(True, True) |
| @_beartype.beartype |
| def eq(g: jit_utils.GraphContext, self, other): |
| if isinstance(self.type(), _C.DeviceObjType) and isinstance( |
| other.type(), _C.DeviceObjType |
| ): |
| # ONNX doesn't have devices, so consider them all to be equal. |
| # The no-op check for equality will get constant-folded. |
| return g.op("Constant", value_t=torch.tensor(True, dtype=torch.bool)) |
| self_node = self.node() |
| other_node = other.node() |
| if self_node.kind() == other_node.kind() == "onnx::Constant": |
| if self_node.kindOf("value") == other_node.kindOf("value") == "s": |
| # Exporting strings to ONNX is not supported. |
| # If both strings are constant, we can compare them directly. |
| # The no-op check for equality will get constant-folded. |
| return g.op( |
| "Constant", |
| value_t=torch.tensor( |
| self_node.s("value") == other_node.s("value"), |
| dtype=torch.bool, |
| ), |
| ) |
| |
| return g.op("Equal", self, other) |
| |
| |
| @_onnx_symbolic("aten::ne") |
| @symbolic_helper.quantized_args(True, True) |
| @wrap_logical_op_with_negation |
| @_beartype.beartype |
| def ne(g: jit_utils.GraphContext, self, other): |
| return eq(g, self, other) |
| |
| |
| @_onnx_symbolic("aten::gt") |
| @symbolic_helper.quantized_args(True, True) |
| @_beartype.beartype |
| def gt(g: jit_utils.GraphContext, input, other): |
| return _gt_impl(g, input, other) |
| |
| |
| @_beartype.beartype |
| def _gt_impl(g: jit_utils.GraphContext, input, other): |
| if symbolic_helper._is_bool(input) and symbolic_helper._is_bool(other): |
| input = g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.INT32) |
| other = g.op("Cast", other, to_i=_C_onnx.TensorProtoDataType.INT32) |
| return g.op("Greater", input, other) |
| |
| |
| @_onnx_symbolic("aten::lt") |
| @symbolic_helper.quantized_args(True, True) |
| @_beartype.beartype |
| def lt(g: jit_utils.GraphContext, input, other): |
| return _lt_impl(g, input, other) |
| |
| |
| @_beartype.beartype |
| def _lt_impl(g: jit_utils.GraphContext, input, other): |
| if symbolic_helper._is_bool(input) and symbolic_helper._is_bool(other): |
| input = g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.INT32) |
| other = g.op("Cast", other, to_i=_C_onnx.TensorProtoDataType.INT32) |
| return g.op("Less", input, other) |
| |
| |
| @_onnx_symbolic("aten::ge") |
| @symbolic_helper.quantized_args(True, True) |
| @wrap_logical_op_with_negation |
| @_beartype.beartype |
| def ge(g: jit_utils.GraphContext, input, other): |
| return _lt_impl(g, input, other) |
| |
| |
| @_onnx_symbolic("aten::le") |
| @symbolic_helper.quantized_args(True, True) |
| @wrap_logical_op_with_negation |
| @_beartype.beartype |
| def le(g: jit_utils.GraphContext, input, other): |
| return _gt_impl(g, input, other) |
| |
| |
| @_onnx_symbolic("aten::__and_") |
| @_beartype.beartype |
| def __and_(g: jit_utils.GraphContext, input, other): |
| if not symbolic_helper._is_bool(input): |
| raise errors.SymbolicValueError( |
| "ONNX export does NOT support exporting bitwise AND " |
| "for non-boolean input values", |
| input, |
| ) |
| if not symbolic_helper._is_bool(other): |
| raise errors.SymbolicValueError( |
| "ONNX export does NOT support exporting bitwise AND " |
| "for non-boolean input values", |
| other, |
| ) |
| return g.op("And", input, other) |
| |
| |
| @_onnx_symbolic("aten::__or_") |
| @_beartype.beartype |
| def __or_(g: jit_utils.GraphContext, input, other): |
| if not symbolic_helper._is_bool(input): |
| raise errors.SymbolicValueError( |
| "ONNX export does NOT support exporting bitwise OR " |
| "for non-boolean input values", |
| input, |
| ) |
| if not symbolic_helper._is_bool(other): |
| raise errors.SymbolicValueError( |
| "ONNX export does NOT support exporting bitwise OR " |
| "for non-boolean input values", |
| other, |
| ) |
| return g.op("Or", input, other) |
| |
| |
| @_onnx_symbolic("aten::__xor_") |
| @_beartype.beartype |
| def __xor_(g: jit_utils.GraphContext, input, other): |
| if not symbolic_helper._is_bool(input): |
| raise errors.SymbolicValueError( |
| "ONNX export does NOT support exporting bitwise XOR " |
| "for non-boolean input values", |
| input, |
| ) |
| if not symbolic_helper._is_bool(other): |
| raise errors.SymbolicValueError( |
| "ONNX export does NOT support exporting bitwise XOR " |
| "for non-boolean input values", |
| other, |
| ) |
| return g.op("Xor", input, other) |
| |
| |
| @_onnx_symbolic("aten::logical_and") |
| @wrap_logical_op_with_cast_to("Bool") |
| @_beartype.beartype |
| def logical_and(g: jit_utils.GraphContext, input, other): |
| return g.op("And", input, other) |
| |
| |
| @_onnx_symbolic("aten::logical_or") |
| @wrap_logical_op_with_cast_to("Bool") |
| @_beartype.beartype |
| def logical_or(g: jit_utils.GraphContext, input, other): |
| return g.op("Or", input, other) |
| |
| |
| @_onnx_symbolic("aten::logical_xor") |
| @wrap_logical_op_with_cast_to("Bool") |
| @_beartype.beartype |
| def logical_xor(g: jit_utils.GraphContext, input, other): |
| return g.op("Xor", input, other) |
| |
| |
| @_onnx_symbolic("aten::logical_not") |
| @_beartype.beartype |
| def logical_not(g: jit_utils.GraphContext, input): |
| return g.op("Not", g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.BOOL)) |
| |
| |
| @_onnx_symbolic("aten::__rshift_") |
| @_beartype.beartype |
| def __rshift_(g: jit_utils.GraphContext, self, other): |
| # make sure to cast other to self's type |
| # (when self is long, make sure that other is not float) |
| self_scalar_type = _type_utils.JitScalarType.from_value(self) |
| if ( |
| _type_utils.JitScalarType.from_value(other, _type_utils.JitScalarType.UNDEFINED) |
| != self_scalar_type |
| ): |
| other = g.op( |
| "Cast", |
| other, |
| to_i=self_scalar_type.onnx_type(), |
| ) |
| |
| two = g.op("Constant", value_t=torch.tensor(2, dtype=torch.float32)) |
| # exponent (same type as self) has to be float or double in onnx::Pow |
| if not symbolic_helper._is_fp(self): |
| other = g.op("Cast", other, to_i=_C_onnx.TensorProtoDataType.FLOAT) |
| two_pow = g.op("Pow", two, other) |
| two_pow = g.op( |
| "Cast", |
| two_pow, |
| to_i=self_scalar_type.onnx_type(), |
| ) |
| rshift = g.op("Div", self, two_pow) |
| return rshift |
| |
| |
| @_onnx_symbolic("aten::__lshift_") |
| @_beartype.beartype |
| def __lshift_(g: jit_utils.GraphContext, self, other): |
| # make sure to cast other to self's type |
| # (when self is long, make sure that other is not float) |
| self_scalar_type = _type_utils.JitScalarType.from_value(self) |
| if ( |
| _type_utils.JitScalarType.from_value(other, _type_utils.JitScalarType.UNDEFINED) |
| != self_scalar_type |
| ): |
| other = g.op( |
| "Cast", |
| other, |
| to_i=self_scalar_type.onnx_type(), |
| ) |
| |
| two = g.op("Constant", value_t=torch.tensor(2, dtype=torch.float32)) |
| # exponent (same type as self) has to be float or double in onnx::Pow |
| if not symbolic_helper._is_fp(self): |
| other = g.op("Cast", other, to_i=_C_onnx.TensorProtoDataType.FLOAT) |
| two_pow = g.op("Pow", two, other) |
| two_pow = g.op( |
| "Cast", |
| two_pow, |
| to_i=self_scalar_type.onnx_type(), |
| ) |
| lshift = g.op("Mul", self, two_pow) |
| return lshift |
| |
| |
| @_onnx_symbolic("aten::where") |
| @symbolic_helper.parse_args("v", "v", "v", "i") |
| @_beartype.beartype |
| def where(g: jit_utils.GraphContext, condition, self=None, other=None, _outputs=None): |
| # Assumes that torch.where's first argument takes only Bool and Byte tensors. |
| if not symbolic_helper._is_bool(condition): |
| condition = g.op("Cast", condition, to_i=_C_onnx.TensorProtoDataType.BOOL) |
| if self is None: |
| condition = nonzero(g, condition) |
| return symbolic_helper._unbind_helper( |
| g, condition, g.op("Constant", value_t=torch.tensor(1)), _outputs |
| ) |
| return g.op("Where", condition, self, other) |
| |
| |
| @_onnx_symbolic("aten::log_softmax") |
| @symbolic_helper.parse_args("v", "i", "none") |
| @_beartype.beartype |
| def log_softmax(g: jit_utils.GraphContext, input, dim, dtype=None): |
| # PyTorch dim and ONNX axis have different meanings. |
| # See Softmax comment for details. |
| # TODO: remove this as onnx opset 11 spec allows negative axes |
| input_dim = symbolic_helper._get_tensor_rank(input) |
| if input_dim is None: |
| return symbolic_helper._unimplemented( |
| "dim", |
| "ONNX and PyTorch use different strategies to split the input. " |
| "Input rank must be known at export time.", |
| ) |
| if dim < 0: |
| dim = input_dim + dim |
| is_transpose_required = input_dim != dim + 1 |
| # ONNX only supports log_softmax with dim = -1. Transpose must be added before and after log_softmax to support other cases. |
| if is_transpose_required: |
| axes = list(range(input_dim)) |
| axes[dim], axes[-1] = axes[-1], axes[dim] |
| input = g.op("Transpose", input, perm_i=axes) |
| dim = input_dim - 1 |
| return_op = g.op("LogSoftmax", input, axis_i=dim) |
| if dtype and dtype.node().kind() != "prim::Constant": |
| parsed_dtype = symbolic_helper._get_const(dtype, "i", "dtype") |
| return_op = g.op( |
| "Cast", return_op, to_i=_type_utils.JitScalarType(parsed_dtype).onnx_type() |
| ) |
| if is_transpose_required: |
| return_op = g.op("Transpose", return_op, perm_i=axes) |
| return return_op |
| |
| |
| @_onnx_symbolic("aten::_log_softmax") |
| @symbolic_helper.parse_args("v", "i", "i") |
| @_beartype.beartype |
| def _log_softmax(g: jit_utils.GraphContext, input, dim, half_to_float): |
| if ( |
| half_to_float |
| and _type_utils.JitScalarType.from_value( |
| input, _type_utils.JitScalarType.UNDEFINED |
| ) |
| == _type_utils.JitScalarType.HALF |
| ): |
| input = g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.FLOAT) |
| return log_softmax(g, input, dim) |
| |
| |
| @_onnx_symbolic("aten::_convolution") |
| @symbolic_helper.parse_args( |
| "v", "v", "v", "is", "is", "is", "i", "is", "i", "i", "i", "i", "i" |
| ) |
| @_beartype.beartype |
| def _convolution( |
| g: jit_utils.GraphContext, |
| input, |
| weight, |
| bias, |
| stride, |
| padding, |
| dilation, |
| transposed, |
| output_padding, |
| groups, |
| benchmark, |
| deterministic, |
| cudnn_enabled, |
| allow_tf32=None, |
| ): |
| weight_size = symbolic_helper._get_tensor_sizes(weight) |
| try: |
| kernel_shape = weight_size[2:] |
| except Exception: |
| # FIXME(justinchuby): Avoid catching Exception. |
| # Catch a more specific exception instead. |
| kernel_shape = None |
| |
| if kernel_shape is None or any(i is None for i in kernel_shape): |
| raise errors.SymbolicValueError( |
| "Unsupported: ONNX export of convolution for kernel of unknown shape.", |
| input, |
| ) |
| |
| args = [input, weight] |
| # ONNX only supports 1D bias |
| if ( |
| not symbolic_helper._is_none(bias) |
| and symbolic_helper._get_tensor_rank(bias) == 1 |
| ): |
| args.append(bias) |
| |
| kwargs = { |
| "kernel_shape_i": weight_size[2:], |
| "strides_i": stride, |
| # NB: ONNX supports asymmetric padding, whereas PyTorch supports only |
| # symmetric padding |
| "pads_i": padding + padding, |
| "dilations_i": dilation, |
| "group_i": groups, |
| } |
| |
| if any(o != 0 for o in output_padding): |
| # ONNX supports both output_shape and output_padding. they are equivalent expressive. |
| # output_padding is more straightforward, so we use it here. |
| # output_shape = stride * (input_shape - 1) + output_padding + kernel_shape - padding * 2 |
| assert transposed |
| assert len(stride) == len(output_padding) |
| kwargs["output_padding_i"] = output_padding |
| |
| n = g.op("ConvTranspose" if transposed else "Conv", *args, **kwargs) |
| |
| if ( |
| not symbolic_helper._is_none(bias) |
| and symbolic_helper._get_tensor_rank(bias) != 1 |
| ): |
| return g.op("Add", n, bias) |
| else: |
| return n |
| |
| |
| @_onnx_symbolic("aten::_convolution_mode") |
| @symbolic_helper.parse_args( |
| "v", |
| "v", |
| "v", |
| "is", |
| "s", |
| "is", |
| "i", |
| ) |
| @_beartype.beartype |
| def _convolution_mode( |
| g: jit_utils.GraphContext, |
| input, |
| weight, |
| bias, |
| stride, |
| padding, |
| dilation, |
| groups, |
| ): |
| weight_size = symbolic_helper._get_tensor_sizes(weight) |
| try: |
| kernel_shape = weight_size[2:] |
| except Exception: |
| # FIXME(justinchuby): Avoid catching Exception. |
| # Catch a more specific exception instead. |
| kernel_shape = None |
| |
| if kernel_shape is None or any(i is None for i in kernel_shape): |
| raise errors.SymbolicValueError( |
| "Unsupported: ONNX export of convolution for kernel of unknown shape.", |
| input, |
| ) |
| |
| args = [input, weight] |
| # ONNX only supports 1D bias |
| if ( |
| not symbolic_helper._is_none(bias) |
| and symbolic_helper._get_tensor_rank(bias) == 1 |
| ): |
| args.append(bias) |
| |
| if padding == "valid": |
| padding = "VALID" |
| elif padding == "same": |
| padding = "SAME_UPPER" |
| kwargs = { |
| "kernel_shape_i": weight_size[2:], |
| "strides_i": stride, |
| "auto_pad_s": padding, |
| "dilations_i": dilation, |
| "group_i": groups, |
| } |
| |
| n = g.op("Conv", *args, **kwargs) |
| |
| if ( |
| not symbolic_helper._is_none(bias) |
| and symbolic_helper._get_tensor_rank(bias) != 1 |
| ): |
| return g.op("Add", n, bias) |
| else: |
| return n |
| |
| |
| @_onnx_symbolic("aten::convolution") |
| @symbolic_helper.parse_args("v", "v", "v", "is", "is", "is", "i", "is", "i") |
| @_beartype.beartype |
| def convolution( |
| g: jit_utils.GraphContext, |
| input, |
| weight, |
| bias, |
| stride, |
| padding, |
| dilation, |
| transposed, |
| output_padding, |
| groups, |
| ): |
| return _convolution( |
| g, |
| input, |
| weight, |
| bias, |
| stride, |
| padding, |
| dilation, |
| transposed, |
| output_padding, |
| groups, |
| None, |
| None, |
| None, |
| None, |
| ) |
| |
| |
| @_onnx_symbolic("aten::conv1d") |
| @symbolic_helper.parse_args("v", "v", "v", "is", "v", "is", "i") |
| @_beartype.beartype |
| def conv1d( |
| g: jit_utils.GraphContext, input, weight, bias, stride, padding, dilation, groups |
| ): |
| str_padding = symbolic_helper._parse_arg(padding, "s") |
| if str_padding in ["valid", "same"]: |
| return _convolution_mode( |
| g, |
| input, |
| weight, |
| bias, |
| stride, |
| str_padding, |
| dilation, |
| groups, |
| ) |
| else: |
| padding = symbolic_helper._parse_arg(padding, "is") |
| return _convolution( |
| g, |
| input, |
| weight, |
| bias, |
| stride, |
| padding, |
| dilation, |
| False, |
| (), |
| groups, |
| None, |
| None, |
| None, |
| None, |
| ) |
| |
| |
| @_onnx_symbolic("aten::conv2d") |
| @symbolic_helper.parse_args("v", "v", "v", "is", "v", "is", "i") |
| @_beartype.beartype |
| def conv2d( |
| g: jit_utils.GraphContext, input, weight, bias, stride, padding, dilation, groups |
| ): |
| str_padding = symbolic_helper._parse_arg(padding, "s") |
| if str_padding in ["valid", "same"]: |
| return _convolution_mode( |
| g, |
| input, |
| weight, |
| bias, |
| stride, |
| str_padding, |
| dilation, |
| groups, |
| ) |
| else: |
| padding = symbolic_helper._parse_arg(padding, "is") |
| return _convolution( |
| g, |
| input, |
| weight, |
| bias, |
| stride, |
| padding, |
| dilation, |
| False, |
| (), |
| groups, |
| None, |
| None, |
| None, |
| None, |
| ) |
| |
| |
| @_onnx_symbolic("aten::conv3d") |
| @symbolic_helper.parse_args("v", "v", "v", "is", "v", "is", "i") |
| @_beartype.beartype |
| def conv3d( |
| g: jit_utils.GraphContext, input, weight, bias, stride, padding, dilation, groups |
| ): |
| str_padding = symbolic_helper._parse_arg(padding, "s") |
| if str_padding in ["valid", "same"]: |
| return _convolution_mode( |
| g, |
| input, |
| weight, |
| bias, |
| stride, |
| str_padding, |
| dilation, |
| groups, |
| ) |
| else: |
| padding = symbolic_helper._parse_arg(padding, "is") |
| return _convolution( |
| g, |
| input, |
| weight, |
| bias, |
| stride, |
| padding, |
| dilation, |
| False, |
| (), |
| groups, |
| None, |
| None, |
| None, |
| None, |
| ) |
| |
| |
| @_onnx_symbolic("aten::conv_transpose1d") |
| @symbolic_helper.parse_args("v", "v", "v", "is", "is", "is", "i", "is") |
| @_beartype.beartype |
| def conv_transpose1d( |
| g: jit_utils.GraphContext, |
| input, |
| weight, |
| bias, |
| stride, |
| padding, |
| output_padding, |
| groups, |
| dilation, |
| ): |
| return _convolution( |
| g, |
| input, |
| weight, |
| bias, |
| stride, |
| padding, |
| dilation, |
| True, |
| output_padding, |
| groups, |
| None, |
| None, |
| None, |
| None, |
| ) |
| |
| |
| @_onnx_symbolic("aten::conv_transpose2d") |
| @symbolic_helper.parse_args("v", "v", "v", "is", "is", "is", "i", "is") |
| @_beartype.beartype |
| def conv_transpose2d( |
| g: jit_utils.GraphContext, |
| input, |
| weight, |
| bias, |
| stride, |
| padding, |
| output_padding, |
| groups, |
| dilation, |
| ): |
| return _convolution( |
| g, |
| input, |
| weight, |
| bias, |
| stride, |
| padding, |
| dilation, |
| True, |
| output_padding, |
| groups, |
| None, |
| None, |
| None, |
| None, |
| ) |
| |
| |
| @_onnx_symbolic("aten::conv_transpose3d") |
| @symbolic_helper.parse_args("v", "v", "v", "is", "is", "is", "i", "is") |
| @_beartype.beartype |
| def conv_transpose3d( |
| g: jit_utils.GraphContext, |
| input, |
| weight, |
| bias, |
| stride, |
| padding, |
| output_padding, |
| groups, |
| dilation, |
| ): |
| return _convolution( |
| g, |
| input, |
| weight, |
| bias, |
| stride, |
| padding, |
| dilation, |
| True, |
| output_padding, |
| groups, |
| None, |
| None, |
| None, |
| None, |
| ) |
| |
| |
| @_onnx_symbolic("aten::batch_norm") |
| @symbolic_helper.parse_args("v", "v", "v", "v", "v", "i", "f", "f", "i") |
| @_beartype.beartype |
| def batch_norm( |
| g: jit_utils.GraphContext, |
| input, |
| weight, |
| bias, |
| running_mean, |
| running_var, |
| training, |
| momentum, |
| eps, |
| cudnn_enabled, |
| ): |
| symbolic_helper.check_training_mode(training, "batch_norm") |
| |
| if ( |
| torch.is_autocast_enabled() |
| and not symbolic_helper.args_have_same_dtype( |
| [input, weight, bias, running_mean, running_var] |
| ) |
| and GLOBALS.export_onnx_opset_version < 15 |
| ): |
| return symbolic_helper._onnx_opset_unsupported_detailed( |
| "BatchNormalization", |
| 9, |
| 15, |
| "All input tensors must have the same `dtype`." |
| " Turn off Autocast or export using opset version 15.", |
| input, |
| ) |
| |
| weight, bias, running_mean, running_var = symbolic_helper._batchnorm_helper( |
| g, input, weight, bias, running_mean, running_var |
| ) |
| out = g.op( |
| "BatchNormalization", |
| input, |
| weight, |
| bias, |
| running_mean, |
| running_var, |
| epsilon_f=eps, |
| momentum_f=1 - momentum, |
| outputs=1 if not training else 5, |
| ) |
| if not training: |
| return out |
| else: |
| res, new_running_mean, new_running_var, saved_mean, saved_var = out |
| new_running_mean.setType(running_mean.type()) |
| new_running_var.setType(running_var.type()) |
| saved_mean.setDebugName("batch_norm_dead_output-" + saved_mean.debugName()) |
| saved_var.setDebugName("batch_norm_dead_output-" + saved_var.debugName()) |
| return res |
| |
| |
| @_onnx_symbolic("aten::native_layer_norm") |
| @symbolic_helper.quantized_args(True, False, False, False) |
| @symbolic_helper.parse_args("v", "is", "v", "v", "f") |
| @_beartype.beartype |
| def native_layer_norm( |
| g: jit_utils.GraphContext, |
| input: _C.Value, |
| normalized_shape: Sequence[int], |
| weight: _C.Value, |
| bias: _C.Value, |
| eps: float, |
| ) -> Tuple[_C.Value, _C.Value, _C.Value]: |
| axes = [-i for i in range(len(normalized_shape), 0, -1)] |
| |
| two_cst = symbolic_helper._generate_wrapped_number(g, 2.0) |
| eps_cst = symbolic_helper._generate_wrapped_number(g, eps) |
| |
| mean = g.op("ReduceMean", input, axes_i=axes) |
| numerator = sub(g, input, mean) |
| |
| # Cast it to eps dtype to avoid precision loss |
| is_type_half = ( |
| _type_utils.JitScalarType.from_value(numerator) |
| == _type_utils.JitScalarType.HALF |
| ) |
| if is_type_half: |
| eps_dtype = _type_utils.JitScalarType.from_value(eps_cst) |
| numerator = g.op( |
| "Cast", numerator, to_i=_type_utils.JitScalarType(eps_dtype).onnx_type() |
| ) |
| |
| # variance = e((x - e(x))^2), and (x - e(x)) is the numerator in the layer_norm formula |
| variance = g.op("ReduceMean", pow(g, numerator, two_cst), axes_i=axes) |
| denominator = sqrt(g, g.op("Add", variance, eps_cst)) |
| normalized = g.op("Div", numerator, denominator) |
| |
| # Cast back to input type as eps related ops are all done |
| if is_type_half: |
| input_dtype = _type_utils.JitScalarType.from_value(input) |
| normalized = g.op( |
| "Cast", normalized, to_i=_type_utils.JitScalarType(input_dtype).onnx_type() |
| ) |
| |
| if not (weight is None or symbolic_helper._is_none(weight)): |
| normalized = mul(g, normalized, weight) |
| if not (bias is None or symbolic_helper._is_none(bias)): |
| normalized = add(g, normalized, bias) |
| |
| # rdenominator := 1 / sqrt(variance + eps) |
| # According to aten::native_layer_norm, rdenominator should have the same dtype as input, |
| # mean and normalized, so we need to Cast it back |
| if is_type_half: |
| denominator = g.op( |
| "Cast", denominator, to_i=_type_utils.JitScalarType(input_dtype).onnx_type() |
| ) |
| rdenominator = g.op("Reciprocal", denominator) |
| else: |
| rdenominator = reciprocal(g, denominator) |
| |
| return normalized, mean, rdenominator |
| |
| |
| @_onnx_symbolic("aten::layer_norm") |
| @symbolic_helper.quantized_args(True, False, False, False) |
| @symbolic_helper.parse_args("v", "is", "v", "v", "f", "b") |
| @_beartype.beartype |
| def layer_norm( |
| g: jit_utils.GraphContext, |
| input: _C.Value, |
| normalized_shape: Sequence[int], |
| weight: _C.Value, |
| bias: _C.Value, |
| eps: float, |
| cudnn_enable: bool, |
| ) -> _C.Value: |
| if symbolic_helper.is_caffe2_aten_fallback(): |
| return g.at( |
| "layer_norm", |
| input, |
| weight, |
| bias, |
| normalized_shape_i=normalized_shape, |
| eps_f=eps, |
| cudnn_enable_i=cudnn_enable, |
| ) |
| normalized, _, _ = native_layer_norm(g, input, normalized_shape, weight, bias, eps) |
| return normalized |
| |
| |
| @_onnx_symbolic("aten::instance_norm") |
| @symbolic_helper.parse_args("v", "v", "v", "v", "v", "b", "f", "f", "b") |
| @_beartype.beartype |
| def instance_norm( |
| g: jit_utils.GraphContext, |
| input, |
| weight, |
| bias, |
| running_mean, |
| running_var, |
| use_input_stats: bool, |
| momentum: Number, |
| eps: Number, |
| cudnn_enabled: bool, |
| ): |
| symbolic_helper.check_training_mode(use_input_stats, "instance_norm") |
| channel_size = symbolic_helper._get_tensor_dim_size(input, 1) |
| if weight is None or symbolic_helper._is_none(weight): |
| if channel_size is None: |
| raise errors.SymbolicValueError( |
| "Unsupported: ONNX export of instance_norm for unknown channel size.", |
| input, |
| ) |
| weight_value = torch.tensor( |
| [1.0] * channel_size, |
| dtype=_type_utils.JitScalarType.from_value(input).dtype(), |
| ) |
| weight = g.op("Constant", value_t=weight_value) |
| if bias is None or symbolic_helper._is_none(bias): |
| if channel_size is None: |
| raise errors.SymbolicValueError( |
| "Unsupported: ONNX export of instance_norm for unknown channel size.", |
| input, |
| ) |
| bias_value = torch.tensor( |
| [0.0] * channel_size, |
| dtype=_type_utils.JitScalarType.from_value(input).dtype(), |
| ) |
| bias = g.op("Constant", value_t=bias_value) |
| if ( |
| running_mean is None |
| or symbolic_helper._is_none(running_mean) |
| or running_var is None |
| or symbolic_helper._is_none(running_var) |
| ): |
| return g.op("InstanceNormalization", input, weight, bias, epsilon_f=eps) |
| else: |
| input_size = symbolic_helper._get_tensor_sizes(input) |
| # If input shape is [N, C, H, W], reshape to [1, N * C, H, W] and call batch_norm. |
| # For more information instance_norm(): |
| # https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Normalization.cpp#L542 |
| input_size_reshape = input_size.copy() |
| n = input_size[0] |
| if n is None: |
| raise errors.SymbolicValueError( |
| "Unsupported: ONNX export of instance_norm training for unknown " |
| "batch size.", |
| input, |
| ) |
| c = input_size[1] |
| input_size_reshape[0] = 1 |
| input_size_reshape[1] = n * c |
| weight_ = repeat( |
| g, weight, g.op("Constant", value_t=torch.tensor([n], dtype=torch.int64)) |
| ) |
| bias_ = repeat( |
| g, bias, g.op("Constant", value_t=torch.tensor([n], dtype=torch.int64)) |
| ) |
| running_mean_ = repeat( |
| g, |
| running_mean, |
| g.op("Constant", value_t=torch.tensor([n], dtype=torch.int64)), |
| ) |
| running_var_ = repeat( |
| g, |
| running_var, |
| g.op("Constant", value_t=torch.tensor([n], dtype=torch.int64)), |
| ) |
| input_reshaped = g.op( |
| "Reshape", |
| input, |
| g.op("Constant", value_t=torch.LongTensor(input_size_reshape)), |
| ) |
| out = batch_norm( |
| g, |
| input_reshaped, |
| weight_, |
| bias_, |
| running_mean_, |
| running_var_, |
| use_input_stats, |
| momentum, |
| eps, |
| cudnn_enabled, |
| ) |
| return view(g, out, g.op("Constant", value_t=torch.tensor(input_size))) |
| |
| |
| @_onnx_symbolic("aten::unfold") |
| @symbolic_helper.parse_args("v", "i", "i", "i") |
| @_beartype.beartype |
| def unfold(g: jit_utils.GraphContext, input, dimension, size, step): |
| if symbolic_helper.is_caffe2_aten_fallback(): |
| return g.at("unfold", input, dimension_i=dimension, size_i=size, step_i=step) |
| sizes = symbolic_helper._get_tensor_sizes(input) |
| # FIXME(justinchuby): Get rid of the try catch here to improve readability |
| try: |
| sizedim = sizes[dimension] |
| except Exception: |
| # FIXME(justinchuby): Avoid catching Exception. |
| # Catch a more specific exception instead. |
| sizedim = None |
| if sizedim is not None: |
| low_indices = range(0, sizedim, step) |
| hi_indices = range(size, sizedim + 1, step) |
| stack = [ |
| symbolic_helper._slice_helper( |
| g, input, axes=[dimension], starts=[low], ends=[hi] |
| ) |
| for low, hi in zip(low_indices, hi_indices) |
| ] |
| ndim = len(sizes) |
| perm = list(range(0, ndim)) |
| perm.append(perm.pop(dimension)) |
| unsqueeze = [ |
| symbolic_helper._unsqueeze_helper( |
| g, g.op("Transpose", t, perm_i=perm), [dimension] |
| ) |
| for t in stack |
| ] |
| return g.op("Concat", *unsqueeze, axis_i=dimension) |
| else: |
| return symbolic_helper._unimplemented( |
| "Unfold", "input size not accessible", input |
| ) |
| |
| |
| @_onnx_symbolic("aten::elu") |
| @symbolic_helper.quantized_args(True) |
| @symbolic_helper.parse_args("v", "t", "t", "t") |
| @_beartype.beartype |
| def elu(g: jit_utils.GraphContext, input, alpha, scale, input_scale): |
| if scale and scale != 1.0: |
| return symbolic_helper._unimplemented( |
| "scale", "does not support scale in Elu", scale |
| ) |
| if input_scale and input_scale != 1.0: |
| return symbolic_helper._unimplemented( |
| "input_scale", "does not support input_scale in Elu", input_scale |
| ) |
| # See Note [Export inplace] |
| return g.op("Elu", input, alpha_f=symbolic_helper._scalar(alpha)) |
| |
| |
| @_onnx_symbolic("aten::selu") |
| @symbolic_helper.quantized_args(True) |
| @_beartype.beartype |
| def selu(g: jit_utils.GraphContext, input): |
| return g.op("Selu", input) |
| |
| |
| @_onnx_symbolic("aten::index_select") |
| @symbolic_helper.parse_args("v", "i", "v") |
| @_beartype.beartype |
| def index_select(g: jit_utils.GraphContext, self, dim, index): |
| # In case of a scalar index, index_select returns a tensor with the same rank as the input. |
| # To match this behavior in ONNX, we make index a 1D tensor so that the following gather |
| # also produces a tensor with the same rank as the input. |
| return symbolic_helper._select_helper(g, self, dim, index) |
| |
| |
| @_onnx_symbolic("aten::index_put") |
| @_beartype.beartype |
| def index_put(g: jit_utils.GraphContext, self, indices_list_value, values, accumulate): |
| if symbolic_helper._is_packed_list(indices_list_value): |
| indices_list = symbolic_helper._unpack_list(indices_list_value) |
| else: |
| indices_list = [indices_list_value] |
| if symbolic_helper.is_caffe2_aten_fallback(): |
| args = [self] + indices_list + [values, accumulate] |
| return g.at("index_put", *args) |
| |
| accumulate = symbolic_helper._parse_arg(accumulate, "b") |
| |
| if len(indices_list) == 0: |
| if accumulate: |
| return add(g, self, values) |
| return values |
| symbolic_helper._onnx_opset_unsupported("index_put", 9, 11, self) |
| |
| |
| @_onnx_symbolic("aten::index_fill") |
| @_beartype.beartype |
| def index_fill(g: jit_utils.GraphContext, self, dim, index, value): |
| dim_value = symbolic_helper._parse_arg(dim, "i") |
| if symbolic_helper.is_caffe2_aten_fallback(): |
| return g.at( |
| "index_fill", |
| self, |
| index, |
| value, |
| overload_name="int_Scalar", |
| dim_i=dim_value, |
| ) |
| |
| expanded_index_shape, expanded_index = symbolic_helper._index_fill_reshape_helper( |
| g, self, dim, index |
| ) |
| value = symbolic_helper._maybe_get_scalar(value) |
| value = symbolic_helper._if_scalar_type_as(value, self) |
| expanded_value = expand(g, value, expanded_index_shape, None) |
| |
| return scatter(g, self, dim, expanded_index, expanded_value) |
| |
| |
| @_onnx_symbolic("aten::index_copy") |
| @_beartype.beartype |
| def index_copy(g: jit_utils.GraphContext, self, dim, index, source): |
| dim_value = symbolic_helper._parse_arg(dim, "i") |
| if symbolic_helper.is_caffe2_aten_fallback(): |
| return g.at("index_copy", self, index, source, dim_i=dim_value) |
| expanded_index_shape, expanded_index = symbolic_helper._index_fill_reshape_helper( |
| g, self, dim, index |
| ) |
| return scatter(g, self, dim, expanded_index, source) |
| |
| |
| @_onnx_symbolic("aten::bucketize") |
| @symbolic_helper.parse_args("v", "v", "b", "b") |
| @_beartype.beartype |
| def bucketize( |
| g: jit_utils.GraphContext, self, boundaries, out_int32=False, right=False |
| ): |
| out_type = _C_onnx.TensorProtoDataType.INT64 |
| if out_int32: |
| out_type = _C_onnx.TensorProtoDataType.INT32 |
| # A tensor expanded_boundaries is created such that it |
| # contains a copy of boundaries for each element of self. |
| new_shape = g.op("Concat", g.op("Shape", boundaries), g.op("Shape", self), axis_i=0) |
| # Unsqueeze step is performed to respect ONNX's numpy style broadcasting for comparison ops |
| # https://github.com/onnx/onnx/blob/main/docs/Broadcasting.md |
| tensor_rank = symbolic_helper._get_tensor_rank(self) |
| assert tensor_rank is not None |
| unsqueeze_axes = list(range(1, tensor_rank + 1)) |
| expanded_boundaries = expand( |
| g, |
| symbolic_helper._unsqueeze_helper(g, boundaries, unsqueeze_axes), |
| new_shape, |
| None, |
| ) |
| # Compare each element of self to boundaries to get a tensor |
| # with leading 1s and trailing 0s. |
| # e.g., 4 > [1, 3, 4] = [1, 1, 0] |
| # The index of the last 1 is the bucket where the element should go. |
| if right: |
| cond = ge(g, self, expanded_boundaries) |
| else: |
| cond = gt(g, self, expanded_boundaries) |
| cond_out = g.op("Cast", cond, to_i=out_type) |
| # Sum to get the number of 1s corresponding to each element, |
| # which is the same as the bucket index. |
| # e.g., sum(4 > [1, 3, 4]) = sum([1, 1, 0]) = 2 |
| return symbolic_helper._reducesum_helper(g, cond_out, axes_i=[0], keepdims_i=0) |
| |
| |
| @_onnx_symbolic("aten::type_as") |
| @_beartype.beartype |
| def type_as(g: jit_utils.GraphContext, self, other): |
| self_dtype = symbolic_helper._try_get_scalar_type(self) |
| other_dtype = symbolic_helper._try_get_scalar_type(other) |
| if self_dtype == other_dtype and self_dtype is not None: |
| return self |
| if other_dtype is not None: |
| return g.op( |
| "Cast", |
| self, |
| to_i=other_dtype.onnx_type(), |
| ) |
| |
| if symbolic_helper.is_caffe2_aten_fallback(): |
| # We don't know the type of other, bail by emitting ATen |
| return g.at("type_as", self, other) |
| |
| raise errors.SymbolicValueError( |
| "Unsupported: ONNX export of type_as for tensor " |
| "of unknown dtype. Please check if the dtype of the " |
| "parameter passed to the type_as function is correct.", |
| other, |
| ) |
| |
| |
| @_onnx_symbolic("aten::cosine_similarity") |
| @symbolic_helper.parse_args("v", "v", "i", "f") |
| @_beartype.beartype |
| def cosine_similarity(g: jit_utils.GraphContext, x1, x2, dim, eps): |
| if symbolic_helper.is_caffe2_aten_fallback(): |
| return g.at("cosine_similarity", x1, x2, dim_i=dim, eps_f=eps) |
| cross = symbolic_helper._reducesum_helper( |
| g, mul(g, x1, x2), axes_i=[dim], keepdims_i=0 |
| ) |
| x1_l2 = symbolic_helper._reducesum_helper( |
| g, mul(g, x1, x1), axes_i=[dim], keepdims_i=0 |
| ) |
| x2_l2 = symbolic_helper._reducesum_helper( |
| g, mul(g, x2, x2), axes_i=[dim], keepdims_i=0 |
| ) |
| div_tens = max( |
| g, sqrt(g, mul(g, x1_l2, x2_l2)), g.op("Constant", value_t=torch.tensor([eps])) |
| ) |
| return div(g, cross, div_tens) |
| |
| |
| @_onnx_symbolic("aten::pairwise_distance") |
| @_beartype.beartype |
| def pairwise_distance(g: jit_utils.GraphContext, input1, input2, p, eps, keepdim): |
| if not symbolic_helper._is_value(eps): |
| eps = g.op("Constant", value_t=torch.tensor([eps])) |
| inv_p = div( |
| g, |
| g.op("Constant", value_t=torch.tensor([1], dtype=torch.float)), |
| add(g, p, eps), |
| ) |
| summation = symbolic_helper._reducesum_helper( |
| g, |
| pow(g, sub(g, input1, input2), p), |
| axes_i=[-1], |
| keepdims_i=symbolic_helper._parse_arg(keepdim, "i"), |
| ) |
| return pow(g, summation, inv_p) |
| |
| |
| @_onnx_symbolic("aten::clone") |
| # ignore clone operators that are inserted by PyTorch autograd |
| @_beartype.beartype |
| def clone(g: jit_utils.GraphContext, input, unused_memory_format): |
| return input |
| |
| |
| @_onnx_symbolic("aten::abs") |
| @_beartype.beartype |
| def abs(g: jit_utils.GraphContext, self): |
| return g.op("Abs", self) |
| |
| |
| @_onnx_symbolic("aten::log") |
| @_beartype.beartype |
| def log(g: jit_utils.GraphContext, self): |
| return g.op("Log", self) |
| |
| |
| @_onnx_symbolic("aten::log1p") |
| @_beartype.beartype |
| def log1p(g: jit_utils.GraphContext, self): |
| return log(g, add(g, symbolic_helper._if_scalar_type_as(torch.ones(1), self), self)) |
| |
| |
| @_onnx_symbolic("aten::log10") |
| @_beartype.beartype |
| def log10(g: jit_utils.GraphContext, self): |
| _ln10 = 2.30258509299404568401 |
| return g.op("Div", log(g, self), g.op("Constant", value_t=torch.tensor([_ln10]))) |
| |
| |
| @_onnx_symbolic("aten::pow") |
| @_beartype.beartype |
| def pow(g: jit_utils.GraphContext, self, exponent): |
| f_dtype = _type_utils.JitScalarType.from_value(self) |
| if not symbolic_helper._is_fp(self): |
| f_dtype = _type_utils.JitScalarType.FLOAT |
| self = g.op("Cast", self, to_i=f_dtype.onnx_type()) |
| if not symbolic_helper._is_fp(exponent): |
| exponent = g.op( |
| "Cast", |
| exponent, |
| to_i=f_dtype.onnx_type(), |
| ) |
| pow = g.op("Pow", self, exponent) |
| return pow |
| |
| |
| @_onnx_symbolic("aten::clamp") |
| @_beartype.beartype |
| def clamp(g: jit_utils.GraphContext, self, min, max): |
| # min or max may be None that we need to dispatch to |
| # Clip separately, as ONNX does not have None syntax |
| if symbolic_helper._is_none(min): |
| return clamp_max(g, self, max) |
| elif symbolic_helper._is_none(max): |
| return clamp_min(g, self, min) |
| else: |
| if symbolic_helper._is_constant(min) and symbolic_helper._is_constant(max): |
| return _op_with_optional_float_cast( |
| g, |
| "Clip", |
| self, |
| min_f=symbolic_helper._parse_arg(min, "f"), |
| max_f=symbolic_helper._parse_arg(max, "f"), |
| opset_before=12, |
| ) |
| else: |
| return clamp_max(g, clamp_min(g, self, min), max) |
| |
| |
| @_onnx_symbolic("aten::clamp_min") |
| @symbolic_helper.parse_args("v", "v") |
| @_beartype.beartype |
| def clamp_min(g: jit_utils.GraphContext, self, min): |
| if symbolic_helper._is_constant(min): |
| return _op_with_optional_float_cast( |
| g, "Clip", self, min_f=symbolic_helper._parse_arg(min, "f"), opset_before=12 |
| ) |
| else: |
| dtype = _type_utils.JitScalarType.from_value(self) |
| min = g.op("Cast", min, to_i=dtype.onnx_type()) |
| return _op_with_optional_float_cast(g, "Max", self, min, opset_before=12) |
| |
| |
| @_onnx_symbolic("aten::clamp_max") |
| @symbolic_helper.parse_args("v", "v") |
| @_beartype.beartype |
| def clamp_max(g: jit_utils.GraphContext, self, max): |
| if symbolic_helper._is_constant(max): |
| return _op_with_optional_float_cast( |
| g, "Clip", self, max_f=symbolic_helper._parse_arg(max, "f"), opset_before=12 |
| ) |
| else: |
| dtype = _type_utils.JitScalarType.from_value(self) |
| max = g.op("Cast", max, to_i=dtype.onnx_type()) |
| return _op_with_optional_float_cast(g, "Min", self, max, opset_before=12) |
| |
| |
| @_onnx_symbolic("aten::max") |
| # torch.max (same for torch.min) actually has two interfaces smashed together: |
| # torch.max(x, dim, keepdim) and torch.max(x, y) |
| # TODO(justinchuby): Support multiple quantized args in output |
| @_beartype.beartype |
| def max(g: jit_utils.GraphContext, self, dim_or_y=None, keepdim=None): |
| # torch.max(input) |
| if dim_or_y is None and keepdim is None: |
| return g.op("ReduceMax", self, keepdims_i=0) |
| # torch.max(input, other) |
| if keepdim is None: |
| return _op_with_optional_float_cast(g, "Max", self, dim_or_y, opset_before=12) |
| # torch.max(input, dim, keepdim) |
| else: |
| dim = symbolic_helper._get_const(dim_or_y, "i", "dim") |
| keepdim = symbolic_helper._get_const(keepdim, "i", "keepdim") |
| max = g.op("ReduceMax", self, axes_i=[dim], keepdims_i=keepdim) |
| indices = g.op("ArgMax", self, axis_i=dim, keepdims_i=keepdim) |
| return max, indices |
| |
| |
| @_onnx_symbolic("aten::maximum") |
| @symbolic_helper.quantized_args(True, True) |
| @_beartype.beartype |
| def maximum(g: jit_utils.GraphContext, input, other): |
| return max(g, input, dim_or_y=other) |
| |
| |
| @_onnx_symbolic("aten::min") |
| # TODO(justinchuby): Support multiple quantized args in output |
| @_beartype.beartype |
| def min(g: jit_utils.GraphContext, self, dim_or_y=None, keepdim=None): |
| # torch.min(input) |
| if dim_or_y is None and keepdim is None: |
| return g.op("ReduceMin", self, keepdims_i=0) |
| # torch.min(input, other) |
| if keepdim is None: |
| return _op_with_optional_float_cast(g, "Min", self, dim_or_y, opset_before=12) |
| # torch.min(input, dim, keepdim) |
| else: |
| dim = symbolic_helper._get_const(dim_or_y, "i", "dim") |
| keepdim = symbolic_helper._get_const(keepdim, "i", "keepdim") |
| min = g.op("ReduceMin", self, axes_i=[dim], keepdims_i=keepdim) |
| indices = g.op("ArgMin", self, axis_i=dim, keepdims_i=keepdim) |
| return min, indices |
| |
| |
| @_onnx_symbolic("aten::minimum") |
| @symbolic_helper.quantized_args(True, True) |
| @_beartype.beartype |
| def minimum(g: jit_utils.GraphContext, input, other): |
| return min(g, input, dim_or_y=other) |
| |
| |
| @_onnx_symbolic("aten::amax") |
| @symbolic_helper.quantized_args(True) |
| @symbolic_helper.parse_args("v", "is", "i") |
| @_beartype.beartype |
| def amax(g: jit_utils.GraphContext, self, dim, keepdim): |
| return g.op("ReduceMax", self, axes_i=dim, keepdims_i=keepdim) |
| |
| |
| @_onnx_symbolic("aten::amin") |
| @symbolic_helper.quantized_args(True) |
| @symbolic_helper.parse_args("v", "is", "i") |
| @_beartype.beartype |
| def amin(g: jit_utils.GraphContext, self, dim, keepdim): |
| return g.op("ReduceMin", self, axes_i=dim, keepdims_i=keepdim) |
| |
| |
| @_onnx_symbolic("aten::aminmax") |
| @symbolic_helper.quantized_args(True) |
| @symbolic_helper.parse_args("v", "v", "i") |
| @_beartype.beartype |
| def aminmax(g: jit_utils.GraphContext, self, dim, keepdim): |
| reduce_kwargs = {"keepdims_i": keepdim} |
| if not symbolic_helper._is_none(dim): |
| dim = symbolic_helper._get_const(dim, "i", "dim") |
| reduce_kwargs["axes_i"] = [dim] |
| |
| return g.op("ReduceMin", self, **reduce_kwargs), g.op( |
| "ReduceMax", self, **reduce_kwargs |
| ) |
| |
| |
| @_onnx_symbolic("aten::exp") |
| @_beartype.beartype |
| def exp(g: jit_utils.GraphContext, self): |
| return g.op("Exp", self) |
| |
| |
| @_onnx_symbolic("aten::dropout_") |
| @_onnx_symbolic("aten::dropout") |
| @symbolic_helper.parse_args("v", "f", "i") |
| @_beartype.beartype |
| def dropout(g: jit_utils.GraphContext, input, p, train): |
| symbolic_helper.check_training_mode(train, "dropout") |
| # if train is False, dropout is no-op |
| if not train: |
| return input |
| r, _ = g.op("Dropout", input, ratio_f=p, outputs=2) |
| return r |
| |
| |
| @_onnx_symbolic( |
| "aten::alpha_dropout_", decorate=[_apply_params("aten::alpha_dropout_")] |
| ) # See Note [Export inplace] |
| @_onnx_symbolic( |
| "aten::feature_alpha_dropout_", |
| decorate=[_apply_params("aten::feature_alpha_dropout_")], |
| ) |
| @_onnx_symbolic( |
| "aten::feature_dropout_", decorate=[_apply_params("aten::feature_dropout_")] |
| ) |
| @_onnx_symbolic( |
| "aten::feature_alpha_dropout", |
| decorate=[_apply_params("aten::feature_alpha_dropout")], |
| ) |
| @_onnx_symbolic("aten::alpha_dropout", decorate=[_apply_params("aten::alpha_dropout")]) |
| @_onnx_symbolic( |
| "aten::feature_dropout", decorate=[_apply_params("aten::feature_dropout")] |
| ) |
| @_beartype.beartype |
| def _unsupported_dropout(name: str): |
| @symbolic_helper.parse_args("v", "none", "b") |
| @_beartype.beartype |
| def feature_dropout(g, input, p, train): |
| # NB: In inference mode, FeatureDropout is exported as an identity op. |
| if train: |
| return symbolic_helper._unimplemented(name, "training mode", input) |
| return input |
| |
| return feature_dropout |
| |
| |
| @_onnx_symbolic("aten::norm") |
| @symbolic_helper.parse_args("v", "t", "is", "i", "v") |
| @_beartype.beartype |
| def norm(g: jit_utils.GraphContext, self, p, dim, keepdim, dtype=None): |
| if p == 1: |
| f = _reduce_op_symbolic("ReduceL1") |
| elif p == 2: |
| f = _reduce_op_symbolic("ReduceL2") |
| else: |
| raise errors.SymbolicValueError( |
| "ONNX export only p-norms with p of 1 or 2", self |
| ) |
| result = f(g, self, dim=dim, keepdim=keepdim) |
| if dtype is not None: |
| dtype = symbolic_helper._get_const(dtype, "i", "dtype") |
| result = g.op("Cast", result, to_i=_type_utils.JitScalarType(dtype).onnx_type()) |
| return result |
| |
| |
| @_onnx_symbolic("aten::conv_tbc") |
| @symbolic_helper.parse_args("v", "v", "v", "i") |
| @_beartype.beartype |
| def conv_tbc(g: jit_utils.GraphContext, input, weight, bias, pad): |
| if symbolic_helper.is_caffe2_aten_fallback(): |
| return g.at("conv_tbc", input, weight, bias, pad_i=pad) |
| else: |
| # input must have 3 dimensions, see: |
| # https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/ConvolutionTBC.cpp#L8-L10 |
| # input = (time, batch, in_channels) |
| # weight = (kernel_width, in_channels, out_channels) |
| # bias = (out_channels,) |
| input = g.op("Transpose", input, perm_i=[1, 2, 0]) |
| weight = g.op("Transpose", weight, perm_i=[2, 1, 0]) |
| conv = conv1d(g, input, weight, bias, [1], [pad], [1], 1) |
| return g.op("Transpose", conv, perm_i=[2, 0, 1]) |
| |
| |
| @_onnx_symbolic("aten::_unique") |
| @symbolic_helper.parse_args("v", "i", "i") |
| @_beartype.beartype |
| def _unique(g: jit_utils.GraphContext, input, sorted, return_inverse): |
| if symbolic_helper.is_caffe2_aten_fallback(): |
| return g.at( |
| "_unique", |
| input, |
| sorted_i=sorted, |
| return_inverse_i=return_inverse, |
| outputs=2, |
| ) |
| else: |
| return symbolic_helper._onnx_unsupported("_unique", input) |
| |
| |
| @_onnx_symbolic("aten::_unique2") |
| @symbolic_helper.parse_args("v", "i", "i", "i") |
| @_beartype.beartype |
| def _unique2(g: jit_utils.GraphContext, input, sorted, return_inverse, return_counts): |
| if symbolic_helper.is_caffe2_aten_fallback(): |
| return g.at( |
| "_unique2", |
| input, |
| sorted_i=sorted, |
| return_inverse_i=return_inverse, |
| return_counts_i=return_counts, |
| outputs=3, |
| ) |
| |
| symbolic_helper._onnx_opset_unsupported("_unique2", 9, 11, input) |
| |
| |
| @_onnx_symbolic("aten::_cast_Byte") |
| @_deprecation.deprecated( |
| "2.0", |
| "the future", |
| "Avoid using this function and create a Cast node instead", |
| ) |
| @_beartype.beartype |
| def _cast_Byte(g: jit_utils.GraphContext, input, non_blocking): |
| return g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.UINT8) |
| |
| |
| @_onnx_symbolic("aten::_cast_Char") |
| @_deprecation.deprecated( |
| "2.0", |
| "the future", |
| "Avoid using this function and create a Cast node instead", |
| ) |
| @_beartype.beartype |
| def _cast_Char(g: jit_utils.GraphContext, input, non_blocking): |
| return g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.INT8) |
| |
| |
| @_onnx_symbolic("aten::_cast_Short") |
| @_deprecation.deprecated( |
| "2.0", |
| "the future", |
| "Avoid using this function and create a Cast node instead", |
| ) |
| @_beartype.beartype |
| def _cast_Short(g: jit_utils.GraphContext, input, non_blocking): |
| return g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.INT16) |
| |
| |
| @_onnx_symbolic("aten::_cast_Int") |
| @_deprecation.deprecated( |
| "2.0", |
| "the future", |
| "Avoid using this function and create a Cast node instead", |
| ) |
| @_beartype.beartype |
| def _cast_Int(g: jit_utils.GraphContext, input, non_blocking): |
| return g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.INT32) |
| |
| |
| @_onnx_symbolic("aten::_cast_Long") |
| @_deprecation.deprecated( |
| "2.0", |
| "the future", |
| "Avoid using this function and create a Cast node instead", |
| ) |
| @_beartype.beartype |
| def _cast_Long(g: jit_utils.GraphContext, input, non_blocking): |
| return g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.INT64) |
| |
| |
| @_onnx_symbolic("aten::_cast_Half") |
| @_deprecation.deprecated( |
| "2.0", |
| "the future", |
| "Avoid using this function and create a Cast node instead", |
| ) |
| @_beartype.beartype |
| def _cast_Half(g: jit_utils.GraphContext, input, non_blocking): |
| return g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.FLOAT16) |
| |
| |
| @_onnx_symbolic("aten::_cast_Float") |
| @_deprecation.deprecated( |
| "2.0", |
| "the future", |
| "Avoid using this function and create a Cast node instead", |
| ) |
| @_beartype.beartype |
| def _cast_Float(g: jit_utils.GraphContext, input, non_blocking): |
| return g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.FLOAT) |
| |
| |
| @_onnx_symbolic("aten::_cast_Double") |
| @_deprecation.deprecated( |
| "2.0", |
| "the future", |
| "Avoid using this function and create a Cast node instead", |
| ) |
| @_beartype.beartype |
| def _cast_Double(g: jit_utils.GraphContext, input, non_blocking): |
| return g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.DOUBLE) |
| |
| |
| @_onnx_symbolic("aten::_cast_Bool") |
| @_deprecation.deprecated( |
| "2.0", |
| "the future", |
| "Avoid using this function and create a Cast node instead", |
| ) |
| @_beartype.beartype |
| def _cast_Bool(g: jit_utils.GraphContext, input, non_blocking): |
| return g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.BOOL) |
| |
| |
| @_onnx_symbolic("aten::empty") |
| @symbolic_helper.parse_args("v", "i", "v", "v", "v", "v") |
| @_beartype.beartype |
| def empty( |
| g: jit_utils.GraphContext, |
| sizes, |
| dtype, |
| layout, |
| device, |
| pin_memory=False, |
| memory_format=None, |
| ): |
| return zeros(g, sizes, dtype, layout, device, pin_memory) |
| |
| |
| @_onnx_symbolic("aten::empty_like") |
| @symbolic_helper.parse_args("v", "i", "v", "v", "v", "v") |
| @_beartype.beartype |
| def empty_like( |
| g: jit_utils.GraphContext, |
| input, |
| dtype=None, |
| layout=None, |
| device=None, |
| pin_memory=False, |
| memory_format=None, |
| ): |
| return zeros_like(g, input, dtype, layout, device, pin_memory) |
| |
| |
| @_onnx_symbolic("aten::new_empty") |
| @_beartype.beartype |
| def new_empty( |
| g: jit_utils.GraphContext, self, sizes, dtype, layout, device, pin_memory=False |
| ): |
| self_dtype = symbolic_helper._try_get_scalar_type(self) |
| if symbolic_helper._is_none(dtype) and self_dtype is not None: |
| dtype = self_dtype |
| return empty(g, sizes, dtype, layout, device, pin_memory) |
| |
| |
| @_onnx_symbolic("aten::scalar_tensor") |
| @_beartype.beartype |
| def scalar_tensor(g: jit_utils.GraphContext, scalar, dtype, *options): |
| dtype = symbolic_helper._get_const(dtype, "i", "dtype") |
| if dtype is None: |
| dtype = _type_utils.JitScalarType.FLOAT |
| scalar = g.op("Cast", scalar, to_i=_type_utils.JitScalarType(dtype).onnx_type()) |
| return scalar |
| |
| |
| @_onnx_symbolic("aten::tensor") |
| @_beartype.beartype |
| def tensor( |
| g: jit_utils.GraphContext, data, dtype=None, device=None, requires_grad=False |
| ): |
| dtype = symbolic_helper._get_const(dtype, "i", "dtype") |
| if symbolic_helper._is_packed_list(data): |
| if dtype is None: |
| dtype = _type_utils.JitScalarType.from_value( |
| symbolic_helper._unpack_list(data)[0] |
| ) |
| input_list = list() |
| for t in symbolic_helper._unpack_list(data): |
| shape_reference = g.op("Constant", value_t=torch.LongTensor([1])) |
| t = symbolic_helper._reshape_helper(g, t, shape_reference) |
| t = g.op("Cast", t, to_i=_type_utils.JitScalarType(dtype).onnx_type()) |
| input_list.append(t) |
| return g.op("Concat", *input_list, axis_i=0) |
| else: |
| if dtype is None: |
| dtype = _type_utils.JitScalarType.from_value(data) |
| if symbolic_helper._is_list(data) and ( |
| symbolic_helper._is_tensor_list(data) |
| or symbolic_helper._is_scalar_list(data) |
| ): |
| data = g.op("ConcatFromSequence", data, axis_i=0, new_axis_i=1) |
| return g.op("Cast", data, to_i=_type_utils.JitScalarType(dtype).onnx_type()) |
| |
| |
| @_onnx_symbolic("aten::as_tensor") |
| @_beartype.beartype |
| def as_tensor(g: jit_utils.GraphContext, data, dtype=None, device=None): |
| return tensor(g, data, dtype, device) |
| |
| |
| @_onnx_symbolic("aten::zeros") |
| @symbolic_helper.parse_args("v", "i", "v", "v", "v") |
| @_beartype.beartype |
| def zeros(g: jit_utils.GraphContext, sizes, dtype, layout, device, pin_memory=False): |
| # NOTE: no way to set device, layout and pin_memory in ONNX, so we ignore it |
| if dtype is None: |
| scalar_type = _type_utils.JitScalarType.FLOAT |
| else: |
| scalar_type = _type_utils.JitScalarType(dtype) |
| sizes_ = symbolic_helper._maybe_get_const(sizes, "is") |
| if isinstance(sizes_, list) and len(sizes_) == 0: |
| sizes = g.op("Constant", value_t=torch.tensor([]).to(torch.int64)) |
| return g.op( |
| "ConstantOfShape", |
| sizes, |
| value_t=torch.tensor([0], dtype=scalar_type.dtype()), |
| ) |
| |
| |
| @_onnx_symbolic("aten::zeros_like") |
| @symbolic_helper.parse_args("v", "i", "v", "v", "v", "v") |
| @_beartype.beartype |
| def zeros_like( |
| g: jit_utils.GraphContext, |
| input, |
| dtype=None, |
| layout=None, |
| device=None, |
| pin_memory=False, |
| memory_format=None, |
| ): |
| shape = g.op("Shape", input) |
| if symbolic_helper._is_none(dtype): |
| scalar_type = _type_utils.JitScalarType.from_value( |
| input, _type_utils.JitScalarType.FLOAT |
| ) |
| else: |
| scalar_type = _type_utils.JitScalarType(dtype) |
| return g.op( |
| "ConstantOfShape", |
| shape, |
| value_t=torch.tensor([0], dtype=scalar_type.dtype()), |
| ) |
| |
| |
| @_onnx_symbolic("aten::new_zeros") |
| @_beartype.beartype |
| def new_zeros( |
| g: jit_utils.GraphContext, self, sizes, dtype, layout, device, pin_memory=False |
| ): |
| self_dtype = symbolic_helper._try_get_scalar_type(self) |
| |
| if symbolic_helper._is_none(dtype) and self_dtype is not None: |
| dtype = self_dtype |
| return zeros(g, sizes, dtype, layout, device, pin_memory) |
| |
| |
| @_onnx_symbolic("aten::zero") |
| @_beartype.beartype |
| def zero(g: jit_utils.GraphContext, self): |
| self_dtype = symbolic_helper._try_get_scalar_type(self) |
| return zeros_like(g, self, self_dtype) |
| |
| |
| @_onnx_symbolic("aten::ones") |
| @symbolic_helper.parse_args("v", "i", "v", "v", "v") |
| @_beartype.beartype |
| def ones(g: jit_utils.GraphContext, sizes, dtype, layout, device, pin_memory=False): |
| if dtype is None: |
| scalar_type = _type_utils.JitScalarType.FLOAT |
| else: |
| scalar_type = _type_utils.JitScalarType(dtype) |
| sizes_ = symbolic_helper._maybe_get_const(sizes, "is") |
| if isinstance(sizes_, list) and len(sizes_) == 0: |
| sizes = g.op("Constant", value_t=torch.tensor([]).to(torch.int64)) |
| return g.op( |
| "ConstantOfShape", |
| sizes, |
| value_t=torch.tensor([1], dtype=scalar_type.dtype()), |
| ) |
| |
| |
| @_onnx_symbolic("aten::ones_like") |
| @symbolic_helper.parse_args("v", "i", "v", "v", "v", "v") |
| @_beartype.beartype |
| def ones_like( |
| g: jit_utils.GraphContext, |
| input, |
| dtype=None, |
| layout=None, |
| device=None, |
| pin_memory=False, |
| memory_format=None, |
| ): |
| shape = g.op("Shape", input) |
| if symbolic_helper._is_none(dtype): |
| scalar_type = _type_utils.JitScalarType.from_value( |
| input, _type_utils.JitScalarType.FLOAT |
| ) |
| else: |
| scalar_type = _type_utils.JitScalarType(dtype) |
| return g.op( |
| "ConstantOfShape", |
| shape, |
| value_t=torch.tensor([1], dtype=scalar_type.dtype()), |
| ) |
| |
| |
| @_onnx_symbolic("aten::new_ones") |
| @_beartype.beartype |
| def new_ones( |
| g: jit_utils.GraphContext, self, sizes, dtype, layout, device, pin_memory=False |
| ): |
| self_dtype = symbolic_helper._try_get_scalar_type(self) |
| if symbolic_helper._is_none(dtype) and self_dtype is not None: |
| dtype = self_dtype |
| return ones(g, sizes, dtype, layout, device, pin_memory) |
| |
| |
| @_onnx_symbolic("aten::full") |
| @_beartype.beartype |
| def full( |
| g: jit_utils.GraphContext, sizes, value, dtype, layout, device, pin_memory=False |
| ): |
| const_value = symbolic_helper._maybe_get_const(value, "t") |
| if symbolic_helper._is_value(const_value): |
| dtype = _type_utils.JitScalarType.FLOAT if dtype is None else dtype |
| tmp = zeros(g, sizes, dtype, layout, device) |
| return add(g, tmp, value, g.op("Constant", value_t=torch.tensor(1))) |
| else: |
| dtype = symbolic_helper._get_const(dtype, "i", "dtype") |
| if dtype is None: |
| scalar_type = _type_utils.JitScalarType.FLOAT |
| else: |
| scalar_type = _type_utils.JitScalarType(dtype) |
| sizes_ = symbolic_helper._maybe_get_const(sizes, "is") |
| if isinstance(sizes_, list) and len(sizes_) == 0: |
| sizes = g.op("Constant", value_t=torch.tensor([]).to(torch.int64)) |
| return g.op( |
| "ConstantOfShape", |
| sizes, |
| value_t=const_value.view(1).to(scalar_type.dtype()), |
| ) |
| |
| |
| @_onnx_symbolic("aten::full_like") |
| @_beartype.beartype |
| def full_like( |
| g: jit_utils.GraphContext, |
| input, |
| fill_value, |
| dtype=None, |
| layout=None, |
| device=None, |
| pin_memory=False, |
| memory_format=None, |
| ): |
| fill_value = symbolic_helper._maybe_get_const(fill_value, "f") |
| dtype = symbolic_helper._get_const(dtype, "i", "dtype") |
| if dtype is None: |
| scalar_type = _type_utils.JitScalarType.from_value( |
| input, _type_utils.JitScalarType.FLOAT |
| ) |
| else: |
| scalar_type = _type_utils.JitScalarType(dtype) |
| if symbolic_helper._is_value(fill_value): |
| tmp = zeros_like(g, input, dtype, layout, device) |
| fill_value = g.op("Cast", fill_value, to_i=scalar_type.onnx_type()) |
| return add(g, tmp, fill_value, g.op("Constant", value_t=torch.tensor(1))) |
| else: |
| shape = g.op("Shape", input) |
| return g.op( |
| "ConstantOfShape", |
| shape, |
| value_t=torch.tensor([fill_value], dtype=scalar_type.dtype()), |
| ) |
| |
| |
| @_onnx_symbolic("aten::new_full") |
| @_beartype.beartype |
| def new_full( |
| g: jit_utils.GraphContext, |
| self, |
| size, |
| fill_value, |
| dtype, |
| layout, |
| device, |
| pin_memory=False, |
| ): |
| self_dtype = symbolic_helper._try_get_scalar_type(self) |
| if symbolic_helper._is_none(dtype) and self_dtype is not None: |
| dtype = self_dtype |
| return full(g, size, fill_value, dtype, layout, device, pin_memory) |
| |
| |
| @_onnx_symbolic("aten::eye") |
| @_beartype.beartype |
| def eye(g: jit_utils.GraphContext, *args): |
| if len(args) == 5: |
| # aten::eye(n, dtype, layout, device, pin_memory) |
| n, dtype, layout, device, pin_memory = args |
| dim_size = symbolic_helper._unsqueeze_helper(g, n, [0]) |
| shape = g.op("Concat", dim_size, dim_size, axis_i=0) |
| tensor = zeros(g, shape, dtype, layout, device) |
| return g.op("EyeLike", tensor) |
| if len(args) == 6: |
| # aten::eye(n, m, dtype, layout, device, pin_memory) |
| n, m, dtype, layout, device, pin_memory = args |
| shape = g.op( |
| "Concat", |
| symbolic_helper._unsqueeze_helper(g, n, [0]), |
| symbolic_helper._unsqueeze_helper(g, m, [0]), |
| axis_i=0, |
| ) |
| tensor = zeros(g, shape, dtype, layout, device) |
| return g.op("EyeLike", tensor) |
| |
| return symbolic_helper._unimplemented("aten::eye", f"with {len(args)} arguments") |
| |
| |
| @_onnx_symbolic("aten::slice") |
| @_beartype.beartype |
| def slice(g: jit_utils.GraphContext, self, *args): |
| if len(args) == 4: |
| # aten::slice(Tensor self, int dim, int start, int end, int step) -> Tensor |
| dim, start, end, step = args |
| step = symbolic_helper._parse_arg(step, "i") |
| if step != 1: |
| raise errors.SymbolicValueError("step!=1 is currently not supported", self) |
| is_start_none = start.node().kind() == "prim::Constant" and isinstance( |
| start.type(), _C.NoneType |
| ) |
| is_end_none = end.node().kind() == "prim::Constant" and isinstance( |
| end.type(), _C.NoneType |
| ) |
| is_start_onnx_const = start.node().kind() == "onnx::Constant" |
| is_end_onnx_const = end.node().kind() == "onnx::Constant" |
| if ( |
| ((not is_start_none) and (not is_start_onnx_const)) |
| or ((not is_end_none) and (not is_end_onnx_const)) |
| or dim.node().kind() != "onnx::Constant" |
| ): |
| if GLOBALS.operator_export_type == _C_onnx.OperatorExportTypes.ONNX: |
| raise errors.SymbolicValueError( |
| "Unsupported: ONNX export of Slice with dynamic inputs. DynamicSlice " |
| "is a deprecated experimental op. Please use statically allocated " |
| "variables or export to a higher opset version.", |
| self, |
| ) |
| else: |
| start_unsqueezed = symbolic_helper._unsqueeze_helper(g, start, [0]) |
| end_unsqueezed = symbolic_helper._unsqueeze_helper(g, end, [0]) |
| dim_unsqueezed = symbolic_helper._unsqueeze_helper(g, dim, [0]) |
| return g.op( |
| "DynamicSlice", |
| self, |
| start_unsqueezed, |
| end_unsqueezed, |
| dim_unsqueezed, |
| ) |
| else: |
| start = 0 if is_start_none else symbolic_helper._parse_arg(start, "i") |
| end = ( |
| _constants.INT64_MAX |
| if is_end_none |
| else symbolic_helper._parse_arg(end, "i") |
| ) |
| dim = symbolic_helper._parse_arg(dim, "i") |
| return symbolic_helper._slice_helper( |
| g, self, axes=[dim], starts=[start], ends=[end] |
| ) |
| elif len(args) == 3: |
| # aten::slice(t[] l, int start, int end, int step) -> t[] |
| start, end, step = args |
| dim = 0 |
| is_start_none = start.node().kind() == "prim::Constant" and isinstance( |
| start.type(), _C.NoneType |
| ) |
| is_end_none = end.node().kind() == "prim::Constant" and isinstance( |
| end.type(), _C.NoneType |
| ) |
| start = 0 if is_start_none else symbolic_helper._parse_arg(start, "i") |
| end = ( |
| _constants.INT64_MAX |
| if is_end_none |
| else symbolic_helper._parse_arg(end, "i") |
| ) |
| return symbolic_helper._slice_helper( |
| g, self, axes=[dim], starts=[start], ends=[end] |
| ) |
| |
| return symbolic_helper._unimplemented("aten::slice", f"with {len(args)} arguments") |
| |
| |
| @_onnx_symbolic("aten::hardtanh") |
| @symbolic_helper.quantized_args(True) |
| @symbolic_helper.parse_args("v", "f", "f") |
| @_beartype.beartype |
| def hardtanh(g: jit_utils.GraphContext, self: _C.Value, min_val: float, max_val: float): |
| return _op_with_optional_float_cast( |
| g, "Clip", self, min_f=min_val, max_f=max_val, opset_before=12 |
| ) |
| |
| |
| @_onnx_symbolic("aten::hardswish") |
| @symbolic_helper.quantized_args(True) |
| @symbolic_helper.parse_args("v") |
| @_beartype.beartype |
| def hardswish(g: jit_utils.GraphContext, self): |
| hs = hardsigmoid(g, self) |
| return g.op("Mul", self, hs) |
| |
| |
| @_onnx_symbolic("aten::hardsigmoid") |
| # Fixed scale and zero_point, discovered from aten/src/ATen/native/quantized/cpu/qhardsigmoid.cpp |
| @symbolic_helper.quantized_args(True, scale=1.0 / 256.0, zero_point=0) |
| @symbolic_helper.parse_args("v") |
| @_beartype.beartype |
| def hardsigmoid(g: jit_utils.GraphContext, self): |
| # Set alpha_f to 1 / 6 to make op equivalent to PyTorch's definition of Hardsigmoid. |
| # See https://pytorch.org/docs/stable/generated/torch.nn.Hardsigmoid.html |
| return g.op("HardSigmoid", self, alpha_f=1 / 6) |
| |
| |
| @_onnx_symbolic("aten::tanhshrink") |
| @symbolic_helper.parse_args("v") |
| @_beartype.beartype |
| def tanhshrink(g: jit_utils.GraphContext, self): |
| return g.op("Sub", self, tanh(g, self)) |
| |
| |
| @_onnx_symbolic("aten::hardshrink") |
| @symbolic_helper.parse_args("v", "f") |
| @_beartype.beartype |
| def hardshrink(g: jit_utils.GraphContext, self, lambd): |
| scalar_type = _type_utils.JitScalarType.from_value( |
| self, _type_utils.JitScalarType.FLOAT |
| ) |
| lambd_op = g.op( |
| "Constant", |
| value_t=torch.tensor(lambd, dtype=scalar_type.dtype()), |
| ) |
| cond = logical_or(g, gt(g, self, lambd_op), lt(g, self, neg(g, lambd_op))) |
| return g.op( |
| "Where", |
| cond, |
| self, |
| g.op( |
| "Constant", |
| value_t=torch.tensor(0, dtype=scalar_type.dtype()), |
| ), |
| ) |
| |
| |
| @_onnx_symbolic("aten::softshrink") |
| @symbolic_helper.parse_args("v", "f") |
| @_beartype.beartype |
| def softshrink(g: jit_utils.GraphContext, self, lambd): |
| scalar_type = _type_utils.JitScalarType.from_value( |
| self, _type_utils.JitScalarType.FLOAT |
| ) |
| lambd_op = g.op( |
| "Constant", |
| value_t=torch.tensor(lambd, dtype=scalar_type.dtype()), |
| ) |
| gt_cond = gt(g, self, lambd_op) |
| gt_out = g.op( |
| "Where", |
| gt_cond, |
| sub(g, self, lambd_op), |
| g.op( |
| "Constant", |
| value_t=torch.tensor(0, dtype=scalar_type.dtype()), |
| ), |
| ) |
| lt_cond = lt(g, self, neg(g, lambd_op)) |
| lt_out = g.op( |
| "Where", |
| lt_cond, |
| add(g, self, lambd_op), |
| g.op( |
| "Constant", |
| value_t=torch.tensor(0, dtype=scalar_type.dtype()), |
| ), |
| ) |
| return add(g, gt_out, lt_out) |
| |
| |
| @_onnx_symbolic("aten::alias") |
| @_beartype.beartype |
| def alias(g: jit_utils.GraphContext, self): |
| return self |
| |
| |
| @_onnx_symbolic("aten::unsqueeze") |
| @symbolic_helper.parse_args("v", "i") |
| @_beartype.beartype |
| def unsqueeze(g: jit_utils.GraphContext, self, dim): |
| # Handle negative dim |
| if dim < 0: |
| rank = symbolic_helper._get_tensor_rank(self) |
| if rank is not None: |
| warnings.warn( |
| "ONNX export unsqueeze with negative axis " |
| + str(dim) |
| + " might cause the onnx model to be incorrect. " |
| + "Negative axis is not supported in ONNX. " |
| + "Axis is converted to " |
| + str(dim + rank + 1) |
| + " based on input shape at export time. " |
| + "Passing an tensor of different rank in execution will be incorrect." |
| ) |
| dim = dim + rank + 1 |
| else: |
| return symbolic_helper._unimplemented( |
| "unsqueeze", "negative axis with unknown input rank", self |
| ) |
| |
| return symbolic_helper._unsqueeze_helper(g, self, axes_i=[dim]) |
| |
| |
| @_onnx_symbolic("aten::sort") |
| # TODO(justinchuby): Support multiple quantized args in output |
| @symbolic_helper.parse_args("v", "i", "i", "none") |
| @_beartype.beartype |
| def sort(g: jit_utils.GraphContext, self, dim, decending, out=None): |
| if out is not None: |
| symbolic_helper._unimplemented( |
| "Sort", "Out parameter is not supported for sort", self |
| ) |
| self_sizes = symbolic_helper._get_tensor_sizes(self) |
| try: |
| dim_size = self_sizes[dim] |
| except Exception: |
| # FIXME(justinchuby): Avoid catching Exception. |
| # Catch a more specific exception instead. |
| dim_size = None |
| |
| if dim_size is None: |
| return symbolic_helper._unimplemented("Sort", "input size not accessible", self) |
| |
| return g.op("TopK", self, k_i=dim_size, axis_i=dim, outputs=2) |
| |
| |
| @_onnx_symbolic("aten::numel") |
| @_beartype.beartype |
| def numel(g: jit_utils.GraphContext, self): |
| shape = g.op("Shape", self) |
| return g.op("ReduceProd", shape, keepdims_i=0) |
| |
| |
| @_onnx_symbolic("aten::topk") |
| # TODO(justinchuby): Support multiple quantized args in output |
| @symbolic_helper.parse_args("v", "i", "i", "i", "i", "none") |
| @_beartype.beartype |
| def topk(g: jit_utils.GraphContext, self, k, dim, largest, sorted, out=None): |
| if out is not None: |
| symbolic_helper._unimplemented( |
| "TopK", "Out parameter is not supported for topk", self |
| ) |
| if not largest: |
| symbolic_helper._unimplemented("TopK", "Ascending TopK is not supported", self) |
| |
| return g.op("TopK", self, k_i=k, axis_i=dim, outputs=2) |
| |
| |
| @_onnx_symbolic("prim::convert_element_type") |
| @_beartype.beartype |
| def convert_element_type(g: jit_utils.GraphContext, self, *args): |
| dtype = symbolic_helper._get_const(args[0], "i", "dtype") |
| return g.op("Cast", self, to_i=_type_utils.JitScalarType(dtype).onnx_type()) |
| |
| |
| @_onnx_symbolic("aten::to") |
| @_beartype.beartype |
| def to(g: jit_utils.GraphContext, self, *args): |
| @_beartype.beartype |
| def is_aten_to_device_only(args): |
| if len(args) == 4: |
| # aten::to(Tensor, Device, bool, bool, memory_format) |
| return ( |
| args[0].node().kind() == "prim::device" |
| or args[0].type().isSubtypeOf(_C.ListType.ofInts()) |
| or isinstance(args[0].type(), _C.DeviceObjType) |
| ) |
| elif len(args) == 5: |
| # aten::to(Tensor, Device, ScalarType, bool, bool, memory_format) |
| # When dtype is None, this is a aten::to(device) call |
| dtype = symbolic_helper._get_const(args[1], "i", "dtype") |
| return dtype is None |
| elif len(args) in (6, 7): |
| # aten::to(Tensor, ScalarType, Layout, Device, bool, bool, memory_format) -> Tensor |
| # aten::to(Tensor, ScalarType, Layout, Device, bool, bool, bool, memory_format) -> Tensor |
| # When dtype is None, this is a aten::to(device) call |
| dtype = symbolic_helper._get_const(args[0], "i", "dtype") |
| return dtype is None |
| return False |
| |
| # ONNX doesn't have a concept of a device, so we ignore device-only casts |
| if is_aten_to_device_only(args): |
| return self |
| |
| if len(args) == 4: |
| # TestONNXRuntime::test_ones_bool shows args[0] of aten::to() can be onnx::Constant[value=<Tensor>]() |
| # In this case, the constant value is a tensor not int, |
| # so symbolic_helper._maybe_get_const(args[0], 'i') would not work. |
| dtype = args[0] |
| if ( |
| symbolic_helper._is_value(args[0]) |
| and args[0].node().kind() == "onnx::Constant" |
| ): |
| tval = symbolic_helper._node_get(args[0].node(), "value") |
| if isinstance(tval, torch.Tensor): |
| if len(tval.shape) == 0: |
| tval = tval.item() |
| dtype = int(tval) |
| else: |
| dtype = tval |
| |
| if symbolic_helper._is_value(dtype) or isinstance(dtype, torch.Tensor): |
| # aten::to(Tensor, Tensor, bool, bool, memory_format) |
| dtype = _type_utils.JitScalarType.from_value(args[0]) |
| return g.op( |
| "Cast", |
| self, |
| to_i=dtype.onnx_type(), |
| ) |
| else: |
| # aten::to(Tensor, ScalarType, bool, bool, memory_format) |
| # memory_format is ignored |
| return g.op("Cast", self, to_i=_type_utils.JitScalarType(dtype).onnx_type()) |
| elif len(args) == 5: |
| # aten::to(Tensor, Device, ScalarType, bool, bool, memory_format) |
| dtype = symbolic_helper._get_const(args[1], "i", "dtype") |
| # memory_format is ignored |
| return g.op("Cast", self, to_i=_type_utils.JitScalarType(dtype).onnx_type()) |
| elif len(args) == 6: |
| # aten::to(Tensor, ScalarType, Layout, Device, bool, bool, memory_format) -> Tensor |
| dtype = symbolic_helper._get_const(args[0], "i", "dtype") |
| # Layout, device and memory_format are ignored |
| return g.op("Cast", self, to_i=_type_utils.JitScalarType(dtype).onnx_type()) |
| elif len(args) == 7: |
| # aten::to(Tensor, ScalarType, Layout, Device, bool, bool, bool, memory_format) -> Tensor |
| dtype = symbolic_helper._get_const(args[0], "i", "dtype") |
| # Layout, device and memory_format are ignored |
| return g.op("Cast", self, to_i=_type_utils.JitScalarType(dtype).onnx_type()) |
| |
| return symbolic_helper._onnx_unsupported("Unknown aten::to signature", self) |
| |
| |
| @_onnx_symbolic("aten::repeat") |
| @_beartype.beartype |
| def repeat(g: jit_utils.GraphContext, self, repeats): |
| dtype = _type_utils.JitScalarType.INT64 |
| shape_ = ones_like(g, repeats, dtype) |
| self = g.op("Expand", self, shape_) |
| return g.op("Tile", self, repeats) |
| |
| |
| @_onnx_symbolic("aten::repeat_interleave") |
| @_beartype.beartype |
| def repeat_interleave( |
| g: jit_utils.GraphContext, self, repeats, dim=None, output_size=None |
| ): |
| input = self |
| # if dim is None flatten |
| # By default, use the flattened input array, and return a flat output array |
| if symbolic_helper._is_none(dim): |
| input = symbolic_helper._reshape_helper( |
| g, self, g.op("Constant", value_t=torch.tensor([-1])) |
| ) |
| dim = torch.tensor(0, dtype=torch.int64) |
| else: |
| dim = symbolic_helper._maybe_get_scalar(dim) |
| |
| repeats_dim = symbolic_helper._get_tensor_rank(repeats) |
| repeats_sizes = symbolic_helper._get_tensor_sizes(repeats) |
| input_sizes = symbolic_helper._get_tensor_sizes(input) |
| if repeats_dim is None: |
| raise errors.SymbolicValueError( |
| "Unsupported: ONNX export of repeat_interleave for unknown repeats rank.", |
| input, |
| ) |
| if repeats_sizes is None: |
| raise errors.SymbolicValueError( |
| "Unsupported: ONNX export of repeat_interleave for unknown repeats size.", |
| input, |
| ) |
| if input_sizes is None: |
| raise errors.SymbolicValueError( |
| "Unsupported: ONNX export of repeat_interleave for unknown input size.", |
| input, |
| ) |
| |
| # Handle cases where dim is negative |
| if dim < 0: |
| dim += len(input_sizes) |
| |
| input_sizes_temp = input_sizes.copy() |
| for idx, input_size in enumerate(input_sizes): |
| if input_size is None: |
| input_sizes[idx], input_sizes_temp[idx] = 0, -1 |
| |
| # Cases where repeats is an int or single value tensor |
| if repeats_dim == 0 or (repeats_dim == 1 and repeats_sizes[0] == 1): |
| if input_sizes[dim] == 0: |
| return symbolic_helper._onnx_opset_unsupported_detailed( |
| "repeat_interleave", |
| 9, |
| 13, |
| "Unsupported along dimension with unknown input size", |
| self, |
| ) |
| return symbolic_helper._repeat_interleave_single_value_repeat_helper( |
| g, self, repeats, dim |
| ) |
| |
| # Cases where repeats is a 1 dim Tensor |
| elif repeats_dim == 1: |
| if input_sizes[dim] == 0: |
| return symbolic_helper._onnx_opset_unsupported_detailed( |
| "repeat_interleave", |
| 9, |
| 13, |
| "Unsupported along dimension with unknown input size", |
| self, |
| ) |
| if repeats_sizes[0] is None: |
| return symbolic_helper._onnx_opset_unsupported_detailed( |
| "repeat_interleave", |
| 9, |
| 13, |
| "Unsupported for cases with dynamic repeats", |
| self, |
| ) |
| assert ( |
| repeats_sizes[0] == input_sizes[dim] |
| ), "repeats must have the same size as input along dim" |
| reps = repeats_sizes[0] |
| else: |
| raise errors.SymbolicValueError("repeats must be 0-dim or 1-dim tensor", self) |
| |
| final_splits = list() |
| r_splits = symbolic_helper._repeat_interleave_split_helper(g, repeats, reps, 0) |
| i_splits = symbolic_helper._repeat_interleave_split_helper(g, input, reps, dim) |
| input_sizes[dim], input_sizes_temp[dim] = -1, 1 |
| for idx, r_split in enumerate(r_splits): |
| i_split = unsqueeze(g, i_splits[idx], dim + 1) |
| r_concat = [ |
| g.op("Constant", value_t=torch.LongTensor(input_sizes_temp[: dim + 1])), |
| r_split, |
| g.op("Constant", value_t=torch.LongTensor(input_sizes_temp[dim + 1 :])), |
| ] |
| r_concat = g.op("Concat", *r_concat, axis_i=0) |
| i_split = expand(g, i_split, r_concat, None) |
| i_split = symbolic_helper._reshape_helper( |
| g, |
| i_split, |
| g.op("Constant", value_t=torch.LongTensor(input_sizes)), |
| allowzero=0, |
| ) |
| final_splits.append(i_split) |
| return g.op("Concat", *final_splits, axis_i=dim) |
| |
| |
| @_onnx_symbolic("aten::pixel_shuffle") |
| @symbolic_helper.parse_args("v", "i") |
| @_beartype.beartype |
| def pixel_shuffle(g: jit_utils.GraphContext, self, upscale_factor): |
| dims = symbolic_helper._get_tensor_sizes(self) |
| if len(dims) != 4: |
| return symbolic_helper._unimplemented( |
| "pixel_shuffle", "only support 4d input", self |
| ) |
| if any(i is None for i in dims[1:]): |
| after_view = symbolic_helper._reshape_helper( |
| g, |
| symbolic_helper._unsqueeze_helper(g, self, [2, 3]), |
| g.op( |
| "Constant", |
| value_t=torch.tensor([0, -1, upscale_factor, upscale_factor, 0, 0]), |
| ), |
| allowzero=0, |
| ) |
| after_transpose = g.op("Transpose", after_view, perm_i=[0, 1, 4, 2, 5, 3]) |
| # For dynamic input shapes, two reshapes are performed |
| reshape_h = symbolic_helper._reshape_helper( |
| g, |
| after_transpose, |
| g.op("Constant", value_t=torch.tensor([0, 0, -1, 1, 0, 0])), |
| allowzero=0, |
| ) |
| reshape_w = symbolic_helper._reshape_helper( |
| g, |
| reshape_h, |
| g.op("Constant", value_t=torch.tensor([0, 0, 0, 0, -1, 1])), |
| allowzero=0, |
| ) |
| return symbolic_helper._squeeze_helper(g, reshape_w, [3, 5]) |
| else: |
| output_channel = dims[1] // upscale_factor // upscale_factor |
| after_view = symbolic_helper._reshape_helper( |
| g, |
| self, |
| g.op( |
| "Constant", |
| value_t=torch.tensor( |
| [ |
| -1, |
| output_channel, |
| upscale_factor, |
| upscale_factor, |
| dims[2], |
| dims[3], |
| ] |
| ), |
| ), |
| allowzero=0, |
| ) |
| after_transpose = g.op("Transpose", after_view, perm_i=[0, 1, 4, 2, 5, 3]) |
| return symbolic_helper._reshape_helper( |
| g, |
| after_transpose, |
| g.op( |
| "Constant", |
| value_t=torch.tensor( |
| [ |
| -1, |
| output_channel, |
| dims[2] * upscale_factor, |
| dims[3] * upscale_factor, |
| ] |
| ), |
| ), |
| allowzero=0, |
| ) |
| |
| |
| @_onnx_symbolic("aten::pixel_unshuffle") |
| @symbolic_helper.parse_args("v", "i") |
| @_beartype.beartype |
| def pixel_unshuffle(g: jit_utils.GraphContext, self, downscale_factor): |
| dims = symbolic_helper._get_tensor_sizes(self) |
| if len(dims) != 4: |
| return symbolic_helper._unimplemented( |
| "pixel_shuffle", "only support 4d input", self |
| ) |
| if any(i is None for i in dims[1:]): |
| # For dynamic input shapes, two reshapes are performed |
| reshape_h = symbolic_helper._reshape_helper( |
| g, |
| symbolic_helper._unsqueeze_helper(g, self, [3]), |
| g.op("Constant", value_t=torch.tensor([0, 0, -1, downscale_factor, 0])), |
| allowzero=0, |
| ) |
| reshape_w = symbolic_helper._reshape_helper( |
| g, |
| reshape_h, |
| g.op("Constant", value_t=torch.tensor([0, 0, 0, 0, -1, downscale_factor])), |
| allowzero=0, |
| ) |
| after_transpose = g.op("Transpose", reshape_w, perm_i=[0, 1, 3, 5, 2, 4]) |
| final_reshape = symbolic_helper._reshape_helper( |
| g, |
| after_transpose, |
| g.op("Constant", value_t=torch.tensor([0, -1, 1, 1, 0, 0])), |
| allowzero=0, |
| ) |
| return symbolic_helper._squeeze_helper(g, final_reshape, [2, 3]) |
| else: |
| output_channel = dims[1] * downscale_factor * downscale_factor |
| after_view = symbolic_helper._reshape_helper( |
| g, |
| self, |
| g.op( |
| "Constant", |
| value_t=torch.tensor( |
| [ |
| -1, |
| dims[1], |
| dims[2] // downscale_factor, |
| downscale_factor, |
| dims[3] // downscale_factor, |
| downscale_factor, |
| ] |
| ), |
| ), |
| allowzero=0, |
| ) |
| after_transpose = g.op("Transpose", after_view, perm_i=[0, 1, 3, 5, 2, 4]) |
| return symbolic_helper._reshape_helper( |
| g, |
| after_transpose, |
| g.op( |
| "Constant", |
| value_t=torch.tensor( |
| [ |
| -1, |
| output_channel, |
| dims[2] // downscale_factor, |
| dims[3] // downscale_factor, |
| ] |
| ), |
| ), |
| allowzero=0, |
| ) |
| |
| |
| @_beartype.beartype |
| def _generic_rnn( |
| g: jit_utils.GraphContext, |
| variant, |
| input, |
| initial_states, |
| all_weights, |
| has_biases, |
| num_layers, |
| dropout, |
| train, |
| bidirectional, |
| batch_first=None, |
| batch_sizes=None, |
| ): |
| warnings.warn( |
| "Exporting a model to ONNX with a batch_size other than 1, " |
| + "with a variable length with " |
| + variant |
| + " can cause an error " |
| + "when running the ONNX model with a different batch size. " |
| + "Make sure to save the model with a batch size of 1, " |
| + "or define the initial states (h0/c0) as inputs of the model. " |
| ) |
| |
| onnxActivations = [ |
| "Relu", |
| "Tanh", |
| "Sigmoid", |
| "Affine", |
| "LeakyRelu", |
| "ThresholdedRelu", |
| "ScaledTanh", |
| "HardSigmoid", |
| "Elu", |
| "Softsign", |
| "Softplus", |
| ] |
| variantToOnnxActivationMap = dict( |
| zip([act_fun.lower() for act_fun in onnxActivations], onnxActivations) |
| ) |
| weights_per_layer = 4 if has_biases else 2 |
| # this means that projections are used inside LSTM, so need to tell user that it's not supported |
| if variant == "LSTM" and len(all_weights) != num_layers * weights_per_layer * ( |
| 1 + bidirectional |
| ): |
| return symbolic_helper._unimplemented("LSTM", "LSTMs with projections", input) |
| assert len(all_weights) == num_layers * weights_per_layer * (1 + bidirectional) |
| layer_weights = [ |
| all_weights[i : i + weights_per_layer] |
| for i in range(0, len(all_weights), weights_per_layer) |
| ] |
| if batch_first: |
| # batch, seq, feat -> seq, batch, feat |
| input = g.op("Transpose", input, perm_i=[1, 0, 2]) |
| if dropout and train: |
| return symbolic_helper._unimplemented( |
| "RNN/GRU/LSTM", "dropout in training mode", input |
| ) |
| |
| if variant.startswith("RNN"): |
| nonlinearity = variantToOnnxActivationMap[variant[4:].lower()] |
| variant = "RNN" |
| |
| w_hh = all_weights[1] |
| hidden_size = symbolic_helper._get_tensor_dim_size(w_hh, 1) |
| if hidden_size is None: |
| return symbolic_helper._unimplemented( |
| "RNN/GRU/LSTM", "unknown hidden size", input |
| ) |
| |
| unidirectional = not bidirectional |
| |
| prev_output = input |
| |
| h_outs = [] |
| if variant == "RNN" or variant == "GRU": |
| h0 = initial_states |
| elif variant == "LSTM": |
| h0, c0 = initial_states |
| c_outs = [] |
| |
| sequence_lens = unused(g) if batch_sizes is None else batch_sizes |
| |
| if variant == "GRU": |
| # pytorch is reset, input, hidden |
| # onnx is input, reset, hidden |
| reform_permutation = [(1, 2), (0, 1), (2, 3)] |
| elif variant == "LSTM": |
| # pytorch is input, forget, cell, output. |
| # onnx is input, output, forget, cell. |
| reform_permutation = [(0, 1), (3, 4), (1, 3)] |
| |
| @_beartype.beartype |
| def reform_weights(g, w, n, intervals): |
| slices = [ |
| symbolic_helper._slice_helper(g, w, axes=[0], starts=[x * n], ends=[y * n]) |
| for x, y in intervals |
| ] |
| return g.op("Concat", *slices, axis_i=0) |
| |
| @_beartype.beartype |
| def transform_weights_no_bias(layer_index): |
| weights = layer_weights[layer_index] |
| if variant == "RNN": |
| weight_ih, weight_hh = weights |
| elif variant == "GRU" or variant == "LSTM": |
| weight_ih, weight_hh = ( |
| reform_weights(g, w, hidden_size, reform_permutation) for w in weights |
| ) |
| return tuple( |
| symbolic_helper._unsqueeze_helper(g, x, [0]) for x in (weight_ih, weight_hh) |
| ) |
| |
| @_beartype.beartype |
| def transform_weights(layer_index): |
| weights = layer_weights[layer_index] |
| if variant == "RNN": |
| weight_ih, weight_hh, bias_ih, bias_hh = weights |
| elif variant == "GRU" or variant == "LSTM": |
| weight_ih, weight_hh, bias_ih, bias_hh = ( |
| reform_weights(g, w, hidden_size, reform_permutation) for w in weights |
| ) |
| bias_concat = g.op("Concat", bias_ih, bias_hh, axis_i=0) |
| return tuple( |
| symbolic_helper._unsqueeze_helper(g, x, [0]) |
| for x in (weight_ih, weight_hh, bias_concat) |
| ) |
| |
| @_beartype.beartype |
| def retrieve_state(x, start, end): |
| return ( |
| x |
| if num_layers == 1 |
| else symbolic_helper._slice_helper( |
| g, x, axes=[0], starts=[start], ends=[end] |
| ) |
| ) |
| |
| for i in range(num_layers): |
| if unidirectional: |
| if weights_per_layer == 4: |
| weight_ih, weight_hh, bias_concat = transform_weights(i) |
| else: |
| weight_ih, weight_hh = transform_weights_no_bias(i) |
| bias_concat = unused(g) |
| |
| state_indices = i, i + 1 |
| else: |
| if weights_per_layer == 4: |
| weight_ih_f, weight_hh_f, bias_f = transform_weights(2 * i) |
| weight_ih_b, weight_hh_b, bias_b = transform_weights(2 * i + 1) |
| bias_concat = g.op("Concat", bias_f, bias_b, axis_i=0) |
| else: |
| weight_ih_f, weight_hh_f = transform_weights_no_bias(2 * i) |
| weight_ih_b, weight_hh_b = transform_weights_no_bias(2 * i + 1) |
| bias_concat = unused(g) |
| |
| weight_ih = g.op("Concat", weight_ih_f, weight_ih_b, axis_i=0) |
| weight_hh = g.op("Concat", weight_hh_f, weight_hh_b, axis_i=0) |
| |
| state_indices = 2 * i, 2 * i + 2 |
| |
| inputs = [prev_output, weight_ih, weight_hh, bias_concat, sequence_lens] |
| |
| inputs.append(retrieve_state(h0, *state_indices)) |
| if variant == "LSTM": |
| inputs.append(retrieve_state(c0, *state_indices)) |
| |
| extra_kwargs = {} if unidirectional else {"direction_s": "bidirectional"} |
| if variant == "RNN": |
| if bidirectional: |
| activation = [nonlinearity, nonlinearity] |
| else: |
| activation = [nonlinearity] |
| |
| prev_output, h_out = g.op( |
| "RNN", |
| *inputs, |
| outputs=2, |
| hidden_size_i=hidden_size, |
| activations_s=activation, |
| **extra_kwargs, |
| ) |
| elif variant == "GRU": |
| prev_output, h_out = g.op( |
| "GRU", |
| *inputs, |
| outputs=2, |
| hidden_size_i=hidden_size, |
| linear_before_reset_i=1, |
| **extra_kwargs, |
| ) |
| elif variant == "LSTM": |
| prev_output, h_out, c_out = g.op( |
| "LSTM", *inputs, outputs=3, hidden_size_i=hidden_size, **extra_kwargs |
| ) |
| |
| if bidirectional: |
| # The ONNX RNN/GRU/LSTM produce an output of dimensions |
| # seq_len, num_directions, batch, hidden_size |
| # We have to convert to match pytorch's expected |
| # seq_len, batch, num_directions * hidden_size |
| # by first moving num_directions before hidden_size with |
| # Transpose, and then combining it with hidden_size |
| # with Reshape. |
| prev_output = g.op("Transpose", prev_output, perm_i=[0, 2, 1, 3]) |
| prev_output = symbolic_helper._reshape_helper( |
| g, |
| prev_output, |
| g.op("Constant", value_t=torch.LongTensor([0, 0, -1])), |
| allowzero=0, |
| ) |
| else: |
| prev_output = symbolic_helper._squeeze_helper(g, prev_output, [1]) |
| |
| h_outs.append(h_out) |
| if variant == "LSTM": |
| c_outs.append(c_out) |
| if batch_first: |
| # seq, batch, num_directions * hidden_size -> batch, seq, num_directions * hidden_size |
| prev_output = g.op("Transpose", prev_output, perm_i=[1, 0, 2]) |
| h_outs = h_out if num_layers == 1 else g.op("Concat", *h_outs, axis_i=0) |
| if variant == "RNN" or variant == "GRU": |
| return prev_output, h_outs |
| elif variant == "LSTM": |
| c_outs = c_out if num_layers == 1 else g.op("Concat", *c_outs, axis_i=0) |
| return prev_output, h_outs, c_outs |
| |
| |
| @symbolic_helper.parse_args("v", "v", "v", "i", "i", "f", "i", "i", "i") |
| @_beartype.beartype |
| def _lstm_full( |
| g: jit_utils.GraphContext, |
| input, |
| hidden_v, |
| weight_v, |
| has_biases, |
| num_layers, |
| dropout, |
| train, |
| bidirectional, |
| batch_first, |
| ): |
| hidden, weight = symbolic_helper._unpack_list( |
| hidden_v |
| ), symbolic_helper._unpack_list(weight_v) |
| return _generic_rnn( |
| g, |
| "LSTM", |
| input, |
| hidden, |
| weight, |
| has_biases, |
| num_layers, |
| dropout, |
| train, |
| bidirectional, |
| batch_first, |
| ) |
| |
| |
| @symbolic_helper.parse_args("v", "v", "v", "v", "i", "i", "f", "i", "i") |
| @_beartype.beartype |
| def _lstm_packed( |
| g: jit_utils.GraphContext, |
| input, |
| batch_sizes, |
| hidden_v, |
| weight_v, |
| has_biases, |
| num_layers, |
| dropout, |
| train, |
| bidirectional, |
| ): |
| hidden, weight = symbolic_helper._unpack_list( |
| hidden_v |
| ), symbolic_helper._unpack_list(weight_v) |
| return _generic_rnn( |
| g, |
| "LSTM", |
| input, |
| hidden, |
| weight, |
| has_biases, |
| num_layers, |
| dropout, |
| train, |
| bidirectional, |
| batch_sizes=batch_sizes, |
| ) |
| |
| |
| @_onnx_symbolic("aten::lstm") |
| @_beartype.beartype |
| def lstm(g: jit_utils.GraphContext, *args): |
| if symbolic_helper._is_tensor_list(args[3]): |
| return _lstm_packed(g, *args) |
| else: |
| return _lstm_full(g, *args) |
| |
| |
| @_onnx_symbolic("aten::lstm_cell") |
| @_beartype.beartype |
| def lstm_cell(g: jit_utils.GraphContext, self, hidden, w_ih, w_hh, b_ih, b_hh): |
| input = symbolic_helper._unsqueeze_helper(g, self, [0]) |
| hidden = symbolic_helper._unpack_list(hidden) |
| hidden = [symbolic_helper._unsqueeze_helper(g, x, [0]) for x in hidden] |
| weight = ( |
| (w_ih, w_hh, b_ih, b_hh) if symbolic_helper._is_tensor(b_ih) else (w_ih, w_hh) |
| ) |
| has_biases = True if symbolic_helper._is_tensor(b_ih) else False |
| _, h_outs, c_outs = _generic_rnn( |
| g, |
| "LSTM", |
| input, |
| hidden, |
| weight, |
| has_biases, |
| num_layers=1, |
| dropout=0, |
| train=0, |
| bidirectional=False, |
| batch_first=False, |
| ) |
| return symbolic_helper._squeeze_helper( |
| g, h_outs, [0] |
| ), symbolic_helper._squeeze_helper(g, c_outs, [0]) |
| |
| |
| @_onnx_symbolic("aten::gru", decorate=[_apply_params("GRU"), _export("gru")]) |
| @_onnx_symbolic( |
| "aten::rnn_tanh", decorate=[_apply_params("RNN_TANH"), _export("rnn_tanh")] |
| ) |
| @_onnx_symbolic( |
| "aten::rnn_relu", decorate=[_apply_params("RNN_RELU"), _export("rnn_relu")] |
| ) |
| def _one_hidden_rnn(kind: str): |
| @symbolic_helper.parse_args("v", "v", "v", "i", "i", "f", "i", "i", "i") |
| @_beartype.beartype |
| def _rnn_full( |
| g, |
| input, |
| hidden, |
| weight_v, |
| has_biases, |
| num_layers, |
| dropout, |
| train, |
| bidirectional, |
| batch_first, |
| ): |
| weight = symbolic_helper._unpack_list(weight_v) |
| return _generic_rnn( |
| g, |
| kind, |
| input, |
| hidden, |
| weight, |
| has_biases, |
| num_layers, |
| dropout, |
| train, |
| bidirectional, |
| batch_first, |
| ) |
| |
| @symbolic_helper.parse_args("v", "v", "v", "v", "i", "i", "f", "i", "i") |
| def _rnn_packed( |
| g, |
| input, |
| batch_sizes, |
| hidden, |
| weight_v, |
| has_biases, |
| num_layers, |
| dropout, |
| train, |
| bidirectional, |
| ): |
| weight = symbolic_helper._unpack_list(weight_v) |
| return _generic_rnn( |
| g, |
| kind, |
| input, |
| hidden, |
| weight, |
| has_biases, |
| num_layers, |
| dropout, |
| train, |
| bidirectional, |
| batch_sizes=batch_sizes, |
| ) |
| |
| def symbolic(g, *args): |
| if symbolic_helper._is_tensor_list(args[3]): |
| return _rnn_packed(g, *args) |
| else: |
| return _rnn_full(g, *args) |
| |
| return symbolic |
| |
| |
| @_onnx_symbolic("aten::_dim_arange") |
| @symbolic_helper.parse_args("v", "i") |
| @_beartype.beartype |
| def _dim_arange(g: jit_utils.GraphContext, like, dim): |
| like_shape = g.op("Shape", like) |
| stop = g.op( |
| "Gather", like_shape, g.op("Constant", value_t=torch.tensor(dim)), axis_i=0 |
| ) |
| if symbolic_helper.is_caffe2_aten_fallback(): |
| return g.op("_caffe2::Range", stop) |
| else: |
| # aten::arange(Scalar end, ScalarType dtype, Layout, Device, bool pin_memory) |
| return arange(g, stop, 4, None, None, None) |
| |
| |
| @_onnx_symbolic("aten::detach") |
| @_beartype.beartype |
| def detach(g: jit_utils.GraphContext, input): |
| # Erase aten::detach nodes because ONNX is inference only |
| return input |
| |
| |
| @_onnx_symbolic("aten::contiguous") |
| @symbolic_helper.parse_args("v", "i") |
| @_beartype.beartype |
| def contiguous(g: jit_utils.GraphContext, input, memory_format): |
| if memory_format > 2: # allower values are any, preserve and contiguous_format |
| raise errors.SymbolicValueError( |
| "onnx memory_format support is not implemented", input |
| ) |
| return input |
| |
| |
| @_onnx_symbolic("aten::_pack_padded_sequence") |
| @symbolic_helper.parse_args("v", "v", "i") |
| @_beartype.beartype |
| def _pack_padded_sequence(g: jit_utils.GraphContext, input, lengths, batch_first): |
| # Currently there is no PackPadded operator in ONNX. We rely on an |
| # optimization pass to remove this later. It is an error if all |
| # PackPadded operators cannot be optimized out. |
| if batch_first: |
| input = g.op("Transpose", input, perm_i=[1, 0, 2]) |
| if not lengths.type().isSubtypeOf(torch._C.TensorType.get()): |
| raise errors.SymbolicValueError( |
| "'lengths' must be a Tensor for ONNX export", input |
| ) |
| # We know it's a TensorType so this check is now safe. |
| # It's really only necessary because those operators expand to something that |
| # only works with int32 types in Caffe2... |
| if ( |
| _type_utils.JitScalarType.from_value( |
| lengths, _type_utils.JitScalarType.UNDEFINED |
| ) |
| != _type_utils.JitScalarType.INT |
| ): |
| lengths = g.op("Cast", lengths, to_i=_C_onnx.TensorProtoDataType.INT32) |
| return g.op("prim::PackPadded", input, lengths, outputs=2) |
| |
| |
| @_onnx_symbolic("aten::_pad_packed_sequence") |
| @symbolic_helper.parse_args("v", "v", "i", "t", "v") |
| @_beartype.beartype |
| def _pad_packed_sequence( |
| g: jit_utils.GraphContext, |
| data, |
| batch_sizes, |
| batch_first, |
| padding_value, |
| total_length, |
| ): |
| # Ignore total_length as it is not supported in _symbolic_pad_packed_sequence |
| # It is only useful/used when training using data_parallel model, so |
| # It shouldn't be relevant for ONNX anyway |
| data, lengths = g.op("prim::PadPacked", data, batch_sizes, outputs=2) |
| if batch_first: |
| data = g.op("Transpose", data, perm_i=[1, 0, 2]) |
| return data, lengths |
| |
| |
| @_onnx_symbolic("aten::randint") |
| @_beartype.beartype |
| def randint(g: jit_utils.GraphContext, low, high, shapes, dtype, *options): |
| dtype = symbolic_helper._get_const(dtype, "i", "dtype") |
| low_i = symbolic_helper._get_const(low, "i", "low") |
| high_i = symbolic_helper._get_const(high, "i", "high") |
| if dtype is None: |
| scalar_type = _type_utils.JitScalarType.INT64 |
| else: |
| scalar_type = _type_utils.JitScalarType(dtype) |
| if low_i is None: |
| raise symbolic_helper._onnx_unsupported("randint", low) |
| if high_i is None: |
| raise symbolic_helper._onnx_unsupported("randint", high) |
| |
| shape = symbolic_helper._maybe_get_const(shapes, "is") |
| if symbolic_helper._is_value(shape): |
| shape_const = g.op( |
| "ConstantOfShape", |
| shapes, |
| value_t=torch.tensor([0], dtype=torch.float), |
| ) |
| randn = g.op( |
| "RandomUniformLike", |
| shape_const, |
| low_f=low_i, |
| high_f=high_i, |
| ) |
| else: |
| randn = g.op( |
| "RandomUniform", |
| shape_i=shape, |
| low_f=low_i, |
| high_f=high_i, |
| ) |
| |
| # cast to integer type |
| int_dtype = _type_utils.JitScalarType.INT64 |
| randint = g.op("Cast", randn, to_i=int_dtype.onnx_type()) |
| if int_dtype != scalar_type: |
| randint = g.op("Cast", randint, to_i=scalar_type.onnx_type()) |
| return randint |
| |
| |
| @_onnx_symbolic("aten::randint_like") |
| @_beartype.beartype |
| def randint_like(g: jit_utils.GraphContext, self, low, high, dtype, *options): |
| dtype = symbolic_helper._get_const(dtype, "i", "dtype") |
| low_i = symbolic_helper._get_const(low, "i", "low") |
| high_i = symbolic_helper._get_const(high, "i", "high") |
| if dtype is None: |
| scalar_type = _type_utils.JitScalarType.INT64 |
| else: |
| scalar_type = _type_utils.JitScalarType(dtype) |
| if low_i is None: |
| raise symbolic_helper._onnx_unsupported("randint", low) |
| if high_i is None: |
| raise symbolic_helper._onnx_unsupported("randint", high) |
| |
| randn = g.op( |
| "RandomUniformLike", |
| self, |
| low_f=low_i, |
| high_f=high_i, |
| ) |
| |
| # cast to integer type |
| int_dtype = _type_utils.JitScalarType.INT64 |
| randint = g.op("Cast", randn, to_i=int_dtype.onnx_type()) |
| if int_dtype != scalar_type: |
| randint = g.op("Cast", randint, to_i=scalar_type.onnx_type()) |
| return randint |
| |
| |
| @_onnx_symbolic("aten::randn") |
| @_beartype.beartype |
| def randn(g: jit_utils.GraphContext, shapes, dtype, *options): |
| dtype = symbolic_helper._get_const(dtype, "i", "dtype") |
| if dtype is None: |
| scalar_type = _type_utils.JitScalarType.FLOAT |
| else: |
| scalar_type = _type_utils.JitScalarType(dtype) |
| shape = symbolic_helper._maybe_get_const(shapes, "is") |
| if symbolic_helper._is_value(shape): |
| shape_const = g.op( |
| "ConstantOfShape", |
| shapes, |
| value_t=torch.tensor([0], dtype=torch.float), |
| ) |
| return g.op( |
| "RandomNormalLike", |
| shape_const, |
| dtype_i=scalar_type.onnx_type(), |
| ) |
| return g.op( |
| "RandomNormal", |
| shape_i=shape, |
| dtype_i=scalar_type.onnx_type(), |
| ) |
| |
| |
| @_onnx_symbolic("aten::rand") |
| @_beartype.beartype |
| def rand(g: jit_utils.GraphContext, shapes, dtype, *options): |
| dtype = symbolic_helper._get_const(dtype, "i", "dtype") |
| if dtype is None: |
| scalar_type = _type_utils.JitScalarType.FLOAT |
| else: |
| scalar_type = _type_utils.JitScalarType(dtype) |
| shape = symbolic_helper._maybe_get_const(shapes, "is") |
| if symbolic_helper._is_value(shape): |
| shape_const = g.op( |
| "ConstantOfShape", |
| shapes, |
| value_t=torch.tensor([0], dtype=torch.float), |
| ) |
| return g.op( |
| "RandomUniformLike", |
| shape_const, |
| dtype_i=scalar_type.onnx_type(), |
| ) |
| return g.op( |
| "RandomUniform", |
| shape_i=shape, |
| dtype_i=scalar_type.onnx_type(), |
| ) |
| |
| |
| @_onnx_symbolic("aten::randn_like") |
| @_beartype.beartype |
| def randn_like( |
| g: jit_utils.GraphContext, |
| self, |
| dtype, |
| layout=None, |
| device=None, |
| pin_memory=False, |
| memory_format=None, |
| ): |
| dtype = symbolic_helper._get_const(dtype, "i", "dtype") |
| if dtype is None: |
| scalar_type = _type_utils.JitScalarType.from_value( |
| self, _type_utils.JitScalarType.FLOAT |
| ) |
| else: |
| scalar_type = _type_utils.JitScalarType(dtype) |
| return g.op("RandomNormalLike", self, dtype_i=scalar_type.onnx_type()) |
| |
| |
| @_onnx_symbolic("aten::rand_like") |
| @_beartype.beartype |
| def rand_like( |
| g: jit_utils.GraphContext, |
| self, |
| dtype, |
| layout=None, |
| device=None, |
| pin_memory=False, |
| memory_format=None, |
| ): |
| dtype = symbolic_helper._get_const(dtype, "i", "dtype") |
| if dtype is None: |
| dtype = _type_utils.JitScalarType.from_value( |
| self, _type_utils.JitScalarType.FLOAT |
| ) |
| return g.op( |
| "RandomUniformLike", self, dtype_i=_type_utils.JitScalarType(dtype).onnx_type() |
| ) |
| |
| |
| @_onnx_symbolic("aten::rrelu") |
| @symbolic_helper.parse_args("v", "f", "f", "i", "none") |
| @_beartype.beartype |
| def rrelu(g: jit_utils.GraphContext, input, lower, upper, training, generator): |
| if not training: |
| slope = (upper + lower) / 2.0 |
| return g.op("LeakyRelu", input, alpha_f=slope) |
| p = g.op("RandomUniformLike", input, high_f=upper, low_f=lower) |
| return g.op("PRelu", input, p) |
| |
| |
| @_onnx_symbolic("aten::bernoulli") |
| @_beartype.beartype |
| def bernoulli(g: jit_utils.GraphContext, input, p=None, generator=None, out=None): |
| if out is not None and not symbolic_helper._is_none(out): |
| symbolic_helper._unimplemented( |
| "Bernoulli", "out parameter is not supported for bernoulli", input |
| ) |
| if generator is not None and not symbolic_helper._is_none(generator): |
| symbolic_helper._unimplemented( |
| "Bernoulli", "generator is not supported for bernoulli", input |
| ) |
| |
| dtype = _type_utils.JitScalarType.from_value( |
| input, _type_utils.JitScalarType.UNDEFINED |
| ) |
| if dtype == _type_utils.JitScalarType.UNDEFINED: |
| return symbolic_helper._unimplemented( |
| "Bernoulli", "input dtype not accessible", input |
| ) |
| |
| rands = g.op( |
| "RandomUniformLike", |
| input, |
| high_f=1.0, |
| low_f=0.0, |
| dtype_i=dtype.onnx_type(), |
| ) |
| prob = p if p is not None and not symbolic_helper._is_none(p) else input |
| output = g.op("Less", rands, prob) |
| return g.op("Cast", output, to_i=dtype.onnx_type()) |
| |
| |
| @_onnx_symbolic("aten::log_sigmoid") |
| @symbolic_helper.parse_args("v") |
| @_beartype.beartype |
| def log_sigmoid(g: jit_utils.GraphContext, input): |
| p = g.op("Sigmoid", input) |
| return g.op("Log", p) |
| |
| |
| @_onnx_symbolic("aten::erf") |
| @symbolic_helper.parse_args("v") |
| @_beartype.beartype |
| def erf(g: jit_utils.GraphContext, input): |
| return g.op("Erf", input) |
| |
| |
| @_onnx_symbolic("aten::flatten") |
| @symbolic_helper.quantized_args(True, False, False) |
| @symbolic_helper.parse_args("v", "i", "i") |
| @_beartype.beartype |
| def flatten(g: jit_utils.GraphContext, input, start_dim, end_dim): |
| dim = symbolic_helper._get_tensor_rank(input) |
| if dim is None: |
| return symbolic_helper._unimplemented( |
| "dim", |
| "ONNX and PyTorch use different strategies to split the input. " |
| "Input rank must be known at export time.", |
| input, |
| ) |
| |
| if dim == 0: |
| return symbolic_helper._reshape_helper(g, input, [1]) |
| if dim == 1: |
| return g.op("Identity", input) |
| # TODO: remove this as onnx opset 11 spec allows negative axes |
| if end_dim < 0: |
| end_dim = dim + end_dim |
| # use ONNX's Flatten operator for cases where the output shape is 2D |
| if start_dim == 1 and end_dim == dim - 1: |
| return g.op("Flatten", input, axis_i=start_dim) |
| if start_dim == 0 and end_dim == dim - 2: |
| return g.op("Flatten", input, axis_i=end_dim + 1) |
| |
| return symbolic_helper._flatten_helper(g, input, start_dim, end_dim, dim) |
| |
| |
| @_onnx_symbolic("aten::nonzero") |
| @symbolic_helper.parse_args("v") |
| @_beartype.beartype |
| def nonzero(g: jit_utils.GraphContext, input): |
| """Emitted from `torch.nonzero(x, as_tuple=False)`""" |
| return t(g, g.op("NonZero", input)) |
| |
| |
| @_onnx_symbolic("aten::nonzero_numpy") |
| # Emitted from `torch.nonzero(x, as_tuple=True)` |
| @_beartype.beartype |
| def nonzero_numpy(g: jit_utils.GraphContext, input, _outputs=None): |
| return unbind(g, nonzero(g, input), 1, _outputs=_outputs) |
| |
| |
| @_onnx_symbolic("aten::isnan") |
| @symbolic_helper.parse_args("v") |
| @_beartype.beartype |
| def isnan(g: jit_utils.GraphContext, input): |
| output = g.op("IsNaN", input) |
| return output |
| |
| |
| @_onnx_symbolic("aten::any") |
| @_beartype.beartype |
| def _any(g: jit_utils.GraphContext, *args): |
| # aten::any(Tensor self) |
| if len(args) == 1: |
| input = args[0] |
| dim, keepdim = None, 0 |
| # aten::any(Tensor self, int dim, bool keepdim) |
| else: |
| input, dim, keepdim = args |
| dim = [symbolic_helper._parse_arg(dim, "i")] |
| keepdim = symbolic_helper._parse_arg(keepdim, "i") |
| input = g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.INT64) |
| input_sum = symbolic_helper._reducesum_helper( |
| g, input, axes_i=dim, keepdims_i=keepdim |
| ) |
| return gt(g, input_sum, g.op("Constant", value_t=torch.tensor(0, dtype=torch.long))) |
| |
| |
| @_onnx_symbolic("aten::all") |
| @_beartype.beartype |
| def _all(g: jit_utils.GraphContext, *args): |
| input = g.op("Not", args[0]) |
| # aten::all(Tensor self) |
| if len(args) == 1: |
| return g.op("Not", _any(g, input)) |
| # aten::all(Tensor self, int dim, bool keepdim) |
| else: |
| return g.op("Not", _any(g, input, args[1], args[2])) |
| |
| |
| @_onnx_symbolic("aten::narrow") |
| @symbolic_helper.parse_args("v", "i", "i", "i") |
| @_beartype.beartype |
| def narrow(g: jit_utils.GraphContext, input, dim, start, length): |
| return symbolic_helper._slice_helper( |
| g, input, axes=[dim], starts=[start], ends=[start + length] |
| ) |
| |
| |
| @_onnx_symbolic("aten::argmax") |
| @symbolic_helper.parse_args("v", "v", "b") |
| @_beartype.beartype |
| def argmax( |
| g: jit_utils.GraphContext, |
| input: torch._C.Value, |
| dim: torch._C.Value, |
| keepdim: bool, |
| ): |
| return symbolic_helper._argmin_argmax_helper(g, input, dim, keepdim, "ArgMax") |
| |
| |
| @_onnx_symbolic("aten::argmin") |
| @symbolic_helper.parse_args("v", "v", "b") |
| @_beartype.beartype |
| def argmin( |
| g: jit_utils.GraphContext, |
| input: torch._C.Value, |
| dim: torch._C.Value, |
| keepdim: bool, |
| ): |
| return symbolic_helper._argmin_argmax_helper(g, input, dim, keepdim, "ArgMin") |
| |
| |
| @_onnx_symbolic("aten::scatter") |
| @symbolic_helper.parse_args("v", "i", "v", "v") |
| @_beartype.beartype |
| def scatter(g: jit_utils.GraphContext, self, dim, index, src): |
| src_type = _type_utils.JitScalarType.from_value( |
| src, _type_utils.JitScalarType.UNDEFINED |
| ) |
| src = symbolic_helper._maybe_get_scalar(src) |
| if symbolic_helper._is_value(src): |
| return g.op("Scatter", self, index, src, axis_i=dim) |
| else: |
| # Check if scalar "src" has same type as self (PyTorch allows different |
| # type for scalar src (but not when src is tensor)). If not, insert Cast node. |
| self_scalar_type = _type_utils.JitScalarType.from_value(self) |
| if self_scalar_type != src_type: |
| src = g.op("Cast", src, to_i=self_scalar_type.onnx_type()) |
| return g.op("Scatter", self, index, expand_as(g, src, index), axis_i=dim) |
| |
| |
| @_onnx_symbolic("aten::scatter_add") |
| @symbolic_helper.parse_args("v", "i", "v", "v") |
| @_beartype.beartype |
| def scatter_add(g: jit_utils.GraphContext, self, dim, index, src): |
| scalar_type = symbolic_helper._try_get_scalar_type(self) |
| if scalar_type is None: |
| return symbolic_helper._unimplemented( |
| "scatter_add", "input dtype not accessible", self |
| ) |
| sizes = symbolic_helper._get_tensor_sizes(self, allow_nonstatic=False) |
| if sizes: |
| to_add = g.op("Constant", value_t=torch.zeros(sizes, dtype=scalar_type.dtype())) |
| else: |
| to_add = zeros_like(g, self, scalar_type) |
| to_add = symbolic_helper._scatter_helper(g, to_add, dim, index, src) |
| return add(g, self, to_add) |
| |
| |
| @_onnx_symbolic("aten::log2") |
| @_beartype.beartype |
| def log2(g: jit_utils.GraphContext, self): |
| _ln2 = 0.693147180559945309 |
| return g.op("Div", log(g, self), g.op("Constant", value_t=torch.tensor(_ln2))) |
| |
| |
| @_onnx_symbolic("aten::is_floating_point") |
| @_beartype.beartype |
| def is_floating_point(g: jit_utils.GraphContext, self): |
| if symbolic_helper._is_fp(self): |
| return g.op("Constant", value_t=torch.BoolTensor([1])) |
| return g.op("Constant", value_t=torch.BoolTensor([0])) |
| |
| |
| @_onnx_symbolic("aten::__is_") |
| @_beartype.beartype |
| def __is_(g: jit_utils.GraphContext, self, other): |
| if symbolic_helper._is_none(other): |
| if symbolic_helper._is_none(self): |
| return g.op("Constant", value_t=torch.BoolTensor([1])) |
| return g.op("Constant", value_t=torch.BoolTensor([0])) |
| return eq(g, self, other) |
| |
| |
| @_onnx_symbolic("aten::__isnot_") |
| @wrap_logical_op_with_negation |
| @_beartype.beartype |
| def __isnot_(g: jit_utils.GraphContext, self, other): |
| return __is_(g, self, other) |
| |
| |
| @_onnx_symbolic("aten::one_hot") |
| @_beartype.beartype |
| def one_hot(g: jit_utils.GraphContext, self, num_classes): |
| values = g.op("Constant", value_t=torch.LongTensor([0, 1])) |
| # onnxruntime supports limited type combinations for OneHot. |
| if _type_utils.JitScalarType.from_value( |
| num_classes, _type_utils.JitScalarType.UNDEFINED |
| ) in { |
| _type_utils.JitScalarType.UINT8, |
| _type_utils.JitScalarType.INT8, |
| _type_utils.JitScalarType.INT, |
| _type_utils.JitScalarType.INT16, |
| }: |
| num_classes = g.op("Cast", num_classes, to_i=_C_onnx.TensorProtoDataType.INT64) |
| return g.op("OneHot", self, num_classes, values, axis_i=-1) |
| |
| |
| @_onnx_symbolic("aten::gather") |
| @symbolic_helper.parse_args("v", "i", "v", "v") |
| @_beartype.beartype |
| def gather(g: jit_utils.GraphContext, self, dim, index, sparse_grad=False): |
| if symbolic_helper._maybe_get_const(sparse_grad, "i"): |
| return symbolic_helper._unimplemented("gather", "sparse_grad == True", self) |
| # NOTE: This workaround is needed since GatherElement is only supported |
| # since opset 11, and Gather in ONNX is not the same as torch.gather. |
| scalar_type = _type_utils.JitScalarType.from_value(self) |
| values = g.op("Constant", value_t=torch.LongTensor([0, 1])) |
| depth = size(g, self, g.op("Constant", value_t=torch.LongTensor([dim]))) |
| index = g.op( |
| "Cast", |
| g.op("OneHot", index, depth, values, axis_i=dim), |
| to_i=scalar_type.onnx_type(), |
| ) |
| mul = g.op("Mul", symbolic_helper._unsqueeze_helper(g, self, [dim + 1]), index) |
| return symbolic_helper._reducesum_helper(g, mul, axes_i=[dim], keepdims_i=0) |
| |
| |
| @symbolic_helper.parse_args("v", "is", "i", "i") |
| @_beartype.beartype |
| def _var_mean(g: jit_utils.GraphContext, input, dim, correction, keepdim): |
| if dim is None: |
| mean = g.op("ReduceMean", input, keepdims_i=0) |
| t_mean = mean |
| num_elements = numel(g, input) |
| else: |
| mean = g.op("ReduceMean", input, axes_i=dim, keepdims_i=keepdim) |
| t_mean = g.op("ReduceMean", input, axes_i=dim, keepdims_i=1) |
| redudced_dims = g.op("Shape", input) |
| # dim could contain one or multiple dimensions |
| redudced_dims = g.op( |
| "Gather", |
| redudced_dims, |
| g.op("Constant", value_t=torch.tensor(dim)), |
| axis_i=0, |
| ) |
| num_elements = g.op("ReduceProd", redudced_dims, keepdims_i=0) |
| sub_v = g.op("Sub", input, t_mean) |
| sqr_sub = g.op("Mul", sub_v, sub_v) |
| keepdim_mean = 0 if dim is None else keepdim |
| var = g.op("ReduceMean", sqr_sub, axes_i=dim, keepdims_i=keepdim_mean) |
| # Correct bias in calculating variance, by dividing it over (N - correction) instead on N |
| if correction is None: |
| correction = 1 |
| if correction != 0: |
| num_elements = g.op( |
| "Cast", num_elements, to_i=_C_onnx.TensorProtoDataType.FLOAT |
| ) |
| one = g.op("Constant", value_t=torch.tensor(correction, dtype=torch.float)) |
| mul = g.op("Mul", var, num_elements) |
| var = g.op("Div", mul, g.op("Sub", num_elements, one)) |
| return var, mean |
| |
| |
| @_onnx_symbolic("aten::std") |
| @_beartype.beartype |
| def std(g: jit_utils.GraphContext, input, *args): |
| var, _ = var_mean(g, input, *args) |
| return g.op("Sqrt", var) |
| |
| |
| @_onnx_symbolic("aten::var") |
| @_beartype.beartype |
| def var(g: jit_utils.GraphContext, input, *args): |
| var, _ = var_mean(g, input, *args) |
| return var |
| |
| |
| @_onnx_symbolic("aten::var_mean") |
| @_beartype.beartype |
| def var_mean(g: jit_utils.GraphContext, input, *args): |
| # var_mean (and all variance-related functions) has multiple signatures, so need to manually figure |
| # out the correct arguments: |
| # aten::var_mean(Tensor self, bool unbiased) |
| # aten::var_mean(Tensor self, int[1] dim, bool unbiased, bool keepdim=False) |
| # aten::var_mean(Tensor self, int[1]? dim=None, *, int? correction=None, bool keepdim=False) |
| if len(args) == 1: |
| return _var_mean(g, input, None, args[0], None) |
| else: |
| return _var_mean(g, input, *args) |
| |
| |
| @_onnx_symbolic("aten::std_mean") |
| @_beartype.beartype |
| def std_mean(g: jit_utils.GraphContext, input, *args): |
| var, mean = var_mean(g, input, *args) |
| return g.op("Sqrt", var), mean |
| |
| |
| @_onnx_symbolic("aten::logsumexp") |
| @symbolic_helper.parse_args("v", "is", "i") |
| @_beartype.beartype |
| def logsumexp(g: jit_utils.GraphContext, input, dim, keepdim): |
| return g.op("ReduceLogSumExp", input, axes_i=dim, keepdims_i=keepdim) |
| |
| |
| @_onnx_symbolic("aten::arange") |
| @_beartype.beartype |
| def arange(g: jit_utils.GraphContext, *args): |
| if symbolic_helper.is_caffe2_aten_fallback(): |
| return g.at("arange", *args) |
| |
| @_beartype.beartype |
| def _get_arange_dtype(dtype): |
| dtype = symbolic_helper._maybe_get_const(dtype, "i") |
| return dtype |
| |
| @_beartype.beartype |
| def _float_step_convert(range_tensor): |
| if symbolic_helper._is_fp(range_tensor): |
| range_tensor = g.op( |
| "Cast", |
| g.op("Ceil", range_tensor), |
| to_i=_type_utils.JitScalarType.INT64.onnx_type(), |
| ) |
| return range_tensor |
| |
| if len(args) == 2 or len(args) == 5: |
| if len(args) == 2: |
| # aten::arange(Scalar end, Tensor out) |
| dtype = None |
| else: |
| # aten::arange(Scalar end, ScalarType dtype, Layout, Device, bool pin_memory) |
| dtype = _get_arange_dtype(args[1]) |
| dtype, end, start, step = symbolic_helper._arange_cast_helper( |
| g, end=args[0], dtype=dtype |
| ) |
| end = symbolic_helper._unsqueeze_helper(g, end, [0]) |
| range_tensor = _float_step_convert(end) |
| arange_tensor = symbolic_helper._squeeze_helper( |
| g, nonzero(g, ones(g, range_tensor, dtype, None, None)), [1] |
| ) |
| return g.op( |
| "Cast", arange_tensor, to_i=_type_utils.JitScalarType(dtype).onnx_type() |
| ) |
| elif len(args) == 4 or len(args) == 7: |
| if len(args) == 4: |
| # aten::arange(Scalar start, Scalar end, Scalar step, Tensor out) |
| dtype = None |
| else: |
| # aten::arange(Scalar start, Scalar end, Scalar step, ScalarType dtype, Layout, Device, bool pin_memory) |
| dtype = _get_arange_dtype(args[3]) |
| dtype, end, start, step = symbolic_helper._arange_cast_helper( |
| g, start=args[0], end=args[1], step=args[2], dtype=dtype |
| ) |
| step = symbolic_helper._unsqueeze_helper(g, step, [0]) |
| end = symbolic_helper._unsqueeze_helper(g, end, [0]) |
| start = symbolic_helper._unsqueeze_helper(g, start, [0]) |
| range_tensor = _float_step_convert(g.op("Div", g.op("Sub", end, start), step)) |
| arange_tensor = symbolic_helper._squeeze_helper( |
| g, nonzero(g, ones(g, range_tensor, None, None, None)), [1] |
| ) |
| arange_tensor = g.op("Add", g.op("Mul", arange_tensor, step), start) |
| return g.op( |
| "Cast", arange_tensor, to_i=_type_utils.JitScalarType(dtype).onnx_type() |
| ) |
| elif len(args) == 6: |
| # aten::arange(Scalar start, Scalar end, ScalarType dtype, Layout, Device, bool pin_memory) |
| dtype = _get_arange_dtype(args[2]) |
| dtype, end, start, step = symbolic_helper._arange_cast_helper( |
| g, start=args[0], end=args[1], dtype=dtype |
| ) |
| end = symbolic_helper._unsqueeze_helper(g, end, [0]) |
| start = symbolic_helper._unsqueeze_helper(g, start, [0]) |
| range_tensor = _float_step_convert(g.op("Sub", end, start)) |
| arange_tensor = g.op( |
| "Add", |
| symbolic_helper._squeeze_helper( |
| g, nonzero(g, ones(g, range_tensor, dtype, *(args[3:]))), [1] |
| ), |
| start, |
| ) |
| return g.op( |
| "Cast", arange_tensor, to_i=_type_utils.JitScalarType(dtype).onnx_type() |
| ) |
| |
| return symbolic_helper._unimplemented("aten::arange", f"with {len(args)} arguments") |
| |
| |
| @_onnx_symbolic("aten::linspace") |
| @_beartype.beartype |
| def linspace( |
| g: jit_utils.GraphContext, start, end, steps, dtype, layout, device, pin_memory |
| ): |
| range_tensor = symbolic_helper._arange_helper(g, steps, None) |
| step = div( |
| g, |
| sub(g, end, start), |
| sub(g, steps, g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64))), |
| ) |
| return add(g, mul(g, range_tensor, step), start) |
| |
| |
| @_onnx_symbolic("aten::lift") |
| @_beartype.beartype |
| def lift(g: jit_utils.GraphContext, self): |
| # at::lift() is a no-op from the perspective of tracing for onnx |
| return self |
| |
| |
| @_onnx_symbolic("aten::masked_fill") |
| @_beartype.beartype |
| def masked_fill(g: jit_utils.GraphContext, self, mask, value): |
| mask = g.op("Cast", mask, to_i=_C_onnx.TensorProtoDataType.BOOL) |
| value = symbolic_helper._maybe_get_scalar(value) |
| return g.op("Where", mask, symbolic_helper._if_scalar_type_as(value, self), self) |
| |
| |
| @_onnx_symbolic("aten::masked_fill_") |
| @_beartype.beartype |
| def masked_fill_(g: jit_utils.GraphContext, self, mask, value): |
| return masked_fill(g, self, mask, value) |
| |
| |
| @_onnx_symbolic("aten::index") |
| @_beartype.beartype |
| def index(g: jit_utils.GraphContext, self, index): |
| if symbolic_helper.is_caffe2_aten_fallback(): |
| return g.at("index", self, index, overload_name="Tensor") |
| |
| if symbolic_helper._is_packed_list(index): |
| indices = symbolic_helper._unpack_list(index) |
| else: |
| indices = [index] |
| |
| @_beartype.beartype |
| def try_mask_to_index(index): |
| if not symbolic_helper._is_none(index) and ( |
| _type_utils.JitScalarType.from_value( |
| index, _type_utils.JitScalarType.UNDEFINED |
| ) |
| == _type_utils.JitScalarType.UINT8 |
| or symbolic_helper._is_bool(index) |
| ): |
| if g.opset < 9: |
| raise errors.SymbolicValueError( |
| "Exporting masked indices are only supported after ONNX opset 9.", |
| self, |
| ) |
| warnings.warn( |
| "Exporting aten::index operator with indices of type Byte. " |
| "Only 1-D indices are supported. In any other case, " |
| "this will produce an incorrect ONNX graph." |
| ) |
| index = symbolic_helper._squeeze_helper(g, nonzero(g, index), [1]) |
| return index |
| |
| indices = [try_mask_to_index(idx) for idx in indices] |
| if len(indices) == 1: |
| return symbolic_helper._select_helper( |
| g, self, 0, indices[0], apply_reshape=False |
| ) |
| else: |
| # Multiple tensors as indices. Each tensor could either be |
| # 1. prim::Constant() |
| # representing ":" in python indexing. E.g. tensor[:, :] |
| # 2. prim::Constant[value=...] or tensor output |
| # representing advanced indexing. E.g. tensor[[0, 1], [2, 0]]. |
| # For more info on advanced indexing, |
| # check https://docs.scipy.org/doc/numpy/reference/arrays.indexing.html#advanced-indexing |
| |
| # Consider a general case of |
| # t: [x_1, y_1, y_2, ..., x_m, ..., y_n] |
| # where t is a tensor of rank m+n, {x_i} are axes where tensor index is provided, and {y_i} are axes for ":". |
| # Same results can be achieved through transposing t into |
| # t: [x_1, x_2, ..., x_m, y_1, y_2, ..., y_n] |
| # and use gatherND. However ONNX does not have gatherND, to use 1d gather we'll need to flatten t |
| # and process the tensor indices. |
| # t: [x_1 * x_2 * ... * x_m, y_1 * y_2 * ... * y_n] |
| # tensor index = \sum_{i=1}^m (ind_i * \prod_{j=i+1}^m (x_j)) |
| # After gather, reshape and transpose back. |
| adv_idx_indices = [ |
| i for i, idx in enumerate(indices) if not symbolic_helper._is_none(idx) |
| ] |
| |
| if len(adv_idx_indices) == 0: |
| return self |
| elif len(adv_idx_indices) == 1: |
| return index_select( |
| g, self, adv_idx_indices[0], indices[adv_idx_indices[0]] |
| ) |
| else: |
| rank = symbolic_helper._get_tensor_rank(self) |
| if rank is None: |
| return symbolic_helper._unimplemented( |
| "aten::index", |
| "operator of advanced indexing on tensor of unknown rank. " |
| "Try turning on shape inference during export: " |
| "torch.onnx._export(..., onnx_shape_inference=True).", |
| self, |
| ) |
| # TODO: If indexing is supported natively in ONNX in future opsets, |
| # update the warning to recommend exporting with higher opset version. |
| warnings.warn( |
| "Exporting aten::index operator of advanced indexing in opset " |
| f"{GLOBALS.export_onnx_opset_version}" |
| " is achieved by combination of multiple ONNX operators, " |
| "including Reshape, Transpose, Concat, and Gather. " |
| "If indices include negative values, the exported graph will produce incorrect results." |
| ) |
| adv_idx_count = len(adv_idx_indices) |
| shape_tensor = _shape_as_tensor(g, self) |
| dim_tensor_list = [ |
| g.op( |
| "Gather", |
| shape_tensor, |
| g.op("Constant", value_t=torch.LongTensor([dim])), |
| axis_i=0, |
| ) |
| for dim in range(rank) |
| ] |
| |
| self = g.op( |
| "Transpose", |
| self, |
| perm_i=adv_idx_indices |
| + [i for i in range(rank) if i not in adv_idx_indices], |
| ) |
| self = g.op("Flatten", self, axis_i=adv_idx_count) |
| |
| # Note that tensor indices will be broadcasted while accumulating. Thus we get the final subarray shape as well. |
| cum_adv_index = indices[adv_idx_indices[-1]] |
| multiplier = dim_tensor_list[adv_idx_indices[-1]] |
| for i in range(adv_idx_count - 2, -1, -1): |
| adv_index = g.op("Mul", indices[adv_idx_indices[i]], multiplier) |
| cum_adv_index = g.op("Add", cum_adv_index, adv_index) |
| multiplier = g.op( |
| "Mul", multiplier, dim_tensor_list[adv_idx_indices[i]] |
| ) |
| |
| # perform gather |
| self = index_select(g, self, 0, cum_adv_index) |
| |
| cum_adv_index_shape_tensor = _shape_as_tensor(g, cum_adv_index) |
| # check if all advanced indices are consecutive. |
| # Refer to https://docs.scipy.org/doc/numpy/reference/arrays.indexing.html#combining-advanced-and-basic-indexing |
| # to understand how the subarray position is decided. |
| if adv_idx_indices == list( |
| range(adv_idx_indices[0], adv_idx_indices[-1] + 1) |
| ): |
| # unfold regular index axes |
| folded_adv_idx_shape_list = [ |
| g.op("Constant", value_t=torch.LongTensor([-1])) |
| ] + [ |
| dim_tensor_list[i] for i in range(rank) if i not in adv_idx_indices |
| ] |
| folded_adv_idx_shape = g.op( |
| "Concat", *folded_adv_idx_shape_list, axis_i=0 |
| ) |
| self = symbolic_helper._reshape_helper(g, self, folded_adv_idx_shape) |
| |
| # Transpose folded advanced indexed axis to its original location. |
| adv_idx_permute = ( |
| list(range(1, adv_idx_indices[0] + 1)) |
| + [0] |
| + list(range(adv_idx_indices[0] + 1, rank - adv_idx_count + 1)) |
| ) |
| self = g.op("Transpose", self, perm_i=adv_idx_permute) |
| |
| # unfold advanced index axes |
| final_shape_list = ( |
| [dim_tensor_list[i] for i in range(adv_idx_indices[0])] |
| + [cum_adv_index_shape_tensor] |
| + [ |
| dim_tensor_list[i] |
| for i in range(adv_idx_indices[0], rank) |
| if i not in adv_idx_indices |
| ] |
| ) |
| final_shape = g.op("Concat", *final_shape_list, axis_i=0) |
| else: |
| final_shape = g.op( |
| "Concat", |
| cum_adv_index_shape_tensor, |
| *[ |
| dim_tensor_list[i] |
| for i in range(rank) |
| if i not in adv_idx_indices |
| ], |
| axis_i=0, |
| ) |
| |
| return symbolic_helper._reshape_helper(g, self, final_shape) |
| |
| |
| @_onnx_symbolic("aten::linalg_norm") |
| @symbolic_helper.parse_args("v", "v", "is", "b", "v") |
| @_beartype.beartype |
| def linalg_norm( |
| g: jit_utils.GraphContext, |
| self: torch._C.Value, |
| ord: torch._C.Value, |
| dim: Optional[Sequence[int]], |
| keepdim: bool, |
| dtype: torch._C.Value, |
| ): |
| # Conditions based on https://pytorch.org/docs/stable/generated/torch.linalg.norm.html |
| ord_value = None |
| if dim is None: |
| if symbolic_helper._is_none(ord): |
| self = symbolic_helper._reshape_helper(g, self, [-1]) |
| ord = g.op("Constant", value_t=torch.LongTensor([2])) |
| self_dim = symbolic_helper._get_tensor_rank(self) |
| if self_dim is None: |
| return symbolic_helper._unimplemented( |
| "dim", "Input rank must be known at export time.", self |
| ) |
| if self_dim == 1: |
| ord_value = symbolic_helper._parse_arg(ord, "f") |
| else: |
| dim = [0, 1] |
| else: |
| if len(dim) == 1: |
| if symbolic_helper._is_none(ord): |
| ord = g.op("Constant", value_t=torch.LongTensor([2])) |
| ord_value = symbolic_helper._parse_arg(ord, "f") |
| if ord_value: |
| return linalg_vector_norm(g, self, ord_value, dim, keepdim, dtype) |
| return linalg_matrix_norm(g, self, ord, dim, keepdim, dtype) |
| |
| |
| @_onnx_symbolic("aten::linalg_vector_norm") |
| @symbolic_helper.parse_args("v", "f", "is", "b", "v") |
| @_beartype.beartype |
| def linalg_vector_norm( |
| g: jit_utils.GraphContext, |
| self: torch._C.Value, |
| ord: float, |
| dim: Optional[Sequence[int]], |
| keepdim: bool, |
| dtype: torch._C.Value, |
| ): |
| # Conditions based on https://pytorch.org/docs/stable/generated/torch.linalg.vector_norm.html |
| if dim is None: |
| self = symbolic_helper._reshape_helper(g, self, [-1]) |
| keepdim = False |
| |
| if ord == math.inf: |
| result = g.op("ReduceMax", g.op("Abs", self), axes_i=dim, keepdims_i=keepdim) |
| elif ord == -math.inf: |
| result = g.op("ReduceMin", g.op("Abs", self), axes_i=dim, keepdims_i=keepdim) |
| elif ord == 0: |
| return symbolic_helper._onnx_opset_unsupported_detailed( |
| "linalg_vector_norm", 9, 11, "ord=0 not supported", self |
| ) |
| else: |
| ord_op = g.op("Constant", value_t=torch.tensor(ord, dtype=torch.float32)) |
| result = symbolic_helper._reducesum_helper( |
| g, g.op("Pow", g.op("Abs", self), ord_op), axes_i=dim, keepdims_i=keepdim |
| ) |
| result = g.op( |
| "Pow", |
| result, |
| g.op( |
| "Div", |
| g.op("Constant", value_t=torch.tensor(1, dtype=torch.float32)), |
| ord_op, |
| ), |
| ) |
| return result |
| |
| |
| @_onnx_symbolic("aten::linalg_matrix_norm") |
| @symbolic_helper.parse_args("v", "v", "is", "b", "v") |
| @_beartype.beartype |
| def linalg_matrix_norm( |
| g: jit_utils.GraphContext, |
| self: torch._C.Value, |
| ord: torch._C.Value, |
| dim: List[int], |
| keepdim: bool, |
| dtype: torch._C.Value, |
| ): |
| # Conditions based on https://pytorch.org/docs/stable/generated/torch.linalg.matrix_norm.html |
| ord_value = symbolic_helper._parse_arg(ord, "s") |
| if ord_value == "fro": |
| return frobenius_norm(g, self, dim, keepdim) |
| elif ord_value == "nuc": |
| return symbolic_helper._unimplemented("linalg.matrix_norm", "ord==nuc", self) |
| else: |
| ord_value = symbolic_helper._parse_arg(ord, "f") |
| if ord_value is None: |
| return frobenius_norm(g, self, dim, keepdim) |
| if ord_value == 2 or ord_value == -2: |
| # ord = 2/-2 unimplemented due to lack of operators |
| # used to calculate singular values |
| return symbolic_helper._unimplemented("linalg.matrix_norm", "ord==2", self) |
| # Wrap the dim vector to handle negative dim values |
| self_dim = symbolic_helper._get_tensor_rank(self) |
| if self_dim is None: |
| return symbolic_helper._unimplemented( |
| "linalg.matrix_norm", "Input rank must be known at export time.", self |
| ) |
| # Common implementation for cases with |
| # ord = 1/-1 and ord = inf/-inf |
| if dim[0] < 0: |
| dim[0] += self_dim |
| if dim[1] < 0: |
| dim[1] += self_dim |
| |
| if ord_value == math.inf or ord_value == -math.inf: |
| dim[0], dim[1] = dim[1], dim[0] |
| if dim[1] > dim[0] and not keepdim: |
| dim[1] -= 1 |
| sum = symbolic_helper._reducesum_helper( |
| g, g.op("Abs", self), axes_i=[dim[0]], keepdims_i=keepdim |
| ) |
| if ord_value > 0: |
| result, indices = max( |
| g, |
| sum, |
| dim_or_y=g.op("Constant", value_t=torch.LongTensor([dim[1]])), |
| keepdim=keepdim, |
| ) |
| else: |
| result, indices = min( |
| g, |
| sum, |
| dim_or_y=g.op("Constant", value_t=torch.LongTensor([dim[1]])), |
| keepdim=keepdim, |
| ) |
| return result |
| |
| |
| @_onnx_symbolic("aten::linalg_cross") |
| @symbolic_helper.parse_args("v", "v", "i") |
| @_beartype.beartype |
| def linalg_cross(g: jit_utils.GraphContext, input, other, dim=-1): |
| return cross(g, input, other, dim) |
| |
| |
| @_onnx_symbolic("aten::frobenius_norm") |
| @symbolic_helper.parse_args("v", "is", "b") |
| @_beartype.beartype |
| def frobenius_norm(g: jit_utils.GraphContext, self, dim=None, keepdim=False): |
| sqr = g.op("Mul", self, self) |
| sumsqr = symbolic_helper._reducesum_helper(g, sqr, axes_i=dim, keepdims_i=keepdim) |
| return g.op("Sqrt", sumsqr) |
| |
| |
| @_onnx_symbolic("aten::multinomial") |
| @symbolic_helper.parse_args("v", "i", "b", "v") |
| @_beartype.beartype |
| def multinomial( |
| g: jit_utils.GraphContext, input, num_samples, replacement=False, generator=None |
| ): |
| if generator is not None and not symbolic_helper._is_none(generator): |
| symbolic_helper._unimplemented( |
| "Multinomial", "generator is not supported for multinomial", input |
| ) |
| if not replacement and num_samples > 1: |
| symbolic_helper._unimplemented( |
| "Multinomial", |
| "replacement=False when num_samples > 1 is not supported for multinomial", |
| input, |
| ) |
| |
| log_input = log(g, input) |
| return g.op( |
| "Multinomial", |
| log_input, |
| dtype_i=_C_onnx.TensorProtoDataType.INT64, |
| sample_size_i=num_samples, |
| ) |
| |
| |
| @_onnx_symbolic("aten::baddbmm") |
| @_beartype.beartype |
| def baddbmm(g: jit_utils.GraphContext, self, batch1, batch2, beta, alpha): |
| scalar_type = _type_utils.JitScalarType.from_value(self) |
| batch_mul = matmul(g, batch1, batch2) |
| mul_a = mul( |
| g, |
| batch_mul, |
| g.op("Cast", alpha, to_i=scalar_type.onnx_type()), |
| ) |
| mul_b = mul( |
| g, |
| self, |
| g.op("Cast", beta, to_i=scalar_type.onnx_type()), |
| ) |
| return add(g, mul_a, mul_b) |
| |
| |
| @_onnx_symbolic("aten::meshgrid") |
| @symbolic_helper.parse_args("v", "s") |
| @_beartype.beartype |
| def meshgrid(g: jit_utils.GraphContext, tensor_list, indexing: Optional[str] = None): |
| if indexing is None: |
| indexing = "ij" |
| elif indexing not in {"ij", "xy"}: |
| raise errors.SymbolicValueError( |
| f"Unsupported indexing: {indexing}", tensor_list |
| ) |
| unpacked_tensor_list = symbolic_helper._unpack_list(tensor_list) |
| if indexing == "xy": |
| unpacked_tensor_list[:2] = unpacked_tensor_list[1::-1] |
| tensors = [ |
| symbolic_helper._reshape_helper( |
| g, t, g.op("Constant", value_t=torch.LongTensor([-1])) |
| ) |
| for t in unpacked_tensor_list |
| ] |
| tensors_shape = [g.op("Shape", t) for t in tensors] |
| out_shape = g.op("Concat", *tensors_shape, axis_i=0) |
| out = [] |
| for i, t in enumerate(tensors): |
| shape_i = [g.op("Constant", value_t=torch.ones(1, dtype=torch.int64))] * len( |
| tensors |
| ) |
| shape_i[i] = tensors_shape[i] |
| t_reshaped = _reshape_from_tensor(g, t, g.op("Concat", *shape_i, axis_i=0)) |
| out.append(g.op("Expand", t_reshaped, out_shape)) |
| if indexing == "xy": |
| out[0], out[1] = out[1], out[0] |
| return g.op("prim::ListConstruct", *out) |
| |
| |
| @_onnx_symbolic("aten::remainder") |
| @_beartype.beartype |
| def remainder(g: jit_utils.GraphContext, input, other): |
| div = _floor_divide(g, input, other) |
| quo = g.op("Mul", div, other) |
| return g.op("Sub", input, quo) |
| |
| |
| @_onnx_symbolic("aten::gelu") |
| @symbolic_helper.parse_args("v", "s") |
| @_beartype.beartype |
| def gelu(g: jit_utils.GraphContext, self: torch._C.Value, approximate: str = "none"): |
| if approximate == "tanh": |
| kBeta = math.sqrt(2 / math.pi) |
| kKappa = 0.044715 |
| |
| beta = torch.tensor(kBeta, dtype=torch.double) |
| kappa = torch.tensor(kKappa, dtype=torch.double) |
| one = torch.tensor(1.0, dtype=torch.double) |
| half = torch.tensor(0.5, dtype=torch.double) |
| |
| self_cube = mul(g, self, mul(g, self, self)) |
| inner = mul(g, beta, add(g, self, mul(g, kappa, self_cube))) |
| return mul(g, half, mul(g, self, add(g, one, g.op("Tanh", inner)))) |
| else: |
| _sqrt2 = 1.4142135623730951 |
| erf = g.op("Erf", g.op("Div", self, torch.tensor(_sqrt2, dtype=torch.double))) |
| erf_plusone = add( |
| g, erf, g.op("Constant", value_t=torch.tensor(1, dtype=torch.double)) |
| ) |
| return mul( |
| g, |
| mul(g, self, erf_plusone), |
| g.op("Constant", value_t=torch.tensor(0.5, dtype=torch.double)), |
| ) |
| |
| |
| @_onnx_symbolic("aten::group_norm") |
| @symbolic_helper.quantized_args(True, False, False, False) |
| @symbolic_helper.parse_args("v", "i", "v", "v", "f", "i") |
| @_beartype.beartype |
| def group_norm( |
| g: jit_utils.GraphContext, input, num_groups, weight, bias, eps, cudnn_enabled |
| ): |
| if symbolic_helper.is_caffe2_aten_fallback(): |
| return g.at( |
| "group_norm", |
| input, |
| weight, |
| bias, |
| num_groups_i=num_groups, |
| eps_f=eps, |
| cudnn_enabled_i=cudnn_enabled, |
| ) |
| |
| channel_size = symbolic_helper._get_tensor_dim_size(input, 1) |
| if channel_size is not None: |
| assert channel_size % num_groups == 0 |
| input_rank = symbolic_helper._get_tensor_rank(input) |
| if input_rank is None: |
| return symbolic_helper._unimplemented("group_norm", "unknown input rank", input) |
| # 0 in the shape list keeps dimension value unchanged. |
| shape = [0, num_groups, -1] |
| input_reshaped = symbolic_helper._reshape_helper( |
| g, input, g.op("Constant", value_t=torch.LongTensor(shape)) |
| ) |
| |
| # C is always divisible by num_groups |
| # Due to shape difference. we need to apply weight and bias after |
| # instance norm computation and reshape |
| weight_ = g.op( |
| "Constant", |
| value_t=torch.tensor( |
| [1.0] * num_groups, |
| dtype=_type_utils.JitScalarType.from_value(input).dtype(), |
| ), |
| ) |
| bias_ = g.op( |
| "Constant", |
| value_t=torch.tensor( |
| [0.0] * num_groups, |
| dtype=_type_utils.JitScalarType.from_value(input).dtype(), |
| ), |
| ) |
| |
| norm_reshaped = g.op( |
| "InstanceNormalization", input_reshaped, weight_, bias_, epsilon_f=eps |
| ) |
| norm = symbolic_helper._reshape_helper(g, norm_reshaped, g.op("Shape", input)) |
| |
| if weight is None or weight.node().mustBeNone(): |
| weight_value = torch.tensor( |
| [1.0], dtype=_type_utils.JitScalarType.from_value(input).dtype() |
| ) |
| weight = g.op("Constant", value_t=weight_value) |
| if bias is None or bias.node().mustBeNone(): |
| bias_value = torch.tensor( |
| [0.0], dtype=_type_utils.JitScalarType.from_value(input).dtype() |
| ) |
| bias = g.op("Constant", value_t=bias_value) |
| |
| # Norm has shape [N, C, *] so we reshape weight and bias to [C, *] |
| axes = list(range(1, input_rank - 1)) |
| return add( |
| g, |
| mul(g, norm, symbolic_helper._unsqueeze_helper(g, weight, axes)), |
| symbolic_helper._unsqueeze_helper(g, bias, axes), |
| ) |
| |
| |
| @_onnx_symbolic("aten::_weight_norm") |
| @symbolic_helper.parse_args("v", "v", "i") |
| @_beartype.beartype |
| def _weight_norm(g: jit_utils.GraphContext, weight_v, weight_g, dim): |
| rank = symbolic_helper._get_tensor_rank(weight_v) |
| if rank is not None: |
| # W = g * ((v) / ||v||) |
| # Compute norm_except_dim for l2 norm. dim = None means over all dims |
| # torch's weight_norm module sets dim = -1 if it's None. |
| # This conflicts the logic for negative axes to access dims backwards |
| # TODO: Might need a fix in torch group_norm module |
| axes = list(range(rank)) |
| if dim is not None: |
| if dim < -1: |
| dim += rank |
| if dim != -1: |
| axes.remove(dim) |
| norm_v = norm(g, weight_v, 2, axes, 1) |
| div = g.op("Div", weight_v, norm_v) |
| return g.op("Mul", div, weight_g) |
| if symbolic_helper.is_caffe2_aten_fallback(): |
| return g.at("_weight_norm", weight_v, weight_g, dim_i=dim) |
| |
| raise errors.SymbolicValueError( |
| "Unsupported: ONNX export of _weight_norm for tensor of unknown rank.", |
| weight_v, |
| ) |
| |
| |
| @_onnx_symbolic("aten::dim") |
| @_beartype.beartype |
| def dim(g: jit_utils.GraphContext, self): |
| """Implement the dim functionality available for a pytorch tensor in ONNX""" |
| # ONNX does not support dim directly in this opset so we can use 2 ops to get the info |
| shape = g.op("Shape", self) |
| return g.op("Size", shape) |
| |
| |
| @_onnx_symbolic("aten::__contains_") |
| @_beartype.beartype |
| def __contains_(g: jit_utils.GraphContext, self, element): |
| unpacked_list = symbolic_helper._unpack_list(self) |
| if all( |
| symbolic_helper._is_constant(x) for x in unpacked_list |
| ) and symbolic_helper._is_constant(element): |
| return g.op( |
| "Constant", |
| value_t=torch.tensor( |
| symbolic_helper._node_get(element.node(), "value") |
| in (symbolic_helper._node_get(x.node(), "value") for x in unpacked_list) |
| ), |
| ) |
| |
| raise errors.SymbolicValueError( |
| "Unsupported: ONNX export of __contains__ for non-constant list or element.", |
| self, |
| ) |
| |
| |
| @_onnx_symbolic("aten::__getitem_") |
| @_beartype.beartype |
| def __getitem_(g: jit_utils.GraphContext, self, i): |
| return select(g, self, g.op("Constant", value_t=torch.tensor([0])), i) |
| |
| |
| @_onnx_symbolic("aten::item") |
| @_beartype.beartype |
| def item(g: jit_utils.GraphContext, self): |
| return self |
| |
| |
| @_onnx_symbolic("aten::take") |
| @_beartype.beartype |
| def take(g: jit_utils.GraphContext, self, index): |
| self_flattened = symbolic_helper._reshape_helper( |
| g, self, g.op("Constant", value_t=torch.tensor([-1], dtype=torch.int64)) |
| ) |
| out = index_select(g, self_flattened, 0, index) |
| out = reshape_as(g, out, index) |
| return out |
| |
| |
| @_beartype.beartype |
| def _kl_div_log_target_impl(g: jit_utils.GraphContext, input, target): |
| diff_ = sub(g, target, input) |
| exp_ = exp(g, target) |
| output = mul(g, exp_, diff_) |
| return output |
| |
| |
| @_beartype.beartype |
| def _kl_div_non_log_target_impl(g: jit_utils.GraphContext, input, target): |
| log_ = log(g, target) |
| diff_ = sub(g, log_, input) |
| output_pos = mul(g, target, diff_) |
| zeros_ = zeros_like(g, output_pos) |
| mask_ = gt(g, target, g.op("Constant", value_t=torch.tensor(0))) |
| output = where(g, mask_, output_pos, zeros_) |
| return output |
| |
| |
| @_onnx_symbolic("aten::kl_div") |
| @symbolic_helper.parse_args("v", "v", "i", "b") |
| @_beartype.beartype |
| def kl_div(g: jit_utils.GraphContext, input, target, reduction, log_target): |
| if log_target: |
| output = _kl_div_log_target_impl(g, input, target) |
| else: |
| output = _kl_div_non_log_target_impl(g, input, target) |
| |
| if reduction == 0: |
| return output |
| elif reduction == 1: |
| return g.op("ReduceMean", output, keepdims_i=0) |
| elif reduction == 2: |
| return symbolic_helper._reducesum_helper(g, output, keepdims_i=0) |
| else: |
| return symbolic_helper._onnx_unsupported( |
| "kl_div with reduction other than none, mean, or sum.", input |
| ) |
| |
| |
| @_onnx_symbolic("aten::mse_loss") |
| @symbolic_helper.parse_args("v", "v", "i") |
| @_beartype.beartype |
| def mse_loss(g: jit_utils.GraphContext, input, target, reduction): |
| output = mul(g, sub(g, input, target), sub(g, input, target)) |
| if reduction == 0: |
| return output |
| elif reduction == 1: |
| return g.op("ReduceMean", output, keepdims_i=0) |
| elif reduction == 2: |
| return symbolic_helper._reducesum_helper(g, output, keepdims_i=0) |
| else: |
| return symbolic_helper._onnx_unsupported( |
| "mse_loss with reduction other than none, mean, or sum.", input |
| ) |
| |
| |
| @_onnx_symbolic("aten::as_strided") |
| @symbolic_helper.quantized_args(True) |
| @symbolic_helper.parse_args("v", "v", "is", "i") |
| @_beartype.beartype |
| def as_strided(g: jit_utils.GraphContext, self, sizes, strides, offset=None): |
| sizes = symbolic_helper._maybe_get_const(sizes, "is") |
| rank = len(strides) |
| self_1d = symbolic_helper._reshape_helper( |
| g, self, g.op("Constant", value_t=torch.tensor([-1], dtype=torch.int64)) |
| ) |
| ind: Optional[torch.Tensor] |
| if not symbolic_helper._is_value(sizes): |
| ind = torch.tensor([0], dtype=torch.long) |
| for i, (size, stride) in enumerate(zip(sizes, strides)): |
| r_size = [1] * rank |
| r_size[i] = -1 |
| ind = ind + torch.arange(size).view(r_size) * stride |
| if offset: |
| ind = ind + offset |
| return g.op("Gather", self_1d, g.op("Constant", value_t=ind)) |
| else: |
| ind = None |
| for i, stride in enumerate(strides): |
| r_size = [1] * rank |
| r_size[i] = -1 |
| size = select( |
| g, |
| sizes, |
| g.op("Constant", value_t=torch.tensor([0])), |
| g.op("Constant", value_t=torch.tensor(i)), |
| ) |
| tmp_ind = symbolic_helper._reshape_helper( |
| g, |
| arange(g, size, 4, None, None, None), |
| g.op("Constant", value_t=torch.tensor(r_size)), |
| ) |
| tmp_ind = g.op( |
| "Mul", tmp_ind, g.op("Constant", value_t=torch.tensor([stride])) |
| ) |
| if ind is None: |
| ind = tmp_ind |
| else: |
| ind = g.op("Add", ind, tmp_ind) |
| if offset: |
| ind = g.op("Add", ind, g.op("Constant", torch.tensor([offset]))) |
| return g.op("Gather", self_1d, ind) |
| |
| |
| @_onnx_symbolic("aten::__derive_index") |
| @_beartype.beartype |
| def __derive_index(g: jit_utils.GraphContext, index, start, step): |
| return g.op("Add", start, g.op("Mul", index, step)) |
| |
| |
| @_onnx_symbolic("aten::__range_length") |
| # Source code for aten op can be found here: pytorch/torch/csrc/jit/runtime/register_prim_ops.cpp |
| # if (step > 0 && lo < hi) { |
| # push(stack, 1 + (hi - 1 - lo) / step); |
| # } else if (step < 0 && lo > hi) { |
| # push(stack, 1 + (lo - 1 - hi) / (0 - step)); |
| # } else { |
| # push(stack, 0); |
| # } |
| @_beartype.beartype |
| def __range_length(g: jit_utils.GraphContext, lo, hi, step): |
| sub = g.op("Sub", hi, lo) |
| div = g.op("Ceil", true_divide(g, sub, step)) |
| return g.op("Cast", div, to_i=_C_onnx.TensorProtoDataType.INT64) |
| |
| |
| @_onnx_symbolic("aten::linear") |
| @_beartype.beartype |
| def linear(g: jit_utils.GraphContext, input, weight, bias): |
| rank = symbolic_helper._get_tensor_rank(input) |
| weight = t(g, weight) |
| if rank == 2 and not bias.node().mustBeNone(): |
| alpha = g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64)) |
| beta = g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64)) |
| output = addmm(g, bias, input, weight, alpha, beta) |
| else: |
| output = matmul(g, input, weight) |
| if not bias.node().mustBeNone(): |
| output = add(g, bias, output) |
| |
| return output |
| |
| |
| @_onnx_symbolic("aten::hann_window") |
| @symbolic_helper.parse_args("v", "b", "i", "v", "v", "v", "v") |
| @_beartype.beartype |
| def hann_window( |
| g: jit_utils.GraphContext, |
| window_length, |
| periodic=True, |
| dtype: Optional[int] = None, |
| layout=None, |
| device=None, |
| pin_memory=None, |
| requires_grad=False, |
| ): |
| if dtype is None: |
| dtype_ = torch.get_default_dtype() |
| if not dtype_ or not dtype_.is_floating_point: |
| dtype_ = torch.float |
| scalar_type = _type_utils.JitScalarType.from_dtype(dtype_) |
| else: |
| scalar_type = _type_utils.JitScalarType(dtype) |
| |
| n_array = arange(g, window_length, 4, None, None, None) |
| output = g.op("Cast", n_array, to_i=_C_onnx.TensorProtoDataType.FLOAT) |
| output = mul( |
| g, g.op("Constant", value_t=torch.tensor(math.pi, dtype=torch.float)), output |
| ) |
| |
| if periodic is False: |
| window_length = sub( |
| g, window_length, g.op("Constant", value_t=torch.tensor(1, dtype=torch.int)) |
| ) |
| output = div(g, output, window_length) |
| output = g.op( |
| "Cast", |
| square(g, sin(g, output)), |
| to_i=scalar_type.onnx_type(), |
| ) |
| |
| return output |
| |
| |
| @_onnx_symbolic("aten::mv") |
| @_beartype.beartype |
| def mv(g: jit_utils.GraphContext, self, vec): |
| return matmul(g, self, vec) |
| |
| |
| @_onnx_symbolic("aten::dot") |
| @_beartype.beartype |
| def dot(g: jit_utils.GraphContext, self, other): |
| return matmul(g, self, other) |
| |
| |
| @_onnx_symbolic("aten::movedim") |
| @symbolic_helper.parse_args("v", "t", "t") |
| @_beartype.beartype |
| def movedim(g: jit_utils.GraphContext, self, source, destination): |
| # This is a pythonic implementation mostly taken from aten/src/ATen/native/TensorShape.cpp::movedim |
| source = source.view(-1) |
| destination = destination.view(-1) |
| |
| assert source.size() == destination.size() |
| |
| if (source == destination).all(): |
| return self |
| |
| self_rank = symbolic_helper._get_tensor_rank(self) |
| assert self_rank is not None |
| |
| perm = list(range(self_rank)) |
| |
| src_dims = perm.copy() |
| dst_dims = perm.copy() |
| |
| for src, dst in zip(source.tolist(), destination.tolist()): |
| perm[dst] = src |
| src_dims[src] = -1 |
| dst_dims[dst] = -1 |
| |
| src_dims = [dim for dim in src_dims if dim != -1] |
| dst_dims = [dim for dim in dst_dims if dim != -1] |
| |
| for src, dst in zip(src_dims, dst_dims): |
| perm[dst] = src |
| |
| return g.op("Transpose", self, perm_i=perm) |
| |
| |
| @_onnx_symbolic("aten::fill") |
| @symbolic_helper.parse_args("v", "v") |
| @_beartype.beartype |
| def fill(g: jit_utils.GraphContext, self, value): |
| scalar_type = _type_utils.JitScalarType.from_value( |
| self, _type_utils.JitScalarType.FLOAT |
| ) |
| return full_like(g, self, value, scalar_type) |
| |
| |
| @_onnx_symbolic("aten::index_add") |
| @_beartype.beartype |
| def index_add(g: jit_utils.GraphContext, self, dim, index, other, alpha=None): |
| warnings.warn( |
| "Warning: ONNX export does not support duplicated values in 'index' field, " |
| + "this will cause the ONNX model to be incorrect." |
| ) |
| |
| # ONNX does not support "alpha" argument, unlike aten index_add |
| # See: https://github.com/pytorch/pytorch/pull/65993#issuecomment-953151102 for more context |
| if alpha and symbolic_helper._scalar(symbolic_helper._maybe_get_scalar(alpha)) != 1: |
| return symbolic_helper._unimplemented("index_add", "alpha != 1", self) |
| |
| dim = symbolic_helper._maybe_get_const(dim, "i") |
| if dim is None: |
| raise errors.SymbolicValueError( |
| "ONNX export does NOT support exporting 'index_add_()' function with " |
| "unknown 'dim' value.", |
| self, |
| ) |
| |
| self_dim_rank = symbolic_helper._get_tensor_rank(self) |
| other_dim_rank = symbolic_helper._get_tensor_rank(other) |
| |
| if self_dim_rank is None or other_dim_rank is None: |
| raise errors.SymbolicValueError( |
| "ONNX export does NOT support exporting 'index_add_()' function while " |
| "the rank of self tensor or tensor to be added is unknown.", |
| self, |
| ) |
| |
| if other_dim_rank != self_dim_rank: |
| delta = self_dim_rank - other_dim_rank |
| for i in range(delta): |
| other = symbolic_helper._unsqueeze_helper( |
| g, other, [symbolic_helper._get_tensor_rank(other)] |
| ) |
| |
| other_dim_size = symbolic_helper._get_tensor_dim_size(other, dim) |
| self_dim_size = symbolic_helper._get_tensor_dim_size(self, dim) |
| |
| if (other_dim_size is not None) and (self_dim_size is not None): |
| if other_dim_size > self_dim_size: |
| raise errors.SymbolicValueError( |
| "ONNX export does not support exporting 'index_add_()' function with " |
| "duplicated values in 'index' parameter yet.", |
| self, |
| ) |
| |
| # Construct a new shape. It's almost as same as self except the size of the 'dim' |
| # dimension is 1, so that we can expand other dimensions as expected. |
| new_shape_axes = list(range(self_dim_rank)) |
| new_shape_starts = [0 for i in range(self_dim_rank)] |
| new_shape_ends = [sys.maxsize if (i != dim) else 1 for i in range(self_dim_rank)] |
| |
| new_shape = symbolic_helper._slice_helper( |
| g, self, axes=new_shape_axes, starts=new_shape_starts, ends=new_shape_ends |
| ) |
| other = expand_as(g, other, new_shape) |
| |
| for i in range(dim): |
| index = symbolic_helper._unsqueeze_helper(g, index, [0]) |
| |
| for i in range(self_dim_rank - dim - 1): |
| index = symbolic_helper._unsqueeze_helper( |
| g, index, [symbolic_helper._get_tensor_rank(index)] |
| ) |
| |
| return scatter_add(g, self, dim, expand_as(g, index, other), other) |
| |
| |
| @_onnx_symbolic("aten::roll") |
| @symbolic_helper.parse_args("v", "is", "is") |
| @_beartype.beartype |
| def roll(g: jit_utils.GraphContext, self, shifts, dims): |
| assert len(shifts) == len(dims) |
| |
| result = self |
| for i in range(len(shifts)): |
| shapes = [] |
| shape = symbolic_helper._slice_helper( |
| g, result, axes=[dims[i]], starts=[-shifts[i]], ends=[sys.maxsize] |
| ) |
| shapes.append(shape) |
| shape = symbolic_helper._slice_helper( |
| g, result, axes=[dims[i]], starts=[0], ends=[-shifts[i]] |
| ) |
| shapes.append(shape) |
| result = g.op("Concat", *shapes, axis_i=dims[i]) |
| |
| return result |
| |
| |
| @_onnx_symbolic("aten::cross") |
| @symbolic_helper.parse_args("v", "v", "i") |
| @_beartype.beartype |
| def cross(g: jit_utils.GraphContext, input, other, dim=None): |
| dim = symbolic_helper._get_dim_for_cross(input, dim) |
| # If we have two tensors such that |
| # A = [a, b, c], B = [d, e, f], we permute the tensor such that we have |
| # After first roll, |
| # A' = [b, c, a], B' = [f, d, e], so that we calculate (b*f, c*d, a*e) |
| roll_x_1 = roll(g, input, [2], [dim]) |
| roll_y_1 = roll(g, other, [1], [dim]) |
| # After second roll, |
| # A' = [c, a, b], B' = [e, f, d], so that we calculate (c*e, a*f, b*d) |
| roll_x_2 = roll(g, input, [1], [dim]) |
| roll_y_2 = roll(g, other, [2], [dim]) |
| # cross product is calculated as |
| # result = [(b*f - c*e), (c*d - a*f), (a*e - b*d)] |
| return sub(g, mul(g, roll_x_1, roll_y_1), mul(g, roll_x_2, roll_y_2)) |
| |
| |
| @_onnx_symbolic("aten::cdist") |
| @_beartype.beartype |
| def cdist( |
| g: jit_utils.GraphContext, |
| x1, |
| x2, |
| p=2.0, |
| compute_mode="use_mm_for_euclid_dist_if_necessary", |
| ): |
| # X1.shape = (B * P * D), X2.shape = (B * R * D) |
| # In order to respect numpy style broadcasting as demonstrated in |
| # https://github.com/onnx/onnx/blob/main/docs/Broadcasting.md |
| # we unsqueeze both input tensors |
| # Currently we ignore the 'compute_mode' variable as we use default to |
| # using matrix multiplication to calculate the euclidean distance |
| rank = symbolic_helper._get_tensor_rank(x1) |
| assert rank is not None |
| broadcasted_x1 = symbolic_helper._unsqueeze_helper(g, x1, [rank - 1]) |
| broadcasted_x2 = symbolic_helper._unsqueeze_helper(g, x2, [rank - 2]) |
| return pairwise_distance( |
| g, broadcasted_x1, broadcasted_x2, p, eps=1e-06, keepdim=False |
| ) |
| |
| |
| @_onnx_symbolic("aten::lerp") |
| @_beartype.beartype |
| def lerp(g: jit_utils.GraphContext, self, end, weight): |
| # Conditional for better numeric. This has been discussed in |
| # https://github.com/pytorch/pytorch/pull/18871 |
| diff = g.op("Sub", end, self) |
| return where( |
| g, |
| g.op("Less", weight, g.op("Constant", value_t=torch.tensor(0.5))), |
| g.op("Add", self, g.op("Mul", weight, diff)), |
| g.op( |
| "Sub", |
| end, |
| g.op( |
| "Mul", |
| diff, |
| g.op("Sub", g.op("Constant", value_t=torch.tensor(1.0)), weight), |
| ), |
| ), |
| ) |
| |
| |
| @_onnx_symbolic("aten::broadcast_tensors") |
| @_beartype.beartype |
| def broadcast_tensors(g: jit_utils.GraphContext, self): |
| all_tensors = symbolic_helper._unpack_list(self) |
| t_with_final_shape = zeros_like(g, all_tensors[0]) |
| |
| # Add operator supports multidirectional broadcasting. So we leverage this function |
| # to infer the final shape generated by the broadcast. |
| for t in all_tensors: |
| t_with_final_shape = add(g, t_with_final_shape, t) |
| |
| t_list = [expand_as(g, t, t_with_final_shape) for t in all_tensors] |
| return g.op("prim::ListConstruct", *t_list) |
| |
| |
| @_onnx_symbolic("aten::is_pinned") |
| def is_pinned(g: jit_utils.GraphContext, self, device=None): |
| # Unused by ONNX. |
| return None |
| |
| |
| @_onnx_symbolic("prim::ConstantSplit") |
| @_beartype.beartype |
| def prim_constant_split(g: jit_utils.GraphContext, self, split_size, dim): |
| size = symbolic_helper._get_tensor_dim_size(self, dim) |
| if size is None: |
| return symbolic_helper._unimplemented( |
| "prim::ConstantSplit", "unknown dimension size", self |
| ) |
| splits = [split_size] * (size // split_size) |
| leftover = size % split_size |
| if leftover: |
| splits.append(leftover) |
| return g.op("Split", self, split_i=splits, axis_i=dim, outputs=len(splits)) |
| |
| |
| # TODO: It would be better to export this as a chunk directly, as this is |
| # less sensitive to changes in input size. |
| # TODO: Once we have proper scoping, stop reimplementing chunk, delete this |
| # method, and use the desugared version |
| @_onnx_symbolic("prim::ConstantChunk") |
| @_beartype.beartype |
| def prim_constant_chunk(g: jit_utils.GraphContext, self, chunks, dim): |
| dim_size = symbolic_helper._get_tensor_dim_size(self, dim) |
| if dim_size is None: |
| return symbolic_helper._unimplemented( |
| "prim::ConstantChunk", "unknown dimension size", self |
| ) |
| split_size = (dim_size + chunks - 1) // chunks |
| return prim_constant_split(g, self, split_size, dim) |
| |
| |
| @_onnx_symbolic("prim::shape") |
| @_beartype.beartype |
| def prim_shape(g: jit_utils.GraphContext, self): |
| return g.op("Shape", self) |
| |
| |
| @_onnx_symbolic("prim::max") |
| @_beartype.beartype |
| def prim_max(g: jit_utils.GraphContext, self, other): |
| return _op_with_optional_float_cast(g, "Max", self, other, opset_before=12) |
| |
| |
| @_onnx_symbolic("prim::min") |
| @_beartype.beartype |
| def prim_min(g: jit_utils.GraphContext, self, other=None): |
| if not other: |
| if symbolic_helper._is_packed_list(self): |
| self = stack(g, self, g.op("Constant", value_t=torch.tensor([0]))) |
| return min(g, self) |
| return min(g, self, other) |
| |
| |
| @_onnx_symbolic("prim::data") |
| @_beartype.beartype |
| def prim_data(g: jit_utils.GraphContext, self): |
| return self |
| |
| |
| @_onnx_symbolic("prim::layout") |
| def prim_layout(g: jit_utils.GraphContext, self): |
| # Always return 'torch.strided'. Other layout types are not supported by JIT 'TensorType'. |
| # Layout class defined in 'c10/core/Layout.h'. |
| return g.op("Constant", value_t=torch.tensor(0)) |
| |
| |
| @_onnx_symbolic("prim::ListConstruct") |
| @_beartype.beartype |
| def prim_list_construct(g: jit_utils.GraphContext, *inputs, **kwargs): |
| return None |
| |
| |
| @_onnx_symbolic("prim::ListUnpack") |
| @_beartype.beartype |
| def prim_list_unpack( |
| g: jit_utils.GraphContext, *inputs, **kwargs |
| ) -> Optional[List[_C.Value]]: |
| if len(inputs) == 1 and inputs[0].node().kind() == "prim::ListConstruct": |
| # Cancel the previous node if it is ListConstruct by returning its inputs |
| # TODO(justinchuby): Use a public method in the helper module |
| return symbolic_helper._unpack_list(inputs[0]) |
| |
| return None |
| |
| |
| @_onnx_symbolic("prim::TupleConstruct") |
| @_beartype.beartype |
| def prim_tuple_construct(g: jit_utils.GraphContext, *inputs, **kwargs): |
| return None |
| |
| |
| @_onnx_symbolic("prim::Uninitialized") |
| @_beartype.beartype |
| def prim_uninitialized(g: jit_utils.GraphContext, *inputs, **kwargs): |
| return None |
| |
| |
| # exists to refine the type of the Value |
| # if x is an optional Tensor, unchecked_cast will cast |
| # x to Tensor, so the rest of the graph knows that x is a Tensor |
| # this doesn't do anything in runtime and is a noop in ONNX |
| @_onnx_symbolic("prim::unchecked_cast") |
| @_beartype.beartype |
| def prim_unchecked_cast(g: jit_utils.GraphContext, self): |
| return self |
| |
| |
| @_onnx_symbolic("prim::dtype") |
| @_beartype.beartype |
| def prim_dtype(g: jit_utils.GraphContext, self): |
| scalar_type = symbolic_helper._try_get_scalar_type(self) |
| if scalar_type is None: |
| scalar_type = _type_utils.JitScalarType.FLOAT |
| # This node records a torch dtype as int |
| return g.op("Constant", value_t=torch.tensor(scalar_type)) |
| |
| |
| @_onnx_symbolic("prim::tolist") |
| @_beartype.beartype |
| def prim_tolist(g: jit_utils.GraphContext, input, dim_val, elem_ty_val): |
| """tolist is currently supported only for 1D input tensors. |
| |
| dim_val and elem_ty_val represent dimension and type annotations |
| that need to match dimension and type of the input tensor. |
| """ |
| dim = symbolic_helper._maybe_get_const(dim_val, "i") |
| if dim > 1: |
| return symbolic_helper._unimplemented("prim::tolist", "dim_val > 1", input) |
| return input |
| |
| |
| # ----------------------------------------------------------------------------- |
| # Symbolic functions that need extra context |
| # ----------------------------------------------------------------------------- |
| @_onnx_symbolic("prim::device") |
| @_beartype.beartype |
| def prim_device(g: jit_utils.GraphContext, *inputs, **kwargs) -> None: |
| output_type = g.original_node.output().type() |
| if isinstance(output_type, _C.DeviceObjType): |
| return None |
| |
| return symbolic_helper._unimplemented( |
| "prim::device", |
| f"output type should be 'DeviceObjType', not '{output_type.kind()}'", |
| g.original_node.output(), |
| ) |
| |
| |
| @_onnx_symbolic("prim::Loop") |
| @_beartype.beartype |
| def prim_loop(g: jit_utils.GraphContext, *inputs, **attrs) -> List[_C.Value]: |
| node = g.original_node |
| env = g.env |
| params_dict = g.params_dict |
| |
| operator_export_type = GLOBALS.operator_export_type |
| opset_version = GLOBALS.export_onnx_opset_version |
| |
| old_blocks = tuple(node.blocks()) |
| new_op_outputs, new_block_contexts, new_node = jit_utils.add_op_with_blocks( |
| g, "Loop", *inputs, outputs=node.outputsSize(), n_blocks=len(old_blocks) |
| ) |
| |
| for old_block, new_block_context in zip(old_blocks, new_block_contexts): |
| # Copy input metadata to subblock |
| # |
| # prim::Loop(iter, cond, input_1, ..., input_n) |
| # block0(iter, input_1, ..., input_n) |
| # |
| # For `Loop` node, copy metadata for `iter`, `input_1`, ..., `input_n`. |
| for i, b_in in enumerate(old_block.inputs()): |
| if i == 0 and i < len(inputs): |
| b_in.setType(inputs[i].type()) |
| # For optional block inputs, they may switch between None not-None inside |
| # the loop body, so if the loop input is not optional, the block input may |
| # still need to be optional. |
| if ( |
| i > 0 |
| and (i + 1) < len(inputs) |
| and not isinstance(b_in.type(), _C.OptionalType) |
| ): |
| b_in.setType(inputs[i + 1].type()) |
| torch._C._jit_pass_onnx_block( |
| old_block, |
| new_block_context.block, |
| operator_export_type, |
| env, |
| False, |
| ) |
| fixed_outputs = torch._C._jit_pass_fixup_onnx_controlflow_node( |
| new_node, opset_version |
| ) |
| # Run shape type inference for Loop after subblock is converted. |
| if GLOBALS.onnx_shape_inference: |
| torch._C._jit_pass_onnx_node_shape_type_inference( |
| new_node, params_dict, opset_version |
| ) |
| return fixed_outputs |
| |
| |
| @_onnx_symbolic("prim::If") |
| @_beartype.beartype |
| def prim_if(g: jit_utils.GraphContext, *inputs, **attrs) -> List[_C.Value]: |
| n = g.original_node |
| block = g.block |
| env = g.env |
| params_dict = g.params_dict |
| |
| operator_export_type = GLOBALS.operator_export_type |
| opset_version = GLOBALS.export_onnx_opset_version |
| |
| static_if = inputs[0].node().kind() == "onnx::Constant" |
| if static_if: |
| # Fold static if |
| # |
| # The torch IR |
| # graph(%embedding_matrix.1 : Float(10, 15, strides=[15, 1], requires_grad=0, device=cpu), |
| # %input.1 : Long(6, strides=[1], requires_grad=0, device=cpu), ... |
| # %65 : Bool(requires_grad=0, device=cpu) = prim::Constant[value={0}]() |
| # %21 : Long(device=cpu) = aten::eq(%20, %64) |
| # %22 : Long(device=cpu) = prim::If(%21) |
| # block0(): |
| # %23 : Long(device=cpu) = aten::is_floating_point(%input.1) |
| # -> (%23) |
| # block1(): |
| # -> (%65) |
| # %input.53 : Tensor, %weight : Tensor = prim::If(%22) |
| # block0(): |
| # -> (%embedding_matrix.1, %input.1) |
| # block1(): |
| # -> (%input.1, %embedding_matrix.1) |
| # %26 : int[] = aten::size(%input.53) |
| # |
| # The converted ONNX graph |
| # %10 : Bool(device=cpu) = onnx::Constant[value={0}]() |
| # %14 : Bool(device=cpu) = onnx::Equal(%13, %8) |
| # %15 : Bool(requires_grad=0, device=cpu) = onnx::Constant[value={0}]() |
| # %16 : Long(1, strides=[1], device=cpu) = onnx::Shape(%input.1) |
| input_flag = symbolic_helper._node_get(inputs[0].node(), "value").tolist() |
| const_value = ( |
| all(input_flag) if isinstance(input_flag, list) else bool(input_flag) |
| ) |
| block_idx = 0 if const_value else 1 |
| current_b = list(n.blocks())[block_idx] |
| env = torch._C._jit_pass_onnx_block( |
| current_b, |
| block, |
| operator_export_type, |
| env, |
| True, |
| ) |
| if_output_list = list(n.outputs()) |
| current_b_list = list(current_b.outputs()) |
| |
| final_b_list = [] |
| for idx in range(len(if_output_list)): |
| if current_b_list[idx] not in env: |
| raise errors.SymbolicValueError( |
| f"The sub block ATen output {current_b_list[idx]} is not in env.", |
| current_b_list[idx], |
| ) # type:ignore[operator] |
| onnx_b = env[current_b_list[idx]] |
| final_b_list.append(onnx_b) |
| return final_b_list |
| else: |
| old_blocks = tuple(n.blocks()) |
| new_op_outputs, new_block_contexts, new_node = jit_utils.add_op_with_blocks( |
| g, "If", *inputs, outputs=n.outputsSize(), n_blocks=len(old_blocks) |
| ) |
| |
| for old_block, new_block_context in zip(old_blocks, new_block_contexts): |
| torch._C._jit_pass_onnx_block( |
| old_block, |
| new_block_context.block, |
| operator_export_type, |
| env, |
| False, |
| ) |
| fixed_outputs = torch._C._jit_pass_fixup_onnx_controlflow_node( |
| new_node, opset_version |
| ) |
| # Run shape type inference for If after subblock is converted. |
| if GLOBALS.onnx_shape_inference: |
| torch._C._jit_pass_onnx_node_shape_type_inference( |
| new_node, params_dict, opset_version |
| ) |
| return fixed_outputs |
| |
| |
| @_onnx_symbolic("prim::Constant") |
| @_beartype.beartype |
| def prim_constant(g: jit_utils.GraphContext, *inputs, **attrs): |
| node = g.original_node |
| |
| if node.mustBeNone(): |
| return None |
| # This must go before checking for string values, because some device constants |
| # have string values, but we want to keep them as unconverted Device types so |
| # that eq() can work on them. |
| if isinstance(node.output().type(), _C.DeviceObjType): |
| return None |
| if node.kindOf("value") == "t": |
| return g.op("Constant", value_t=symbolic_helper._node_get(node, "value")) |
| if node.kindOf("value") == "s": |
| return g.op("Constant", value_s=symbolic_helper._node_get(node, "value")) |
| if node.output().type().isSubtypeOf( |
| _C.ListType.ofInts() |
| ) or node.output().type().isSubtypeOf(_C.ListType.ofFloats()): |
| return g.op( |
| "Constant", value_t=torch.tensor(symbolic_helper._node_get(node, "value")) |
| ) |
| if node.output().type().isSubtypeOf(_C.ListType.ofStrings()): |
| str_constants = [ |
| g.op("Constant", value_s=s) |
| for s in symbolic_helper._node_get(node, "value") |
| ] |
| return g.op("prim::ListConstruct", *str_constants) |
| |
| raise errors.SymbolicValueError( |
| f"Unsupported prim::Constant kind: '{node.kindOf('value')}'. " |
| f"Please send a bug report at {_constants.PYTORCH_GITHUB_ISSUES_URL}.", |
| node.output(), |
| ) |
| |
| |
| @_onnx_symbolic("prim::type") |
| @_beartype.beartype |
| def prim_type(g: jit_utils.GraphContext, device_value: _C.Value, *args, **kwargs): |
| if device_value.node().kind() == "prim::device": |
| device = jit_utils.get_device_from_value(device_value.node().input()) |
| if device is not None: |
| return g.op("Constant", value_s=str(device)) |
| |
| return symbolic_helper._unimplemented( |
| "prim::type", |
| "Device type cannot be statically determined.", |
| device_value, |
| ) |
| |
| |
| @_onnx_symbolic("onnx::Placeholder") |
| @_beartype.beartype |
| def onnx_placeholder(g: jit_utils.GraphContext, *inputs, **attrs): |
| node = g.original_node |
| block = g.block |
| env = g.env |
| |
| return torch._C._jit_onnx_convert_pattern_from_subblock(block, node, env) |
| |
| |
| @_onnx_symbolic("aten::resolve_conj") |
| @_onnx_symbolic("aten::resolve_neg") |
| @_beartype.beartype |
| def noop_complex_operators(g: jit_utils.GraphContext, input: _C.Value): |
| # ONNX does not have operators to *directly* manipulate real/imaginary components |
| # However, a few torch APIs (e.g. .tolist()) use complex operations when input is real, |
| # which results in failures due to missing operators for complex numbers |
| |
| # `aten::resolve_conj` and `aten::resolve_neg` can safely be implemented as no-op |
| return input |
| |
| |
| @_onnx_symbolic("aten::_conj") |
| @_onnx_symbolic("aten::conj_physical") |
| @_beartype.beartype |
| def unsupported_complex_operators(g: jit_utils.GraphContext, input: _C.Value): |
| # ONNX does not have operators to *directly* manipulate real/imaginary components |
| # However, a few torch APIs (e.g. .tolist()) use complex operations when input is real, |
| # which results in failures due to missing operators for complex numbers |
| |
| # While `aten::_conj` and `aten::conj_physical` raise exception when input is complex |
| if symbolic_helper.is_complex_value(input): |
| # FIXME(justinchuby): report correct name for symbolic being executed |
| return symbolic_helper._onnx_unsupported( |
| "aten::_conj, aten::conj_physical", |
| input, |
| ) |
| |
| # they can safely be implemented as no-op for real numbers only |
| return noop_complex_operators(g, input) |
| |
| |
| @_onnx_symbolic("aten::logit") |
| @_beartype.beartype |
| def logit(g: jit_utils.GraphContext, self: torch._C.Value, eps: torch._C.Value): |
| one = g.op("Constant", value_t=torch.tensor(1.0)) |
| |
| if not symbolic_helper._is_none(eps): |
| eps = g.op( |
| "Cast", eps, to_i=_type_utils.JitScalarType.from_value(self).onnx_type() |
| ) |
| one_sub_eps = g.op("Sub", one, eps) |
| self_less_equal_one_sub_eps = g.op("Greater", one_sub_eps, self) |
| temporary_self = g.op("Where", self_less_equal_one_sub_eps, self, one_sub_eps) |
| |
| temporary_self_less_eps = g.op("Less", temporary_self, eps) |
| z = g.op("Where", temporary_self_less_eps, eps, temporary_self) |
| else: |
| z = self |
| |
| sub = g.op("Sub", one, z) |
| div = g.op("Div", z, sub) |
| return g.op("Log", div) |