| """This file exports ONNX ops for opset 16. |
| |
| Note [ONNX Operators that are added/updated in opset 16] |
| |
| ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ |
| https://github.com/onnx/onnx/blob/main/docs/Changelog.md#version-16-of-the-default-onnx-operator-set |
| New operators: |
| GridSample https://github.com/onnx/onnx/pull/3557 |
| |
| Updated operators: |
| Identity |
| If |
| LeakyRelu |
| Loop |
| PRelu |
| RoiAlign |
| Scan |
| ScatterElements |
| ScatterND |
| Where |
| GreaterOrEqual |
| LessOrEqual |
| """ |
| |
| # EDITING THIS FILE? READ THIS FIRST! |
| # see Note [Edit Symbolic Files] in README.md |
| |
| import functools |
| |
| import torch |
| from torch.nn.functional import ( |
| GRID_SAMPLE_INTERPOLATION_MODES, |
| GRID_SAMPLE_PADDING_MODES, |
| ) |
| from torch.onnx import _type_utils, symbolic_helper |
| from torch.onnx._internal import _beartype, jit_utils, registration |
| |
| _onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=16) |
| |
| |
| # note (mkozuki): Why `grid_sampler` instead of `grid_sample`? |
| # Because `torch.nn.functional.grid_sample` calls `torch.grid_sampler`. |
| @_onnx_symbolic("aten::grid_sampler") |
| @symbolic_helper.parse_args("v", "v", "i", "i", "b") |
| @_beartype.beartype |
| def grid_sampler( |
| g: jit_utils.GraphContext, |
| input, |
| grid, |
| mode_enum, |
| padding_mode_enum, |
| align_corners, |
| ): |
| mode_s = {v: k for k, v in GRID_SAMPLE_INTERPOLATION_MODES.items()}[mode_enum] # type: ignore[call-arg] |
| padding_mode_s = {v: k for k, v in GRID_SAMPLE_PADDING_MODES.items()}[padding_mode_enum] # type: ignore[call-arg] |
| return g.op( |
| "GridSample", |
| input, |
| grid, |
| align_corners_i=int(align_corners), |
| mode_s=mode_s, |
| padding_mode_s=padding_mode_s, |
| ) |
| |
| |
| @_onnx_symbolic("aten::scatter_add") |
| @symbolic_helper.parse_args("v", "i", "v", "v") |
| @_beartype.beartype |
| def scatter_add(g: jit_utils.GraphContext, self, dim, index, src): |
| if symbolic_helper.is_caffe2_aten_fallback(): |
| return g.at("scatter", self, dim, index, src, overload_name="src") |
| |
| src_type = _type_utils.JitScalarType.from_value( |
| src, _type_utils.JitScalarType.UNDEFINED |
| ) |
| src_sizes = symbolic_helper._get_tensor_sizes(src) |
| index_sizes = symbolic_helper._get_tensor_sizes(index) |
| |
| if len(src_sizes) != len(index_sizes): |
| return symbolic_helper._unimplemented( |
| "scatter_add", |
| f"`index` ({index_sizes}) should have the same dimensionality as `src` ({src_sizes})", |
| ) |
| |
| # PyTorch only allows index shape <= src shape, so we can only consider |
| # taking index as subset size to src, like PyTorch does. When sizes for src |
| # and index are not matched or there are dynamic axes, we take index shape to |
| # slice src to accommodate. |
| if src_sizes != index_sizes or None in index_sizes: |
| adjusted_shape = g.op("Shape", index) |
| starts = g.op("Constant", value_t=torch.tensor([0] * len(index_sizes))) |
| src = g.op("Slice", src, starts, adjusted_shape) |
| |
| src = symbolic_helper._maybe_get_scalar(src) |
| if symbolic_helper._is_value(src): |
| return g.op("ScatterElements", self, index, src, axis_i=dim, reduction_s="add") |
| else: |
| # Check if scalar "src" has same type as self (PyTorch allows different |
| # type for scalar src (but not when src is tensor)). If not, insert Cast node. |
| if _type_utils.JitScalarType.from_value(self) != src_type: |
| src = g.op( |
| "Cast", |
| src, |
| to_i=_type_utils.JitScalarType.from_value(self).onnx_type(), |
| ) |
| |
| return g.op( |
| "ScatterElements", |
| self, |
| index, |
| src, |
| axis_i=dim, |
| reduction_s="add", |
| ) |