blob: 3913370f9e46ff77f0a54aef8cc7e65f38068482 [file] [log] [blame]
# Owner(s): ["oncall: export"]
import unittest
from unittest.mock import patch
import torch
import torch._dynamo as torchdynamo
from torch._export.serde.serialize import GraphModuleOpUpgrader
from torch._export.serde.upgrade import get_target_version, get_upgraders
from torch.export import export
from torch.testing._internal.common_utils import IS_WINDOWS, run_tests, TestCase
TEST_UPGRADERS = {
"aten::div__Scalar_mode_0_3": (
"div.Scalar_mode(Tensor self, Scalar other, *, str rounding_mode) -> Tensor",
"""
from typing import Any, Optional
def div__Scalar_mode_0_3(self: torch.Tensor, other: Any, *, rounding_mode: Optional[str]=None) -> torch.Tensor:
return self.divide(other, rounding_mode=rounding_mode)
""",
),
"aten::gelu_0_9": (
"gelu(Tensor self) -> Tensor",
"""
def gelu_0_9(self: Tensor) -> Tensor:
return torch.gelu(self, approximate='none')
""",
),
}
TEST_UPGRADERS_ENTRY_MAP = {
"div__Scalar_mode_0_3": """
from typing import Any, Optional
def div__Scalar_mode_0_3(self: torch.Tensor, other: Any, *, rounding_mode: Optional[str]=None) -> torch.Tensor:
return self.divide_(other, rounding_mode=rounding_mode)"""
}
TEST_OP_VERSION_MAP = {
"aten::div_.Scalar_mode": [
torch._C._UpgraderEntry(
4,
"div__Scalar_mode_0_3",
"aten::div_.Scalar_mode(Tensor(a!) self, Scalar other, *, str? rounding_mode) -> Tensor(a!)",
)
]
}
def count_op(graph, target_str):
return len(
[
n
for n in graph.nodes
if isinstance(n.target, torch._ops.OpOverload)
and n.target.name() == target_str
]
)
@unittest.skipIf(not torchdynamo.is_dynamo_supported(), "dynamo doesn't support")
class TestUpgrade(TestCase):
def test_get_upgraders(self):
with patch.object(
torch._C, "_get_upgraders_entry_map", return_value=TEST_UPGRADERS_ENTRY_MAP
), patch.object(
torch._C, "_get_operator_version_map", return_value=TEST_OP_VERSION_MAP
):
op_upgraders = get_upgraders()
self.assertEqual(
op_upgraders,
{
"div__Scalar_mode_0_3": (
"aten::div_.Scalar_mode(Tensor(a!) self, Scalar other, *, str? rounding_mode) -> Tensor(a!)",
"""
from typing import Any, Optional
def div__Scalar_mode_0_3(self: torch.Tensor, other: Any, *, rounding_mode: Optional[str]=None) -> torch.Tensor:
return self.divide_(other, rounding_mode=rounding_mode)""",
)
},
)
def test_get_upgraders_missing_from_entry_map_raises(self):
with patch.object(
torch._C, "_get_upgraders_entry_map", return_value={}
), patch.object(
torch._C, "_get_operator_version_map", return_value=TEST_OP_VERSION_MAP
):
with self.assertRaises(RuntimeError):
get_upgraders()
def test_upgrader_with_invalid_format_throws_exception(self):
"""Invalid upgrader function string should throw exception"""
upgraders = [("div(Tensor a, Tensor b) -> Tensor", "TEST")]
with self.assertRaises(RuntimeError):
GraphModuleOpUpgrader._populate_passes(upgraders)
def test_get_target_version_invalid_format_throws_exception(self):
with self.assertRaises(RuntimeError):
get_target_version("div_0")
with self.assertRaises(RuntimeError):
get_target_version("div_0_")
with self.assertRaises(RuntimeError):
get_target_version("div")
def test_creates_upgrader_pass(self):
compiler_opset_version = {"aten": 4}
model_opset_version = {"aten": 3}
upgrader = GraphModuleOpUpgrader(
compiler_opset_version, model_opset_version, TEST_UPGRADERS
)
self.assertEqual(len(upgrader.upgrader_passes), 1)
def test_div_upgrader_replaces_op_with_old_version(self):
class Foo(torch.nn.Module):
def forward(self, a: torch.Tensor, b):
return torch.ops.aten.div.Scalar_mode(a, b, rounding_mode="trunc")
fn = Foo()
inputs = (torch.ones([2, 3]) * 4, 2.0)
ep = export(fn, inputs, [])
compiler_opset_version = {"aten": 4}
model_opset_version = {"aten": 3}
upgrader = GraphModuleOpUpgrader(
compiler_opset_version, model_opset_version, TEST_UPGRADERS
)
upgraded = ep._transform_do_not_use(*upgrader.upgrader_passes)
upgraded.graph_module.print_readable()
count = count_op(upgraded.graph, "aten::div.Scalar_mode")
self.assertEqual(count, 0)
custom_op_count = count_op(upgraded.graph, "aten::div__Scalar_mode_0_3")
self.assertEqual(custom_op_count, 1)
@unittest.skipIf(IS_WINDOWS, "Test case not supported on Windows")
def test_div_upgrader_pass_return_new_op_after_retrace(self):
class Foo(torch.nn.Module):
def forward(self, a: torch.Tensor, b):
return torch.ops.aten.div.Scalar_mode(a, b, rounding_mode="trunc")
fn = Foo()
inputs = (torch.ones([2, 3]) * 4, 2.0)
ep = export(fn, inputs)
compiler_opset_version = {"aten": 4}
model_opset_version = {"aten": 3}
upgrader = GraphModuleOpUpgrader(
compiler_opset_version, model_opset_version, TEST_UPGRADERS
)
count = count_op(ep.graph, "aten::div.Scalar_mode")
self.assertEqual(count, 1)
# upgrade: replace op (div.Scalar_mode -> div__Scalar_mode_0_3) then retrace
upgraded_ep = upgrader.upgrade(ep)
upgraded_ep.graph_module.print_readable()
# no old version of op (div__Scalar_mode_0_3) anymore.
custom_op_count = count_op(upgraded_ep.graph, "aten::div__Scalar_mode_0_3")
self.assertEqual(custom_op_count, 0)
# div__Scalar_mode_0_3 decomposes into div.Tensor.
decomposed_op_count = count_op(upgraded_ep.graph, "aten::div.Tensor_mode")
self.assertEqual(decomposed_op_count, 1)
if __name__ == "__main__":
run_tests()