blob: 8a261b13d646a75378d65bc91af501a9da3f3e93 [file] [log] [blame]
# Owner(s): ["module: dynamo"]
import unittest
from unittest.mock import patch
import torch
import torch._dynamo as torchdynamo
from torch._export import export
from torch._export.serde.serialize import GraphModuleOpUpgrader
from torch._export.serde.upgrade import get_target_version, get_upgraders
from torch.testing._internal.common_utils import (
run_tests,
TestCase,
)
TEST_UPGRADERS = {
"aten::div__Scalar_mode_0_3": (
"div.Scalar_mode(Tensor self, Scalar other, *, str rounding_mode) -> Tensor",
"""
from typing import Any, Optional
def div__Scalar_mode_0_3(self: torch.Tensor, other: Any, *, rounding_mode: Optional[str]=None) -> torch.Tensor:
return self.divide(other, rounding_mode=rounding_mode)
""",
),
"aten::gelu_0_9": (
"gelu(Tensor self) -> Tensor",
"""
def gelu_0_9(self: Tensor) -> Tensor:
return torch.gelu(self, approximate='none')
""",
),
}
TEST_UPGRADERS_ENTRY_MAP = {
"div__Scalar_mode_0_3":
"""
from typing import Any, Optional
def div__Scalar_mode_0_3(self: torch.Tensor, other: Any, *, rounding_mode: Optional[str]=None) -> torch.Tensor:
return self.divide_(other, rounding_mode=rounding_mode)"""
}
TEST_OP_VERSION_MAP = {
"aten::div_.Scalar_mode": [
torch._C._UpgraderEntry(
4,
"div__Scalar_mode_0_3",
"aten::div_.Scalar_mode(Tensor(a!) self, Scalar other, *, str? rounding_mode) -> Tensor(a!)"
)
]
}
def count_op(graph, target_str):
return len(
[n for n in graph.nodes if isinstance(n.target, torch._ops.OpOverload) and n.target.name() == target_str])
@unittest.skipIf(not torchdynamo.is_dynamo_supported(), "dynamo doesn't support")
class TestUpgrade(TestCase):
def test_get_upgraders(self):
with patch.object(torch._C, "_get_upgraders_entry_map", return_value=TEST_UPGRADERS_ENTRY_MAP), \
patch.object(torch._C, "_get_operator_version_map", return_value=TEST_OP_VERSION_MAP):
op_upgraders = get_upgraders()
self.assertEqual(op_upgraders, {
"div__Scalar_mode_0_3": (
"aten::div_.Scalar_mode(Tensor(a!) self, Scalar other, *, str? rounding_mode) -> Tensor(a!)",
"""
from typing import Any, Optional
def div__Scalar_mode_0_3(self: torch.Tensor, other: Any, *, rounding_mode: Optional[str]=None) -> torch.Tensor:
return self.divide_(other, rounding_mode=rounding_mode)""",
)})
def test_get_upgraders_missing_from_entry_map_raises(self):
with patch.object(torch._C, "_get_upgraders_entry_map", return_value={}), \
patch.object(torch._C, "_get_operator_version_map", return_value=TEST_OP_VERSION_MAP):
with self.assertRaises(RuntimeError):
get_upgraders()
def test_upgrader_with_invalid_format_throws_exception(self):
"""Invalid upgrader function string should throw exception"""
upgraders = [("div(Tensor a, Tensor b) -> Tensor", "TEST")]
with self.assertRaises(RuntimeError):
GraphModuleOpUpgrader._populate_passes(upgraders)
def test_get_target_version_invalid_format_throws_exception(self):
with self.assertRaises(RuntimeError):
get_target_version("div_0")
with self.assertRaises(RuntimeError):
get_target_version("div_0_")
with self.assertRaises(RuntimeError):
get_target_version("div")
def test_creates_upgrader_pass(self):
compiler_opset_version = {"aten": 4}
model_opset_version = {"aten": 3}
upgrader = GraphModuleOpUpgrader(compiler_opset_version, model_opset_version, TEST_UPGRADERS)
self.assertEqual(len(upgrader.upgrader_passes), 1)
def test_div_upgrader_replaces_op_with_old_version(self):
def fn(a: torch.Tensor, b):
return torch.ops.aten.div.Scalar_mode(a, b, rounding_mode='trunc')
inputs = (torch.ones([2, 3]) * 4, 2.)
ep = export(fn, inputs, [])
compiler_opset_version = {"aten": 4}
model_opset_version = {"aten": 3}
upgrader = GraphModuleOpUpgrader(compiler_opset_version, model_opset_version, TEST_UPGRADERS)
upgraded = ep._transform(*upgrader.upgrader_passes)
upgraded.graph_module.print_readable()
count = count_op(upgraded.graph, "aten::div.Scalar_mode")
self.assertEqual(count, 0)
custom_op_count = count_op(upgraded.graph, "aten::div__Scalar_mode_0_3")
self.assertEqual(custom_op_count, 1)
def test_div_upgrader_pass_return_new_op_after_retrace(self):
def fn(a: torch.Tensor, b):
return torch.ops.aten.div.Scalar_mode(a, b, rounding_mode='trunc')
inputs = (torch.ones([2, 3]) * 4, 2.)
ep = export(fn, inputs, {}, [])
compiler_opset_version = {"aten": 4}
model_opset_version = {"aten": 3}
upgrader = GraphModuleOpUpgrader(compiler_opset_version, model_opset_version, TEST_UPGRADERS)
count = count_op(ep.graph, "aten::div.Scalar_mode")
self.assertEqual(count, 1)
# upgrade: replace op (div.Scalar_mode -> div__Scalar_mode_0_3) then retrace
upgraded_ep = upgrader.upgrade(ep)
upgraded_ep.graph_module.print_readable()
# no old version of op (div__Scalar_mode_0_3) anymore.
custom_op_count = count_op(upgraded_ep.graph, "aten::div__Scalar_mode_0_3")
self.assertEqual(custom_op_count, 0)
# div__Scalar_mode_0_3 decomposes into div.Tensor.
decomposed_op_count = count_op(upgraded_ep.graph, "aten::div.Tensor_mode")
self.assertEqual(decomposed_op_count, 1)
if __name__ == '__main__':
run_tests()