blob: 5a862e7585884b51f5cef2bab34d5111d5eaa99e [file] [log] [blame]
# Owner(s): ["module: onnx"]
"""Test consistency between the output values of torch.onnx FX exported operators
and torch operators given the same inputs.
Usage:
pytest test/onnx/test_op_consistency.py
To run tests on a specific operator (e.g. torch.ceil):
pytest test/onnx/test_op_consistency.py -k ceil
pytest test/onnx/test_op_consistency.py -k nn_functional_scaled_dot_product_attention
Read more on Running and writing tests:
https://github.com/pytorch/pytorch/wiki/Running-and-writing-tests
Note:
When new ops are supported, please scroll down to modify the EXPECTED_SKIPS_OR_FAILS and
TESTED_OPS lists. See "Modify this section"
"""
from __future__ import annotations
import copy
from typing import Any, Callable, Collection, Optional, Tuple, Union
import onnx_test_common
import parameterized
import torch
from onnx_test_common import skip, xfail
from torch.testing._internal import (
common_device_type,
common_methods_invocations,
common_utils,
)
from torch.testing._internal.opinfo import core as opinfo_core
# Modify this section ##########################################################
# NOTE: Modify this section as more ops are supported. The list should be sorted
# alphabetically.
#
# For example, to add a test for torch.ceil:
# 1. Add "ceil" to TESTED_OPS then run pytest.
# 2. If the test fails, fix the error or add a new entry to EXPECTED_SKIPS_OR_FAILS.
# TODO: Directly modify DecorateInfo in each OpInfo in ob_db when all ops are enabled.
# Ops to be tested for numerical consistency between onnx and pytorch
TESTED_OPS: frozenset[str] = frozenset(
[
"abs",
"acos",
"acosh",
"add",
"addmm",
"all",
"allclose",
"amax",
"amin",
"any",
"arange",
"argmax",
"argmin",
"as_strided",
"asin",
"asinh",
"atan",
"atanh",
"atleast_1d",
"atleast_2d",
"atleast_3d",
"baddbmm",
"bmm",
"broadcast_to",
"cat",
"ceil",
"chunk",
"clamp",
"clamp_max",
"clamp_min",
"clone",
# "col2im", extra opinfo needed
"constant_pad_nd",
"contiguous",
# "copy", copy is not in OPS_DB
"cos",
"cosh",
"cross",
"cumsum",
# "detach", detach is not in OP-TEST-DB
"div",
"dot",
# "empty", non-deterministic
# "empty_like", non-deterministic
# "empty_strided", empty_strided is not in OPS_DB
"eq",
"equal",
"erf",
"exp",
"exp2",
"expand",
"expand_as",
"fill",
"flip",
"floor",
"fmod",
"full",
"full_like",
"gather",
"hstack", # aten::cat is invoked instead
"index_put",
"logit",
"mean",
"native_batch_norm",
# "new_empty", non-deterministic
# "new_empty_strided", non-deterministic
"new_full",
"new_ones",
"new_zeros",
"nn.functional.adaptive_avg_pool1d",
"nn.functional.adaptive_avg_pool2d",
"nn.functional.adaptive_avg_pool3d",
"nn.functional.avg_pool1d",
"nn.functional.avg_pool2d",
"nn.functional.avg_pool3d",
"nn.functional.batch_norm",
"nn.functional.conv1d",
# "nn.functional.conv2d", AssertionError: The values for attribute 'shape' do not match in float32
# "nn.functional.conv3d", extra opinfo needed
# "nn.functional.convolution", extra opinfo needed
"nn.functional.cross_entropy",
"nn.functional.celu",
"nn.functional.dropout",
"nn.functional.elu",
"nn.functional.embedding",
"nn.functional.embedding_bag",
"nn.functional.max_pool1d",
"nn.functional.max_pool2d",
"nn.functional.max_pool3d",
"nn.functional.nll_loss",
# "nn.functional.scaled_dot_product_attention" non-deterministic
"nonzero",
"scatter_add",
"scatter_reduce",
"square",
"stft",
"sum",
"unflatten",
"var_mean",
"vstack", # aten::cat is invoked instead
]
)
COMPLEX_TESTED_OPS = frozenset(
[
"abs",
"stft",
]
)
# NOTE: For ATen signature modifications that will break ONNX export,
# use **xfail_torchlib_forward_compatibility** and **skip_torchlib_forward_compatibility** instead of xfail or skip
# to make the signal apparent for maintainers.
def xfail_torchlib_forward_compatibility(
op_name: str,
variant_name: str = "",
*,
reason: str,
github_issue: str,
opsets: Optional[Collection[Union[int, Callable[[int], bool]]]] = None,
dtypes: Optional[Collection[torch.dtype]] = None,
matcher: Optional[Callable[[Any], bool]] = None,
enabled_if: bool = True,
):
"""Prefer using this (xfail) over skip when possible.
Only skip when the test is not failing consistently.
"""
return xfail(
op_name,
variant_name=variant_name,
reason=f"{reason}. GitHub Issue: {github_issue}",
opsets=opsets,
dtypes=dtypes,
matcher=matcher,
enabled_if=enabled_if,
)
def skip_torchlib_forward_compatibility(
op_name: str,
variant_name: str = "",
*,
reason: str,
github_issue: str,
opsets: Optional[Collection[Union[int, Callable[[int], bool]]]] = None,
dtypes: Optional[Collection[torch.dtype]] = None,
matcher: Optional[Callable[[Any], Any]] = None,
enabled_if: bool = True,
):
"""Prefer using xfail_torchlib_forward_compatibility over this (skip) when possible.
Only skip when the test is not failing consistently.
"""
return skip(
op_name,
variant_name=variant_name,
reason=f"{reason}. GitHub Issue: {github_issue}",
opsets=opsets,
dtypes=dtypes,
matcher=matcher,
enabled_if=enabled_if,
)
# fmt: off
# Turn off black formatting to keep the list compact
# Expected failures for onnx export.
# The list should be sorted alphabetically by op name.
# Q: When should I use fixme vs vs skip vs xfail?
# A: Prefer xfail over skip when possible.
# 2a. If a test is now failing because of xpass, because some previous errors
# are now fixed, removed the corresponding xfail.
# 2b. If a test is not failing consistently, use skip.
EXPECTED_SKIPS_OR_FAILS: Tuple[onnx_test_common.DecorateMeta, ...] = (
xfail(
"add", dtypes=onnx_test_common.BOOL_TYPES,
reason=onnx_test_common.reason_onnx_does_not_support("Add")
),
xfail(
"add",
dtypes=(torch.uint8, torch.int8, torch.int16,),
reason=onnx_test_common.reason_onnx_script_does_not_support(
"Add", "int8, int16, uint8 have type issue."
),
),
xfail(
"addmm", dtypes=onnx_test_common.BOOL_TYPES,
reason=onnx_test_common.reason_onnx_does_not_support("Addmm")
),
xfail_torchlib_forward_compatibility(
"all",
reason=onnx_test_common.reason_onnx_script_does_not_support("aten.all.dims"),
github_issue="https://github.com/microsoft/onnxscript/pull/1084"
),
xfail(
"allclose", dtypes=onnx_test_common.BOOL_TYPES + onnx_test_common.INT_TYPES + onnx_test_common.FLOAT_TYPES,
reason=onnx_test_common.reason_dynamo_does_not_support("Allclose")
),
xfail(
"amax",
dtypes=(torch.int16, *onnx_test_common.BOOL_TYPES),
reason=onnx_test_common.reason_onnx_does_not_support("ReduceMin", "bool, int16"),
),
xfail(
"amin", dtypes=(torch.int16, *onnx_test_common.BOOL_TYPES),
reason=onnx_test_common.reason_dynamo_does_not_support("ReduceMin", "bool, int16")
),
xfail_torchlib_forward_compatibility(
"any",
reason=onnx_test_common.reason_onnx_script_does_not_support("aten.any.dims"),
github_issue="https://github.com/microsoft/onnxscript/pull/1084"
),
xfail(
"arange",
dtypes=(torch.uint8,),
reason=onnx_test_common.reason_onnx_script_does_not_support("Arange", "uint8, int8"),
),
xfail(
"arange",
dtypes=(torch.int16, torch.int32),
reason="AssertionError: The values for attribute 'shape' do not match",
),
xfail(
"argmax",
dtypes=(
torch.int16,
torch.int64,
),
reason=onnx_test_common.reason_onnx_runtime_does_not_support(
"ArgMax", "int16, int64"
),
),
xfail(
"argmin",
dtypes=(
torch.uint8,
torch.int8,
torch.int16,
torch.int64,
),
reason=onnx_test_common.reason_onnx_runtime_does_not_support(
"ArgMin", "uint8, int8, int16, int64"
),
),
skip(
"as_strided",
variant_name="partial_views",
reason="ONNX doesn't have partial view for tensor; [PostInline][ORT] segfaults",
),
xfail(
"baddbmm",
dtypes=(
torch.uint8,
torch.int8,
torch.int16,
),
reason=onnx_test_common.reason_onnx_runtime_does_not_support(
"Matmul", "uint8, int8, int16"
),
),
xfail(
"bmm",
dtypes=(
torch.uint8,
torch.int8,
torch.int16,
),
reason=onnx_test_common.reason_onnx_runtime_does_not_support(
"Matmul", "uint8, int8, int16"
),
),
skip(
"ceil", dtypes=onnx_test_common.BOOL_TYPES + onnx_test_common.INT_TYPES,
reason=onnx_test_common.reason_onnx_does_not_support("Ceil", "bool and int")
),
xfail(
"chunk", dtypes=onnx_test_common.BOOL_TYPES,
reason=onnx_test_common.reason_onnx_runtime_does_not_support("Chunk", "bool")
),
xfail(
"chunk",
dtypes=(torch.uint8, torch.int8, torch.int16, torch.float16,),
reason=onnx_test_common.reason_onnx_runtime_does_not_support(
"Chunk", "uint8, int8, int16, float16"
),
),
xfail(
"clamp",
dtypes=(torch.uint8, torch.int8, torch.int16,),
reason=onnx_test_common.reason_onnx_runtime_does_not_support(
"Max", "uint8, int8, int16"
),
),
xfail(
"clamp_max", dtypes=onnx_test_common.BOOL_TYPES,
reason=onnx_test_common.reason_onnx_script_does_not_support("Clamp_max", "bool")
),
xfail(
"clamp_max",
dtypes=(torch.uint8, torch.int8, torch.int16,),
reason=onnx_test_common.reason_onnx_runtime_does_not_support(
"Max", "uint8, int8, int16"
),
),
xfail(
"clamp_min",
dtypes=(torch.uint8, torch.int8, torch.int16,),
reason=onnx_test_common.reason_onnx_runtime_does_not_support(
"Max", "uint8, int8, int16"
),
),
xfail(
"clamp_min", dtypes=onnx_test_common.BOOL_TYPES,
reason=onnx_test_common.reason_onnx_script_does_not_support("Clamp_min", "bool")
),
xfail(
"constant_pad_nd",
dtypes=(torch.int16,),
reason=onnx_test_common.reason_onnx_runtime_does_not_support(
"Constant_pad_nd", "int16"
),
),
xfail(
"cumsum", dtypes=onnx_test_common.BOOL_TYPES + (torch.uint8, torch.int8, torch.int16,),
reason=onnx_test_common.reason_onnx_does_not_support("Cumsum", "bool, uint8, int8, int16")
),
# See https://github.com/pytorch/pytorch/issues/111454
xfail(
"cumsum", dtypes=(torch.float16,),
reason=onnx_test_common.reason_onnx_runtime_does_not_support("RUNTIME_EXCEPTION : \
Exception during initialization: /onnxruntime_src/onnxruntime/core/framework/\
allocation_planner.cc:230 int& onnxruntime::PlannerImpl::\
UseCount(onnxruntime::OrtValueIndex) n >= 0 && static_cast<size_t>(n) \
< ort_value_info_.size() was false.")
),
xfail(
"cross",
reason=onnx_test_common.reason_onnx_script_does_not_support("linalg_cross"),
),
xfail(
"dot", dtypes=(torch.uint8, torch.int8, torch.int16,),
reason=onnx_test_common.reason_onnx_does_not_support("MatMul", "uint8, int8, int16")
),
xfail(
"eq",
dtypes=(torch.uint8, torch.int8, torch.int16,),
reason=onnx_test_common.reason_onnx_runtime_does_not_support("Equal", "uint8, int8, int16"),
),
xfail(
"equal",
reason=onnx_test_common.reason_dynamo_does_not_support("aten.equal.default")
),
xfail(
"floor",
dtypes=onnx_test_common.BOOL_TYPES + onnx_test_common.INT_TYPES,
reason=onnx_test_common.reason_onnx_does_not_support("Floor", "bool, int"),
),
xfail(
"index_put",
dtypes=onnx_test_common.BOOL_TYPES,
reason=onnx_test_common.reason_onnx_script_does_not_support("index_put", "bool"),
),
xfail(
"index_put",
dtypes=(torch.uint8, torch.int8, torch.int16,),
reason=onnx_test_common.reason_onnx_script_does_not_support("Add", "int8, int16"),
),
xfail(
"nn.functional.adaptive_avg_pool2d",
reason=onnx_test_common.reason_onnx_script_does_not_support("RecursionError: \
maximum recursion depth exceeded while calling a Python object"),
),
xfail(
"nn.functional.adaptive_avg_pool3d",
reason=onnx_test_common.reason_onnx_script_does_not_support("aten._adaptive_avg_pool3d.default"),
),
xfail(
"nn.functional.avg_pool1d",
dtypes=onnx_test_common.INT_TYPES,
reason=onnx_test_common.reason_onnx_does_not_support("AveragePool", "int"),
),
xfail(
"nn.functional.avg_pool2d",
dtypes=onnx_test_common.INT_TYPES,
reason=onnx_test_common.reason_onnx_does_not_support("AveragePool", "int"),
),
xfail(
"nn.functional.avg_pool3d",
dtypes=onnx_test_common.INT_TYPES,
reason=onnx_test_common.reason_onnx_does_not_support("AveragePool", "int"),
),
xfail(
"nn.functional.conv1d",
dtypes=(torch.int64,),
reason=onnx_test_common.reason_onnx_does_not_support("Conv1d", "int64"),
),
xfail(
"nn.functional.conv2d",
dtypes=(torch.int64,),
reason=onnx_test_common.reason_onnx_does_not_support("Conv2d", "int64"),
),
xfail(
"nn.functional.dropout",
reason=onnx_test_common.reason_dynamo_does_not_support("Dropout"),
),
xfail(
"nn.functional.max_pool2d",
dtypes=onnx_test_common.BOOL_TYPES + onnx_test_common.INT_TYPES,
reason=onnx_test_common.reason_onnx_does_not_support("Max_pool2d"),
),
xfail(
"nn.functional.max_pool3d",
dtypes=onnx_test_common.BOOL_TYPES + onnx_test_common.INT_TYPES,
reason=onnx_test_common.reason_onnx_does_not_support("Max_pool3d"),
),
xfail(
"nonzero",
dtypes=(torch.int8, torch.int16),
reason=onnx_test_common.reason_onnx_runtime_does_not_support("NonZero", "int8, int16"),
),
xfail(
"scatter_add",
dtypes=(torch.float16,),
reason=onnx_test_common.reason_onnx_runtime_does_not_support("ScatterElements reduction=sum", "float16"),
),
xfail(
"scatter_reduce",
variant_name="sum",
dtypes=(torch.float16,),
reason=onnx_test_common.reason_onnx_runtime_does_not_support("ScatterElements reduction=sum", "float16"),
),
xfail(
"scatter_reduce",
variant_name="prod",
dtypes=(torch.float16,),
reason=onnx_test_common.reason_onnx_runtime_does_not_support("ScatterElements reduction=prod", "float16"),
),
xfail(
"scatter_reduce",
variant_name="amin",
dtypes=onnx_test_common.BOOL_TYPES + (torch.float16,),
reason=onnx_test_common.reason_onnx_runtime_does_not_support("ScatterElements reduction=amin", "float16"),
),
xfail(
"scatter_reduce",
variant_name="amax",
dtypes=onnx_test_common.BOOL_TYPES + (torch.float16,),
reason=onnx_test_common.reason_onnx_runtime_does_not_support("ScatterElements reduction=amax", "float16"),
),
xfail(
"scatter_reduce",
variant_name="mean",
reason="ONNX doesn't support reduce='mean' option",
),
xfail(
"square",
dtypes=(torch.int8, torch.uint8, torch.int16),
reason=onnx_test_common.reason_onnx_runtime_does_not_support("Pow", "int8, uint8, int16"),
),
xfail(
"stft",
reason=onnx_test_common.reason_dynamo_does_not_support("aten._fft_r2c.default"),
),
xfail(
"unflatten", dtypes=onnx_test_common.BOOL_TYPES,
reason=onnx_test_common.reason_onnx_does_not_support("Unflatten")
),
)
# fmt: on
SKIP_XFAIL_SUBTESTS: tuple[onnx_test_common.DecorateMeta, ...] = (
xfail(
"addmm", # xfail can't only use dtypes to catch all cases
matcher=lambda sample: sample.input.dtype
in (torch.uint8, torch.int8, torch.int16),
reason=onnx_test_common.reason_onnx_script_does_not_support(
"Add", "int8, int16, uint8"
),
),
skip(
"amax",
matcher=lambda sample: len(sample.input.shape) == 0,
reason="Op (ReduceMax) [ShapeInferenceError] axis must be in [-rank, rank-1]. input rank was 0",
),
skip(
"amin",
matcher=lambda sample: len(sample.input.shape) == 0,
reason="Op (ReduceMax) [ShapeInferenceError] axis must be in [-rank, rank-1]. input rank was 0",
),
skip(
"cat",
matcher=lambda sample: sample.input[0].equal(torch.tensor([])),
reason="core dump - cat does not support zero-dim tensors yet",
),
xfail(
"index_put",
matcher=lambda sample: (sample.args[0][0].dtype == torch.bool)
and (sample.kwargs.get("accumulate") is False),
reason=onnx_test_common.reason_dynamo_does_not_support(
"https://github.com/pytorch/pytorch/issues/101150"
),
),
xfail(
"native_batch_norm",
matcher=lambda sample: sample.args[4]
and (
isinstance(sample.args[0], torch.Tensor) and sample.args[0].shape == (1,)
), # Edge case with training=True and mean being 1d tensor of single element.
reason="AssertionError: The values for attribute 'shape' do not match: torch.Size([1]) != torch.Size([]).",
),
xfail(
"nn.functional.avg_pool1d",
matcher=lambda sample: (sample.kwargs.get("ceil_mode") is True)
and (
sample.kwargs.get("count_include_pad") is True
or sample.input.shape[2]
% (
sample.args[0][0]
if isinstance(sample.args[0], tuple)
else sample.args[0]
)
!= 0
),
reason="fixme: ORT doesn't match PyTorch when ceil_mode=True until opset 19",
),
xfail(
"nn.functional.avg_pool2d",
matcher=lambda sample: (len(sample.args) > 5 and sample.args[5] is not None)
or (sample.kwargs.get("divisor_override") is not None),
reason="ONNX doesn't support divisor_override argument",
),
xfail(
"nn.functional.avg_pool3d",
matcher=lambda sample: sample.kwargs.get("ceil_mode") is True,
reason="fixme: ORT doesn't match PyTorch when ceil_mode=True until opset 19",
),
xfail(
"nn.functional.avg_pool3d",
matcher=lambda sample: (len(sample.args) > 5 and sample.args[5] is not None)
or (sample.kwargs.get("divisor_override") is not None),
reason="ONNX doesn't support divisor_override argument",
),
skip(
"nn.functional.conv1d",
matcher=lambda sample: isinstance(sample.kwargs.get("padding"), str),
reason="String padding is not accepted by aten::conv1d",
),
skip(
"nn.functional.conv2d",
matcher=lambda sample: isinstance(sample.kwargs.get("padding"), str),
reason="String padding is not accepted by aten::conv2d",
),
skip(
"nn.functional.cross_entropy",
matcher=lambda sample: not isinstance(sample.kwargs.get("weight"), int),
reason="ONNX SoftmaxCrossEntropyLoss op only accept argument[weight] is int type",
),
skip_torchlib_forward_compatibility(
"nn.functional.embedding_bag",
matcher=lambda sample: sample.kwargs.get("padding_idx") is not None or True,
reason=onnx_test_common.reason_onnx_script_does_not_support(
"'padding_idx' overload for _embedding_bag and _embedding_bag_forward_only. "
"'padding_idx=-1' is emitted for aten op when 'padding_idx' is not provided"
),
github_issue="https://github.com/microsoft/onnxscript/issues/1056",
),
skip(
"nn.functional.max_pool3d",
matcher=lambda sample: sample.kwargs.get("ceil_mode") is True
and sample.kwargs.get("padding") == 1,
reason="FIXME: After https://github.com/microsoft/onnxruntime/issues/15446 is fixed",
),
xfail(
"nonzero",
matcher=lambda sample: len(sample.input.shape) == 0
and sample.kwargs.get("as_tuple", False) is False,
reason="Output 'shape' do not match: torch.Size([0, 1]) != torch.Size([0, 0]).",
),
xfail(
"scatter_add",
matcher=lambda sample: len(sample.input.shape) == 0,
reason="fixme: Rank(0) input will lead ORT failed due to different rank(result) in if-else branch",
),
skip(
"scatter_reduce",
# ONNX has not include_self parameter and default is include_self=True mode
matcher=lambda sample: sample.kwargs.get("include_self") is False,
reason="ONNX does't support include_self=False option",
),
xfail(
"unflatten",
reason="Logic not implemented for size 0 inputs in op.Reshape",
matcher=lambda sample: any(dim == 0 for dim in sample.input.shape),
),
)
# END OF SECTION TO MODIFY #####################################################
OPS_DB = copy.deepcopy(common_methods_invocations.op_db)
OP_WITH_SKIPPED_XFAIL_SUBTESTS = frozenset(meta.op_name for meta in SKIP_XFAIL_SUBTESTS)
ALL_OPS_IN_DB = frozenset(op_info.name for op_info in OPS_DB)
# Assert all ops in OPINFO_FUNCTION_MAPPING are in the OPS_DB
assert TESTED_OPS.issubset(ALL_OPS_IN_DB), f"{TESTED_OPS - ALL_OPS_IN_DB} not in OPS_DB"
class SingleOpModel(torch.nn.Module):
"""Test model to wrap around a single op for export."""
def __init__(self, op, kwargs):
super().__init__()
self.operator = op
self.kwargs = kwargs
def forward(self, *args):
return self.operator(*args, **self.kwargs)
def _should_skip_xfail_test_sample(
op_name: str, sample
) -> Tuple[Optional[str], Optional[str]]:
"""Returns a reason if a test sample should be skipped."""
if op_name not in OP_WITH_SKIPPED_XFAIL_SUBTESTS:
return None, None
for decorator_meta in SKIP_XFAIL_SUBTESTS:
# Linear search on ops_test_data.SKIP_XFAIL_SUBTESTS. That's fine because the list is small.
if decorator_meta.op_name == op_name:
assert decorator_meta.matcher is not None, "Matcher must be defined"
if decorator_meta.matcher(sample):
return decorator_meta.test_behavior, decorator_meta.reason
return None, None
def _run_test_output_match(
test_suite: onnx_test_common._TestONNXRuntime,
device: str,
dtype: torch.dtype,
op: opinfo_core.OpInfo,
):
# device is provided by instantiate_device_type_tests, but we only want to run in cpu.
assert device == "cpu"
samples = op.sample_inputs(
device,
dtype,
requires_grad=False,
)
for i, cpu_sample in enumerate(samples):
inputs = (cpu_sample.input, *cpu_sample.args)
# Provide the repr to subtest because tensors are not serializable in parallel test runs
with test_suite.subTest(
opset=test_suite.opset_version,
sample_num=i,
inputs=repr(inputs),
kwargs=repr(cpu_sample.kwargs),
):
test_behavior, reason = _should_skip_xfail_test_sample(op.name, cpu_sample)
with onnx_test_common.normal_xfail_skip_test_behaviors(
test_behavior, reason
):
model = SingleOpModel(op.op, cpu_sample.kwargs)
model.eval()
if dtype == torch.float32:
# Relax atol and rtol for float32 based on empirical results
rtol = 1e-5
atol = 2e-5
elif (
dtype == torch.float16
and op.name in test_suite.fp16_low_precision_list
):
rtol = 1e-2
atol = 1e-3
else:
rtol = None
atol = None
# Run the test
test_suite.run_test_with_fx_to_onnx_exporter_and_onnx_runtime(
model, inputs, rtol=rtol, atol=atol
)
def _get_test_class_name(cls, num, params_dict) -> str:
del cls # unused
del num # unused
return params_dict["name"]
@parameterized.parameterized_class(
[
{
"name": f"TestOnnxModelOutputConsistency_opset{opset}",
"opset_version": opset,
}
for opset in onnx_test_common.FX_TESTED_OPSETS
],
class_name_func=_get_test_class_name,
)
class TestOnnxModelOutputConsistency(onnx_test_common._TestONNXRuntime):
"""Test output consistency between exported ONNX models and PyTorch eager mode.
This is a parameterized test suite.
"""
opset_version = -1
op_level_debug: bool = False
dynamic_shapes: bool = False
fp16_low_precision_list = [
"nn.functional.batch_norm",
"native_batch_norm",
"dot",
"logit",
]
@common_device_type.ops(
[op for op in OPS_DB if op.name in TESTED_OPS],
allowed_dtypes=onnx_test_common.TESTED_DTYPES,
)
def test_output_match(self, device: str, dtype: torch.dtype, op):
"""Test the ONNX exporter."""
_run_test_output_match(self, device, dtype, op)
@common_device_type.ops(
[op for op in OPS_DB if op.name in COMPLEX_TESTED_OPS],
allowed_dtypes=onnx_test_common.COMPLEX_TYPES,
)
def test_output_match_complex(self, device: str, dtype: torch.dtype, op):
"""Test the ONNX exporter with complex dtype."""
_run_test_output_match(self, device, dtype, op)
for opset in onnx_test_common.FX_TESTED_OPSETS:
# The name needs to match the parameterized_class name.
test_class_name = f"TestOnnxModelOutputConsistency_opset{opset}"
onnx_test_common.add_decorate_info(
OPS_DB,
test_class_name,
"test_output_match",
opset=opset,
skip_or_xfails=EXPECTED_SKIPS_OR_FAILS,
)
onnx_test_common.add_decorate_info(
OPS_DB,
test_class_name,
"test_output_match_complex",
opset=opset,
skip_or_xfails=EXPECTED_SKIPS_OR_FAILS,
)
common_device_type.instantiate_device_type_tests(
globals()[test_class_name], globals(), only_for="cpu"
)
if __name__ == "__main__":
common_utils.run_tests()