blob: aa9170ea4e90a90fc73b4989bbd54f6fce93c9a4 [file] [log] [blame]
# Owner(s): ["module: dynamo"]
import unittest
from typing import List
import torch
from functorch.experimental import control_flow
from torch._dynamo.eval_frame import is_dynamo_supported
from torch._export import export
from torch._export.constraints import constrain_as_value
from torch._export.pass_base import _ExportPassBase
from torch.testing._internal.common_utils import run_tests, TestCase
@unittest.skipIf(not is_dynamo_supported(), "Dynamo not supported")
class TestPassInfra(TestCase):
def test_export_pass_base(self) -> None:
def f(x: torch.Tensor) -> List[torch.Tensor]:
y = torch.cat([x, x])
return torch.ops.aten.tensor_split.sections(y, 2)
class NullPass(_ExportPassBase):
pass
ep = export(f, (torch.ones(3, 2),))
old_nodes = ep.graph.nodes
ep = ep._transform(NullPass())
new_nodes = ep.graph.nodes
for node in new_nodes:
if node.op != "call_function":
continue
self.assertTrue(hasattr(node, "stack_trace"))
self.assertIsNotNone(node.stack_trace)
self.assertEqual(len(new_nodes), len(old_nodes))
for new_node, old_node in zip(new_nodes, old_nodes):
self.assertEqual(new_node.op, old_node.op)
self.assertEqual(new_node.target, old_node.target)
def test_cond(self) -> None:
class M(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, pred, x, y):
def true_fn(x, y):
b = x.item()
constrain_as_value(b, min=2, max=5)
return x - y
def false_fn(x, y):
c = y.item()
constrain_as_value(c, min=2, max=5)
return x + y
ret = control_flow.cond(pred, true_fn, false_fn, [x, y])
return ret
x = torch.tensor([2])
y = torch.tensor([5])
mod = M()
_ = export(mod, (torch.tensor(True), x, y))._transform(_ExportPassBase())
def test_node_name_stability(self) -> None:
# Tests that graph nodes stay the same for nodes that are not touched
# during transformation
class CustomModule(torch.nn.Module):
def __init__(self):
super().__init__()
# Define a parameter
self.my_parameter = torch.nn.Parameter(torch.tensor(2.0))
# Define two buffers
self.register_buffer('my_buffer1', torch.tensor(3.0))
self.register_buffer('my_buffer2', torch.tensor(4.0))
def forward(self, x1, x2):
# Use the parameter, buffers, and both inputs in the forward method
output = (x1 + self.my_parameter) * self.my_buffer1 + x2 * self.my_buffer2
# Mutate one of the buffers (e.g., increment it by 1)
self.my_buffer2.add_(1.0)
return output
inps = (torch.rand(1), torch.rand(1))
m = CustomModule()
ep_before = export(m, inps)
# No op transformation that doesn't perform any meaningful changes to node
ep_after = ep_before._transform(_ExportPassBase())
for before_node, after_node in zip(ep_before.graph.nodes, ep_after.graph.nodes):
self.assertEqual(before_node.name, after_node.name)
def test_graph_signature_updated_after_transformation(self) -> None:
# Checks that pass infra correctly updates graph signature
# after transformations.
class CustomModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.my_parameter = torch.nn.Parameter(torch.tensor(2.0))
self.register_buffer("my_buffer1", torch.tensor(3.0))
self.register_buffer("my_buffer2", torch.tensor(4.0))
def forward(self, x1, x2):
# Use the parameter, buffers, and both inputs in the forward method
output = (
x1 + self.my_parameter
) * self.my_buffer1 + x2 * self.my_buffer2
return output
my_module = CustomModule()
# Test the custom module with two input tensors
input_tensor1 = torch.tensor(5.0)
input_tensor2 = torch.tensor(6.0)
ep_before = torch._export.export(my_module, (input_tensor1, input_tensor2))
from torch.fx.passes.infra.pass_base import PassResult
def modify_input_output_pass(gm):
for node in gm.graph.nodes:
if node.op == "call_function":
node.name = node.name + "_modified"
gm.recompile()
return PassResult(gm, True)
ep_after = ep_before._transform(modify_input_output_pass)
new_signature = ep_after.graph_signature
for node_name in new_signature.user_outputs:
self.assertTrue("_modified" in node_name)
old_signature = ep_before.graph_signature
self.assertNotEqual(new_signature.user_outputs, old_signature.user_outputs)
if __name__ == '__main__':
run_tests()