blob: fb4f1fc2cdaa9dc38e4050ba70033a166e6055bd [file] [log] [blame]
# Owner(s): ["module: onnx"]
from __future__ import annotations
import tempfile
from typing import Mapping, Tuple
import onnx
import onnx.inliner
import pytorch_test_common
import torch
import transformers # type: ignore[import]
from torch import nn
from torch._subclasses import fake_tensor
from torch.nn import functional as F
from torch.onnx import dynamo_export, ExportOptions
from torch.onnx._internal.diagnostics import infra
from torch.onnx._internal.fx import diagnostics, registration
from torch.testing._internal import common_utils
def assert_has_diagnostics(
diagnostic_context: diagnostics.DiagnosticContext,
rule: infra.Rule,
level: infra.Level,
expected_node: str,
):
rule_level_pairs = (rule.id, level.name.lower())
sarif_log = diagnostic_context.sarif_log()
actual_results = []
for run in sarif_log.runs:
if run.results is None:
continue
for result in run.results:
id_level_pair = (result.rule_id, result.level)
actual_results.append(id_level_pair)
if (
rule_level_pairs == id_level_pair
and result.message.text
and result.message.markdown
and expected_node in result.message.text
):
return
raise AssertionError(
f"Expected diagnostic results of rule id and level pair {rule_level_pairs} "
f"not found with expected error node {expected_node} and "
f"Actual diagnostic results: {actual_results}"
)
@common_utils.instantiate_parametrized_tests
class TestFxToOnnx(pytorch_test_common.ExportTestCase):
def setUp(self):
super().setUp()
self.export_options = ExportOptions()
def tearDown(self):
super().tearDown()
def test_simple_function(self):
def func(x):
y = x + 1
z = y.relu()
return (y, z)
_ = dynamo_export(
func, torch.randn(1, 1, 2), export_options=self.export_options
)
def test_empty(self):
# Since `torch.empty` returns tensor with uninitialized data, we cannot
# test this under `test_fx_to_onnx_with_onnxruntime.py` with result comparison.
def func(x):
return torch.empty(x.size(), dtype=torch.int64)
tensor_x = torch.randn(1, 1, 2)
_ = dynamo_export(func, tensor_x, export_options=self.export_options)
def test_args_used_for_export_is_not_converted_to_fake_tensors(self):
def func(x, y):
return x + y
tensor_x = torch.randn(1, 1, 2)
tensor_y = torch.randn(1, 1, 2)
_ = dynamo_export(func, tensor_x, tensor_y, export_options=self.export_options)
self.assertNotIsInstance(tensor_x, fake_tensor.FakeTensor)
self.assertNotIsInstance(tensor_y, fake_tensor.FakeTensor)
@common_utils.parametrize(
"diagnostic_rule",
[
common_utils.subtest(
diagnostics.rules.find_opschema_matched_symbolic_function,
name="optional_inputs",
),
common_utils.subtest(
diagnostics.rules.op_level_debugging,
name="get_attr_node_in_op_level_debug",
),
],
)
def test_mnist_exported_with_no_warnings(self, diagnostic_rule):
class MNISTModel(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(1, 32, 3, 1, bias=False)
self.conv2 = nn.Conv2d(32, 64, 3, 1, bias=False)
self.fc1 = nn.Linear(9216, 128, bias=False)
self.fc2 = nn.Linear(128, 10, bias=False)
def forward(self, tensor_x: torch.Tensor):
tensor_x = self.conv1(tensor_x)
tensor_x = F.sigmoid(tensor_x)
tensor_x = self.conv2(tensor_x)
tensor_x = F.sigmoid(tensor_x)
tensor_x = F.max_pool2d(tensor_x, 2)
tensor_x = torch.flatten(tensor_x, 1)
tensor_x = self.fc1(tensor_x)
tensor_x = F.sigmoid(tensor_x)
tensor_x = self.fc2(tensor_x)
output = F.log_softmax(tensor_x, dim=1)
return output
tensor_x = torch.rand((64, 1, 28, 28), dtype=torch.float32)
onnx_program = dynamo_export(
MNISTModel(), tensor_x, export_options=ExportOptions(op_level_debug=True)
)
assert_has_diagnostics(
onnx_program.diagnostic_context,
diagnostic_rule,
diagnostics.levels.NONE,
expected_node="aten.convolution.default",
)
def test_no_warnings_on_complex_dtype_in_op_level_debug(self):
class ComplexModel(torch.nn.Module):
def forward(self, input):
return torch.ops.aten.mul(input, input)
real = torch.tensor([1, 2], dtype=torch.float32)
imag = torch.tensor([3, 4], dtype=torch.float32)
x = torch.complex(real, imag)
onnx_program = dynamo_export(
ComplexModel(), x, export_options=ExportOptions(op_level_debug=True)
)
assert_has_diagnostics(
onnx_program.diagnostic_context,
diagnostics.rules.op_level_debugging,
diagnostics.levels.NONE,
expected_node="aten.mul.Tensor",
)
def test_trace_only_op_with_evaluator(self):
model_input = torch.tensor([[1.0, 2.0, 3.0], [1.0, 1.0, 2.0]])
class ArgminArgmaxModel(torch.nn.Module):
def forward(self, input):
return (
torch.argmin(input),
torch.argmax(input),
torch.argmin(input, keepdim=True),
torch.argmax(input, keepdim=True),
torch.argmin(input, dim=0, keepdim=True),
torch.argmax(input, dim=1, keepdim=True),
)
_ = dynamo_export(
ArgminArgmaxModel(), model_input, export_options=self.export_options
)
def test_multiple_outputs_op_with_evaluator(self):
class TopKModel(torch.nn.Module):
def forward(self, x):
values, _ = torch.topk(x, 3)
return torch.sum(values)
x = torch.arange(1.0, 6.0, requires_grad=True)
onnx_program = dynamo_export(TopKModel(), x, export_options=self.export_options)
def test_unsupported_indices_fake_tensor_generated_with_op_level_debug(self):
class EmbedModelWithoutPaddingIdx(torch.nn.Module):
def forward(self, input, emb):
return torch.nn.functional.embedding(input, emb)
model = EmbedModelWithoutPaddingIdx()
x = torch.randint(4, (4, 3, 2))
embedding_matrix = torch.rand(10, 3)
onnx_program = dynamo_export(
model,
x,
embedding_matrix,
export_options=ExportOptions(op_level_debug=True),
)
assert_has_diagnostics(
onnx_program.diagnostic_context,
diagnostics.rules.op_level_debugging,
diagnostics.levels.WARNING,
expected_node="aten.embedding.default",
)
def test_unsupported_function_schema_raises_diagnostic_warning_when_found_nearest_match(
self,
):
class TraceModel(torch.nn.Module):
def forward(self, input):
return input.new_zeros(())
x = torch.randn((2, 3), dtype=torch.float32)
onnx_program = dynamo_export(TraceModel(), x)
assert_has_diagnostics(
onnx_program.diagnostic_context,
diagnostics.rules.find_opschema_matched_symbolic_function,
diagnostics.levels.WARNING,
expected_node="aten.new_zeros.default",
)
def test_perfect_match_on_sequence_and_bool_attributes(
self,
):
class TraceModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv2 = torch.nn.Conv2d(
16, 33, (3, 5), stride=(2, 1), padding=(4, 2), dilation=(3, 1)
)
def forward(self, input):
return self.conv2(input)
x = torch.randn(20, 16, 50, 50)
onnx_program = dynamo_export(
TraceModel(), x, export_options=ExportOptions(op_level_debug=False)
)
assert_has_diagnostics(
onnx_program.diagnostic_context,
diagnostics.rules.find_opschema_matched_symbolic_function,
diagnostics.levels.NONE,
expected_node="aten.convolution.default",
)
def test_dispatch_overload_fall_back_default_raise_diagnostic_warning(self):
class TraceModel(torch.nn.Module):
def forward(self, input):
return torch.ops.aten.add.Tensor(input, input)
onnx_registry = torch.onnx.OnnxRegistry()
self.assertTrue(
onnx_registry.is_registered_op(
namespace="aten", op_name="add", overload="Tensor"
)
)
aten_add_Tensor = registration.OpName.from_name_parts(
namespace="aten", op_name="add", overload="Tensor"
)
onnx_registry._registry.pop(aten_add_Tensor)
x = torch.tensor(3)
onnx_program = dynamo_export(
TraceModel(), x, export_options=ExportOptions(onnx_registry=onnx_registry)
)
assert_has_diagnostics(
onnx_program.diagnostic_context,
diagnostics.rules.find_operator_overloads_in_onnx_registry,
diagnostics.levels.WARNING,
expected_node="aten.add.Tensor",
)
def test_aten_clone_does_not_raise_warning_of_lack_of_memory_format(self):
class CustomModule(torch.nn.Module):
def forward(self, input):
return torch.ops.aten.clone(input, memory_format=torch.preserve_format)
x = torch.tensor(3)
onnx_program = dynamo_export(CustomModule(), x)
assert_has_diagnostics(
onnx_program.diagnostic_context,
diagnostics.rules.find_opschema_matched_symbolic_function,
diagnostics.levels.NONE,
expected_node="aten.clone.default",
)
def test_missing_complex_onnx_variant_raises_errors_in_dispatcher(self):
registry = torch.onnx.OnnxRegistry()
# NOTE: simulate unsupported nodes
aten_mul_tensor = registration.OpName.from_name_parts(
namespace="aten", op_name="mul", overload="Tensor"
)
# Only keep real aten.mul to test missing complex aten.mul
registry._registry[aten_mul_tensor] = [
onnx_func
for onnx_func in registry._registry[aten_mul_tensor]
if not onnx_func.is_complex
]
class TraceModel(torch.nn.Module):
def forward(self, input):
return torch.ops.aten.mul.Tensor(input, input)
x = torch.tensor([1 + 2j, 3 + 4j], dtype=torch.complex64)
with self.assertRaises(torch.onnx.OnnxExporterError) as e:
torch.onnx.dynamo_export(
TraceModel(),
x,
export_options=torch.onnx.ExportOptions(onnx_registry=registry),
)
try:
torch.onnx.dynamo_export(
TraceModel(),
x,
export_options=torch.onnx.ExportOptions(onnx_registry=registry),
)
except torch.onnx.OnnxExporterError as e:
assert_has_diagnostics(
e.onnx_program.diagnostic_context,
diagnostics.rules.no_symbolic_function_for_call_function,
diagnostics.levels.ERROR,
expected_node="aten.mul.Tensor",
)
def test_symbolic_shape_of_values_inside_function_is_exported_as_graph_value_info(
self,
):
class SubModule(torch.nn.Module):
def forward(self, x, y, bias):
output = x @ y
return output + bias
class Module(torch.nn.Module):
def __init__(self):
super().__init__()
self.submodule = SubModule()
def forward(self, x, y, bias):
return self.submodule(x, y, bias)
x = torch.randn(2, 3)
y = torch.randn(3, 4)
bias = torch.randn(4)
onnx_program = torch.onnx.dynamo_export(
Module(),
x,
y,
bias,
export_options=torch.onnx.ExportOptions(dynamic_shapes=True),
)
model_proto = onnx_program.model_proto
# Assert value_info for values inside local function can be retrieved
def _assert_node_outputs_has_value_info(
node: onnx.NodeProto,
value_infos: Mapping[str, onnx.ValueInfoProto],
local_functions: Mapping[Tuple[str, str], onnx.FunctionProto],
prefix: str = "",
):
for output in node.output:
output_prefix = f"{prefix}/{output}" if prefix else output
self.assertIn(output_prefix, value_infos)
if node.domain.startswith("pkg.onnxscript.torch_lib"):
# No shape info available for values inside torchlib functions.
return
if (
function := local_functions.get((node.domain, node.op_type))
) is not None:
for node in function.node:
node_prefix = (
f"{prefix}/{function.name}" if prefix else function.name
)
_assert_node_outputs_has_value_info(
node, value_infos, local_functions, node_prefix
)
type_infos = {vi.name: vi for vi in model_proto.graph.value_info}
functions = {(f.domain, f.name): f for f in model_proto.functions}
for node in model_proto.graph.node:
_assert_node_outputs_has_value_info(node, type_infos, functions)
def test_dynamo_export_retains_readable_parameter_and_buffer_names(self):
class SubModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv2 = nn.Conv2d(32, 64, 3, 1, bias=False)
self.fc1 = nn.Linear(9216, 128, bias=False)
self.register_buffer("buffer", torch.randn(1, 128))
def forward(self, tensor_x: torch.Tensor):
tensor_x = self.conv2(tensor_x)
tensor_x = F.sigmoid(tensor_x)
tensor_x = F.max_pool2d(tensor_x, 2)
tensor_x = torch.flatten(tensor_x, 1)
tensor_x = self.fc1(tensor_x)
tensor_x = tensor_x + self.buffer
tensor_x = F.sigmoid(tensor_x)
return tensor_x
class MNISTModel(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(1, 32, 3, 1, bias=False)
self.submodule = SubModule()
self.fc2 = nn.Linear(128, 10, bias=False)
def forward(self, tensor_x: torch.Tensor):
tensor_x = self.conv1(tensor_x)
tensor_x = F.sigmoid(tensor_x)
tensor_x = self.submodule(tensor_x)
tensor_x = self.fc2(tensor_x)
output = F.log_softmax(tensor_x, dim=1)
return output
tensor_x = torch.rand((64, 1, 28, 28), dtype=torch.float32)
model = MNISTModel()
onnx_program = torch.onnx.dynamo_export(model, tensor_x)
model_proto = onnx_program.model_proto
self.assertEqual(
{initializer.name for initializer in model_proto.graph.initializer},
{*model.state_dict().keys()},
)
def test_fake_tensor_mode_simple(self):
class Model(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.linear = torch.nn.Linear(2, 2)
def forward(self, x):
out = self.linear(x)
return out
with torch.onnx.enable_fake_mode() as fake_context:
x = torch.rand(5, 2, 2)
model = Model()
export_options = ExportOptions(fake_context=fake_context)
onnx_program = torch.onnx.dynamo_export(
model, x, export_options=export_options
)
assert (
onnx_program is not None
), "ONNXProgram must be created on successful export"
assert (
onnx_program.model_proto is not None
), "A model protobuf must be created on a successful export"
onnx.checker.check_model(onnx_program.model_proto, full_check=True)
assert (
len(onnx_program.model_proto.graph.initializer) == 0
), "Initializers cannot exist when fake mode is enabled"
# Variant 1: Save ONNX proto using Model's state_dict()
with tempfile.NamedTemporaryFile(suffix=".onnx") as tmp_onnx_file:
model_state_dict = Model().state_dict() # Create a state_dict for testing
onnx_program.save(tmp_onnx_file.name, model_state_dict=model_state_dict)
assert (
len(onnx.load(tmp_onnx_file.name).graph.initializer) == 2
), "Initializers must be present after loading it from model_state_dict"
# Variant 2: Save ONNX proto using Model checkpoint file
with tempfile.NamedTemporaryFile(
suffix=".onnx"
) as tmp_onnx_file, tempfile.NamedTemporaryFile(
suffix=".pt"
) as tmp_checkpoint_file:
torch.save(
Model().state_dict(), tmp_checkpoint_file.name
) # Create checkpoint file for testing
onnx_program.save(
tmp_onnx_file.name, model_state_dict=tmp_checkpoint_file.name
)
assert (
len(onnx.load(tmp_onnx_file.name).graph.initializer) == 2
), "Initializers must be present after loading it from model_state_dict"
def test_fake_tensor_mode_simple_invalid_input(self):
class Model(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.linear = torch.nn.Linear(2, 2)
def forward(self, x):
out = self.linear(x)
return out
real_model = Model()
real_x = torch.rand(5, 2, 2)
with torch.onnx.enable_fake_mode() as fake_context:
fake_model = Model()
fake_x = torch.rand(5, 2, 2)
# TODO: Split each scenario on its own test case
# Scenario 1: Fake model and fake input WITHOUT ExportOptions(fake_context=...)
with self.assertRaises(torch.onnx.OnnxExporterError):
export_options = ExportOptions(fake_context=None)
_ = torch.onnx.dynamo_export(
fake_model, fake_x, export_options=export_options
)
# Scenario 2: Fake model and real input WITHOUT fake_context
with self.assertRaises(torch.onnx.OnnxExporterError):
export_options = ExportOptions(fake_context=None)
_ = torch.onnx.dynamo_export(
fake_model, real_x, export_options=export_options
)
# Scenario 3: Real model and real input WITH fake_context
with self.assertRaises(torch.onnx.OnnxExporterError):
export_options = ExportOptions(fake_context=fake_context)
_ = torch.onnx.dynamo_export(
real_model, real_x, export_options=export_options
)
# Scenario 4: Fake model and real input WITH fake_context
with self.assertRaises(torch.onnx.OnnxExporterError):
export_options = ExportOptions(fake_context=fake_context)
_ = torch.onnx.dynamo_export(
fake_model, real_x, export_options=export_options
)
@pytorch_test_common.xfail(
error_message="Dynamic control flow is not supported at the moment."
)
def test_fake_tensor_mode_huggingface_llama(self):
config = transformers.LlamaConfig(
vocab_size=8096, hidden_size=256, num_hidden_layers=2, num_attention_heads=2
)
batch, seq = 4, 256
with torch.onnx.enable_fake_mode() as fake_context:
model = transformers.LlamaModel(config).eval()
input_ids = torch.randint(0, config.vocab_size, (batch, seq))
attention_mask = torch.ones(batch, seq, dtype=torch.bool)
position_ids = torch.arange(0, seq, dtype=torch.long)
position_ids = position_ids.unsqueeze(0).view(-1, seq)
export_options = torch.onnx.ExportOptions(fake_context=fake_context)
onnx_program = torch.onnx.dynamo_export(
model,
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
export_options=export_options,
)
onnx.checker.check_model(onnx_program.model_proto)
onnx.shape_inference.infer_shapes(onnx_program.model_proto)
@pytorch_test_common.xfail(
error_message="Dynamic control flow is not supported at the moment."
)
def test_fake_tensor_mode_huggingface_tiiuae_falcon(self):
config = transformers.FalconConfig()
batch, seq = 4, 256
with torch.onnx.enable_fake_mode() as fake_context:
model = transformers.FalconModel(config).eval()
input_ids = torch.randint(0, config.vocab_size, (batch, seq))
attention_mask = torch.ones(batch, seq, dtype=torch.bool)
export_options = torch.onnx.ExportOptions(fake_context=fake_context)
onnx_program = torch.onnx.dynamo_export(
model,
input_ids=input_ids,
attention_mask=attention_mask,
export_options=export_options,
)
onnx.checker.check_model(onnx_program.model_proto)
onnx.shape_inference.infer_shapes(onnx_program.model_proto)
def test_exported_program_input_with_custom_fx_tracer(self):
from torch.onnx._internal import exporter
from torch.onnx._internal.fx import dynamo_graph_extractor
class Model(torch.nn.Module):
def forward(self, x):
return x + 1
x = torch.randn(1, 1, 2)
exported_program = torch.export.export(Model(), args=(x,))
export_options = torch.onnx.ExportOptions()
export_options = exporter.ResolvedExportOptions(
export_options, model=exported_program
)
export_options.fx_tracer = (
dynamo_graph_extractor.DynamoExport()
) # Override fx_tracer to an unsupported tracer
with self.assertRaises(torch.onnx.OnnxExporterError):
onnx_program = torch.onnx.dynamo_export(
exported_program,
x,
export_options=export_options,
)
self.assertTrue(onnx_program._export_exception is not None)
with self.assertRaises(torch.onnx.InvalidExportOptionsError):
raise self._export_exception
def test_exported_program_torch_distributions_normal_Normal(self):
class Model(torch.nn.Module):
def __init__(self):
self.normal = torch.distributions.normal.Normal(0, 1)
super().__init__()
def forward(self, x):
return self.normal.sample(x.shape)
x = torch.randn(2, 3)
with torch.no_grad():
exported_program = torch.export.export(Model(), args=(x,))
_ = torch.onnx.dynamo_export(
exported_program,
x,
)
def test_aten_div_no_opmath_type_promotion(self):
class Model(torch.nn.Module):
def forward(self, input):
return input / 2
model = Model()
input = torch.randn(3, 5, requires_grad=True, dtype=torch.float16)
model_proto = torch.onnx.dynamo_export(model, input).model_proto
model_proto = onnx.inliner.inline_local_functions(model_proto)
div_node = next(
node for node in model_proto.graph.node if node.op_type == "Div"
)
# The input of Div node should be the input of the model,
# with no Cast node in between.
self.assertEqual(div_node.input[0], model_proto.graph.input[0].name)
def test_exported_program_as_input_with_model_signature(self):
class Model(torch.nn.Module):
def forward(self, x):
return x + 1.0
x = torch.randn(1, 1, 2, dtype=torch.float)
exported_program = torch.export.export(Model(), args=(x,))
onnx_program = torch.onnx.dynamo_export(
exported_program,
x,
)
self.assertTrue(onnx_program.model_signature, torch.export.ExportGraphSignature)
if __name__ == "__main__":
common_utils.run_tests()