blob: 9b44912dc450df783a45ec7692c05a9f4e3b1b97 [file] [log] [blame]
# Owner(s): ["module: dynamo"]
import unittest
import torch._dynamo as torchdynamo
from torch._export import export
from torch._export.db.case import ExportCase, normalize_inputs, SupportLevel
from torch._export.db.examples import (
filter_examples_by_support_level,
get_rewrite_cases,
)
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
parametrize,
run_tests,
TestCase,
)
@unittest.skipIf(not torchdynamo.is_dynamo_supported(), "dynamo doesn't support")
class ExampleTests(TestCase):
# TODO Maybe we should make this tests actually show up in a file?
@parametrize(
"name,case",
filter_examples_by_support_level(SupportLevel.SUPPORTED).items(),
name_fn=lambda name, case: f"case_{name}",
)
def test_exportdb_supported(self, name: str, case: ExportCase) -> None:
model = case.model
inputs = normalize_inputs(case.example_inputs)
exported_program = export(
model,
inputs.args,
inputs.kwargs,
constraints=case.constraints,
)
exported_program.graph_module.print_readable()
self.assertEqual(
exported_program(*inputs.args, **inputs.kwargs),
model(*inputs.args, **inputs.kwargs),
)
if case.extra_inputs is not None:
inputs = normalize_inputs(case.extra_inputs)
self.assertEqual(
exported_program(*inputs.args, **inputs.kwargs),
model(*inputs.args, **inputs.kwargs),
)
@parametrize(
"name,case",
filter_examples_by_support_level(SupportLevel.NOT_SUPPORTED_YET).items(),
name_fn=lambda name, case: f"case_{name}",
)
def test_exportdb_not_supported(self, name: str, case: ExportCase) -> None:
model = case.model
# pyre-ignore
with self.assertRaises(torchdynamo.exc.Unsupported):
inputs = normalize_inputs(case.example_inputs)
exported_model = export(
model,
inputs.args,
inputs.kwargs,
constraints=case.constraints,
)
@parametrize(
"name,rewrite_case",
[
(name, rewrite_case)
for name, case in filter_examples_by_support_level(
SupportLevel.NOT_SUPPORTED_YET
).items()
for rewrite_case in get_rewrite_cases(case)
],
name_fn=lambda name, case: f"case_{name}_{case.name}",
)
def test_exportdb_not_supported_rewrite(
self, name: str, rewrite_case: ExportCase
) -> None:
# pyre-ignore
inputs = normalize_inputs(rewrite_case.example_inputs)
exported_model = export(
rewrite_case.model,
inputs.args,
inputs.kwargs,
constraints=rewrite_case.constraints,
)
instantiate_parametrized_tests(ExampleTests)
if __name__ == "__main__":
run_tests()