blob: a438a6d2d050602dc2d64fb200cee5e197cec2e2 [file] [log] [blame]
"""
PYTEST_DONT_REWRITE (prevents pytest from rewriting assertions, which interferes
with test_functionalization_with_native_python_assertion)
"""
# Owner(s): ["oncall: export"]
import math
import operator
import unittest
from re import escape
from typing import List, Set
import torch
from functorch.experimental.control_flow import cond
from torch._dynamo.eval_frame import is_dynamo_supported
from torch._export.non_strict_utils import (
_fakify_script_objects,
_gather_constant_attrs,
)
from torch._export.pass_base import _ExportPassBaseDeprecatedDoNotUse
from torch._export.passes.replace_set_grad_with_hop_pass import (
_is_set_grad_enabled_node,
_is_set_grad_enabled_sub_mod,
)
from torch._export.passes.replace_view_ops_with_view_copy_ops_pass import (
get_view_copy_of_view_op,
is_view_op,
ReplaceViewOpsWithViewCopyOpsPass,
)
from torch._export.utils import (
node_inline_,
nodes_count,
nodes_filter,
nodes_map,
sequential_split,
)
from torch._higher_order_ops.auto_functionalize import auto_functionalized
from torch._subclasses.fake_tensor import FakeTensorMode
from torch.export import export
from torch.export._remove_auto_functionalized_pass import (
unsafe_remove_auto_functionalized_pass,
)
from torch.export._remove_effect_tokens_pass import _remove_effect_tokens
from torch.fx.experimental.symbolic_shapes import ShapeEnv
from torch.fx.passes.infra.partitioner import Partition
from torch.fx.passes.operator_support import OperatorSupport
from torch.library import _scoped_library, impl
from torch.testing._internal.common_utils import (
IS_WINDOWS,
run_tests,
skipIfTorchDynamo,
TestCase,
)
from torch.testing._internal.torchbind_impls import init_torchbind_implementations
from torch.utils import _pytree as pytree
def count_call_function(graph: torch.fx.Graph, target: torch.ops.OpOverload) -> int:
count = 0
for node in graph.nodes:
if node.op == "call_function" and node.target == target:
count += 1
return count
class _AddOperatorSupport(OperatorSupport):
def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
return node.op == "call_function" and node.target in {operator.add}
class _AtenAddOperatorSupport(OperatorSupport):
def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
return node.op == "call_function" and node.target in {torch.ops.aten.add.Tensor}
def _to_partition_names(partitions: List[Partition]) -> List[Set[str]]:
return [{n.name for n in p.nodes} for p in partitions]
def _get_output_names(gm: torch.fx.GraphModule) -> List[str]:
output_node = next(n for n in gm.graph.nodes if n.op == "output")
args = pytree.tree_leaves(output_node.args)
# if isinstance(args, tuple) and len(args) == 1:
# args = args[0]
return [str(arg) for arg in args]
class ModelsWithScriptObjectAttr:
class Simple(torch.nn.Module):
def __init__(self):
super().__init__()
self.attr = torch.classes._TorchScriptTesting._Foo(10, 20)
class SimpleWithAttrInContainer(torch.nn.Module):
def __init__(self):
super().__init__()
self.attr = torch.classes._TorchScriptTesting._Foo(10, 20)
self.pytree_attr2 = [
torch.classes._TorchScriptTesting._Foo(1, 2),
{
torch.classes._TorchScriptTesting._Foo(3, 4),
},
{"foo": torch.classes._TorchScriptTesting._Foo(5, 6)},
]
class NestedWithAttrInContainer(torch.nn.Module):
def __init__(self):
super().__init__()
self.attr = torch.classes._TorchScriptTesting._Foo(10, 20)
self.pytree_attr2 = [
torch.classes._TorchScriptTesting._Foo(1, 2),
{
torch.classes._TorchScriptTesting._Foo(3, 4),
},
{"foo": torch.classes._TorchScriptTesting._Foo(5, 6)},
]
self.sub_mod = ModelsWithScriptObjectAttr.Simple()
self.sub_mod2 = ModelsWithScriptObjectAttr.SimpleWithAttrInContainer()
class MoreNestedWithAttrInContainer(torch.nn.Module):
def __init__(self):
super().__init__()
self.attr = torch.classes._TorchScriptTesting._Foo(10, 20)
self.pytree_attr2 = [
torch.classes._TorchScriptTesting._Foo(1, 2),
{
torch.classes._TorchScriptTesting._Foo(3, 4),
},
{"foo": torch.classes._TorchScriptTesting._Foo(5, 6)},
]
self.sub_mod = ModelsWithScriptObjectAttr.Simple()
self.sub_mod2 = ModelsWithScriptObjectAttr.NestedWithAttrInContainer()
def _set_grad_enabled_tests():
from torch.export._trace import _export
class SetGradOp(torch.nn.Module):
def forward(self, x):
x = x + 1
torch._C._set_grad_enabled(True)
c = x.sin().sum()
torch._C._set_grad_enabled(False)
d = c + 1
torch._C._set_grad_enabled(True)
e = d - 1
return d, e
class SetGradCtxManager(torch.nn.Module):
def forward(self, x):
x = x + 1
with torch.enable_grad():
c = x.sin().sum()
with torch.no_grad():
d = c + 1
with torch.enable_grad():
e = d - 1
return d, e
class SetGradCtxManagerMultiDep(torch.nn.Module):
def forward(self, x):
x = x + 1
with torch.enable_grad():
c1 = x.sin().sum()
c2 = x.cos().sum()
with torch.no_grad():
d1 = c1 + 1
d2 = c2 + 1
with torch.enable_grad():
e1 = d1 - 1
e2 = d2 - 1
return d1, d2, e1, e2
x = torch.randn(2, 2)
def _get_predispatch_module(mod, args, ambient_grad_enabled=True):
with torch.set_grad_enabled(ambient_grad_enabled):
return _export(mod, args, pre_dispatch=True).module()
return {
"ctx_manager": (_get_predispatch_module(SetGradCtxManager(), (x,)), (x,)),
"ctx_manager_under_no_grad": (
_get_predispatch_module(SetGradCtxManager(), (x,), False),
(x,),
),
"ctx_manager_multi_dep": (
_get_predispatch_module(SetGradCtxManagerMultiDep(), (x,)),
(x,),
),
"ctx_manager_multi_dep_no_grad": (
_get_predispatch_module(SetGradCtxManagerMultiDep(), (x,), False),
(x,),
),
"op": (_get_predispatch_module(SetGradOp(), (x,)), (x,)),
"op_under_no_grad": (_get_predispatch_module(SetGradOp(), (x,), False), (x,)),
}
def _sequential_split_inline_tests():
from torch.export._trace import _export
class Simple(torch.nn.Module):
def forward(self, x):
x = x + 1
c = x.sin().sum()
d = c + 1
e = d - 1
return d, e
class MultiDep(torch.nn.Module):
def forward(self, x1, x2):
x1 = x1 + 1
x2 = x2 + 1
c1 = x1.sin()
c2 = x2.cos()
d1 = c1 + 1
d2 = c2 + 1
e1 = d1 - 1
e2 = d2 - 1
return d1, d2, e1, e2
def _get_predispatch_module(mod, args):
return _export(mod, args, pre_dispatch=True).module()
def _insert_dilimiter_nodes(gm: torch.fx.GraphModule, step: int = 1):
insert_locs = []
for i, node in enumerate(
nodes_filter(gm.graph.nodes, lambda n: n.op == "call_function")
):
if i % step == 0:
insert_locs.append(node)
for i, node in enumerate(insert_locs):
with gm.graph.inserting_before(node):
gm.graph.call_function(
torch._C._set_grad_enabled, (True if i % 2 == 0 else False,), {}
)
return gm
x = torch.randn(2, 2)
simple = _get_predispatch_module(Simple(), (x,))
simple1 = _get_predispatch_module(Simple(), (x,))
multi_dep = _get_predispatch_module(MultiDep(), (x, x.sin()))
multi_dep1 = _get_predispatch_module(MultiDep(), (x, x.sin()))
return {
"simple_step1": (_insert_dilimiter_nodes(simple1, 1), (x,)),
"simple_step2": (_insert_dilimiter_nodes(simple, 2), (x,)),
"multi_dep_step2": (_insert_dilimiter_nodes(multi_dep, 2), (x, x.sin())),
"multi_dep_step3": (_insert_dilimiter_nodes(multi_dep1, 3), (x, x.sin())),
}
@skipIfTorchDynamo("recursively running dynamo on export is unlikely")
@unittest.skipIf(not is_dynamo_supported(), "Dynamo not supported")
class TestPasses(TestCase):
def setUp(self):
super().setUp()
self.SEQUENTIAL_SPLIT_INLINE_TESTS = _sequential_split_inline_tests()
self.SET_GRAD_ENABLED_TESTS = _set_grad_enabled_tests()
init_torchbind_implementations()
def tearDown(self):
self.SEQUENTIAL_SPLIT_INLINE_TESTS.clear()
self.SET_GRAD_ENABLED_TESTS.clear()
super().tearDown()
def test_runtime_assert_one_dim(self) -> None:
class M(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
return x.cos()
x = torch.zeros(2, 2, 3)
dim1_x = torch.export.Dim("dim1_x", min=2, max=6)
ep = torch.export.export(M(), (x,), dynamic_shapes={"x": {1: dim1_x}})
with self.assertRaisesRegex(
RuntimeError,
escape("Expected input at *args[0].shape[1] to be <= 6, but got 7"),
):
ep.module()(torch.zeros(2, 7, 3))
self.assertEqual(
ep.module()(torch.ones(2, 4, 3)), M().forward(torch.ones(2, 4, 3))
)
def test_runtime_assert_multiple_dims(self) -> None:
class M(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x, y):
return x.cos().sum() + y.sin().sum()
x = torch.zeros(4, 2, 3)
y = torch.zeros(5, 5, 5)
dim1_x = torch.export.Dim("dim1_x", min=2, max=6)
dim0_x, dim0_y = torch.export.dims("dim0_x", "dim0_y", min=3)
ep = torch.export.export(
M(), (x, y), dynamic_shapes={"x": {0: dim0_x, 1: dim1_x}, "y": {0: dim0_y}}
)
with self.assertRaisesRegex(
RuntimeError,
escape("Expected input at *args[0].shape[1] to be <= 6, but got 7"),
):
ep.module()(torch.zeros(4, 7, 3), torch.ones(5, 5, 5))
with self.assertRaisesRegex(
RuntimeError,
escape("Expected input at *args[1].shape[0] to be >= 3, but got 2"),
):
ep.module()(torch.zeros(4, 2, 3), torch.ones(2, 5, 5))
def test_runtime_assert_some_dims_not_specified(self) -> None:
class M(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x, y):
return x.cos().sum() + y.sin().sum()
x = torch.zeros(4, 2, 3)
y = torch.zeros(5, 5, 5)
dim1_x = torch.export.Dim("dim1_x", min=2, max=6)
dim0_x = torch.export.Dim("dim0_x", min=3)
ep = torch.export.export(
M(), (x, y), dynamic_shapes={"x": {0: dim0_x, 1: dim1_x}, "y": None}
)
with self.assertRaisesRegex(
RuntimeError,
escape("Expected input at *args[0].shape[1] to be <= 6, but got 7"),
):
ep.module()(torch.zeros(4, 7, 3), torch.ones(5, 5, 5))
# y is specialized to 5
with self.assertRaisesRegex(
RuntimeError,
escape("Expected input at *args[1].shape[0] to be equal to 5, but got 2"),
):
ep.module()(torch.zeros(4, 2, 3), torch.ones(2, 5, 5))
# Since we didn't insert the constraint for x[1] >= 2, it should work for case where x[1] == 1
gm_result_for_1_size = ep.module()(torch.ones(3, 1, 3), torch.ones(5, 5, 5))
eager_result_for_1_size = M().forward(torch.ones(3, 1, 3), torch.ones(5, 5, 5))
self.assertEqual(gm_result_for_1_size, eager_result_for_1_size)
def test_runtime_assert_some_inps_not_used(self) -> None:
class M(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x, y):
return y.cos().sum()
x = torch.zeros(4, 2, 3)
y = torch.zeros(5, 5, 5)
dim1_y = torch.export.Dim("dim1_y", min=3, max=6)
ep = torch.export.export(
M(), (x, y), dynamic_shapes={"x": None, "y": {1: dim1_y}}
)
with self.assertRaisesRegex(RuntimeError, escape("shape[1] to be equal to 2")):
ep.module()(torch.zeros(4, 7, 3), torch.ones(5, 5, 5))
# y is specialized to 5
with self.assertRaisesRegex(
RuntimeError,
escape("Expected input at *args[1].shape[0] to be equal to 5, but got 2"),
):
ep.module()(torch.zeros(4, 2, 3), torch.ones(2, 5, 5))
# Since we didn't insert the constraint for x[1] >= 2, it should work for case where x[1] == 1
gm_result_for_1_size = ep.module()(torch.zeros(4, 2, 3), torch.ones(5, 5, 5))
eager_result_for_1_size = M().forward(torch.zeros(4, 2, 3), torch.ones(5, 5, 5))
self.assertEqual(gm_result_for_1_size, eager_result_for_1_size)
def test_view_to_view_copy(self) -> None:
class M(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
z = x.view(x.shape)
return z.cos().sum()
x = torch.zeros(4, 2, 3)
ep = export(M(), (x,))
self.assertEqual(count_call_function(ep.graph, torch.ops.aten.view.default), 1)
ep = ep._transform_do_not_use(ReplaceViewOpsWithViewCopyOpsPass())
self.assertEqual(count_call_function(ep.graph, torch.ops.aten.view.default), 0)
def test_functionalization_with_view_copy(self) -> None:
class Module(torch.nn.Module):
def forward(self, x):
y = x + 4
y.add_(4)
z = y.view(y.shape)
return x.cos() + z.cos()
x = torch.zeros(4, 2, 3)
foo = Module()
ep = export(foo, (x,))._transform_do_not_use(
ReplaceViewOpsWithViewCopyOpsPass()
)
# After this pass, there shouldn't be any view nodes in the graph
self.assertTrue(count_call_function(ep.graph, torch.ops.aten.view.default) == 0)
self.assertTrue(
count_call_function(ep.graph, torch.ops.aten.view_copy.default) > 0
)
def test_views_op_having_view_copy(self) -> None:
schemas = torch._C._dispatch_get_registrations_for_dispatch_key("")
aten_schemas = [s[6:] for s in schemas if s.startswith("aten::")]
for aten_schema in aten_schemas:
val = aten_schema.split(".")
assert len(val) <= 2
name = ""
overload = ""
if len(val) == 1:
name = val[0]
overload = "default"
else:
name, overload = val[0], val[1]
op_overload = getattr(getattr(torch.ops.aten, name), overload)
if torch.Tag.core in op_overload.tags and is_view_op(op_overload._schema):
self.assertIsNotNone(get_view_copy_of_view_op(op_overload._schema))
def test_custom_obj_tuple_out(self):
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.attr = torch.classes._TorchScriptTesting._Foo(10, 20)
def forward(self, x):
a = torch.ops._TorchScriptTesting.takes_foo_tuple_return(self.attr, x)
y = a[0] + a[1]
b = torch.ops._TorchScriptTesting.takes_foo(self.attr, y)
return b
m = MyModule()
inputs = (torch.ones(2, 3),)
ep = torch.export.export(m, inputs, strict=False)
inp = torch.randn(2, 3)
orig_res = m(inp)
ep_res = ep.module()(inp)
without_token_ep = _remove_effect_tokens(ep)
without_token_ep.verifier().check(without_token_ep)
without_token_res = without_token_ep.module()(inp)
self.assertTrue(torch.allclose(orig_res, ep_res))
self.assertTrue(torch.allclose(orig_res, without_token_res))
def test_fakify_script_objects(self):
for m in [
ModelsWithScriptObjectAttr.Simple(),
ModelsWithScriptObjectAttr.SimpleWithAttrInContainer(),
ModelsWithScriptObjectAttr.NestedWithAttrInContainer(),
ModelsWithScriptObjectAttr.MoreNestedWithAttrInContainer(),
]:
constant_attrs = _gather_constant_attrs(m)
fake_mode = FakeTensorMode(
shape_env=ShapeEnv(tracked_fakes=[]),
allow_non_fake_inputs=True,
)
with _fakify_script_objects(m, (), {}, fake_mode) as (
patched_mod,
_,
_,
fake_constant_attrs,
fake_to_real,
):
self.assertEqual(len(fake_constant_attrs), len(constant_attrs))
for fake_obj, fqn in fake_constant_attrs.items():
self.assertEqual(constant_attrs[fake_to_real[fake_obj]], fqn)
# TODO: _gather_constants doesn't recursively look into the pytree containers.
@unittest.expectedFailure
def test_fakify_script_objects_properly_handle_containers(self):
m = ModelsWithScriptObjectAttr.SimpleWithAttrInContainer()
constant_attrs = _gather_constant_attrs(m)
fake_mode = FakeTensorMode(
shape_env=ShapeEnv(tracked_fakes=[]),
allow_non_fake_inputs=True,
)
with _fakify_script_objects(m, (), {}, fake_mode) as (
patched_mod,
_,
_,
fake_constant_attrs,
fake_to_real,
):
self.assertTrue("attr" in fake_constant_attrs.values())
self.assertTrue("pytree_attr2" in fake_constant_attrs.values())
def test_runtime_assert_inline_constraints_for_item(self) -> None:
class M(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
b = x.item()
torch._check(b >= 2)
torch._check(b <= 5)
return b
x = torch.tensor([2])
mod = M()
ep = export(mod, (x,))
with self.assertRaisesRegex(
RuntimeError, r"Runtime assertion failed for expression u[\d+] \<\= 5"
):
ep.module()(torch.tensor([6]))
new_inp = torch.tensor([5])
self.assertEqual(mod(new_inp), ep.module()(new_inp))
def test_runtime_assert_inline_constraints_for_nonzero(self) -> None:
class M(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
b = x.nonzero()
torch._check(b.shape[0] >= 3)
torch._check(b.shape[0] <= 5)
return b
x = torch.tensor([2, 1, 2, 3, 5, 0])
mod = M()
dim0_x = torch.export.Dim("dim0_x")
ep = torch.export.export(mod, (x,), dynamic_shapes={"x": {0: dim0_x}})
num_assert = count_call_function(
ep.graph, torch.ops.aten._assert_scalar.default
)
self.assertEqual(num_assert, 2)
num_constrain_range = count_call_function(
ep.graph, torch.ops.aten.sym_constrain_range.default
)
self.assertEqual(num_constrain_range, 0)
with self.assertRaisesRegex(
RuntimeError,
r"Runtime assertion failed for expression u[\d+] \>\= 3",
):
ep.module()(torch.tensor([1, 1, 0, 0, 0]))
with self.assertRaisesRegex(
RuntimeError,
r"Runtime assertion failed for expression u[\d+] \<\= 5",
):
ep.module()(torch.ones(6))
new_inp = torch.tensor([1, 1, 1, 1])
self.assertEqual(mod(new_inp), ep.module()(new_inp))
@unittest.skipIf(IS_WINDOWS, "Windows not supported")
@unittest.expectedFailure
# TODO(pianpwk): add back runtime asserts to subgraphs
def test_runtime_assert_inline_constraints_for_cond(self) -> None:
class M(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, pred, x, y):
def true_fn(x, y):
b = x.item()
torch._check(b >= 2)
torch._check(b <= 5)
return x - b
def false_fn(x, y):
c = y.item()
torch._check(c >= 2)
torch._check(c <= 5)
return y - c
ret = cond(pred, true_fn, false_fn, [x, y])
return ret
x = torch.tensor([2])
y = torch.tensor([5])
mod = M()
ep = export(mod, (torch.tensor(True), x, y))
with self.assertRaisesRegex(
RuntimeError, "is outside of inline constraint \\[2, 5\\]."
):
ep.module()(torch.tensor(False), torch.tensor([6]), torch.tensor([6]))
def test_math_ops(self):
class Module(torch.nn.Module):
def forward(self, x):
return (
torch.tensor([math.ceil(x.item())]),
torch.tensor([math.floor(x.item())]),
)
func = Module()
x = torch.randn(1, dtype=torch.float32)
ep = torch.export.export(func, args=(x,))
_ExportPassBaseDeprecatedDoNotUse()(ep.graph_module)
def test_predispatceh_set_grad(self):
def _check_node_users_in_the_same_graph(gm):
for node in gm.graph.nodes:
for user in node.users:
self.assertTrue(user.graph is gm.graph)
mod, args = self.SET_GRAD_ENABLED_TESTS["op"]
_check_node_users_in_the_same_graph(mod)
self.assertExpectedInline(
mod.code.strip("\n"),
"""\
def forward(self, x):
x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
add = torch.ops.aten.add.Tensor(x, 1); x = None
sin = torch.ops.aten.sin.default(add); add = None
sum_1 = torch.ops.aten.sum.default(sin); sin = None
submod_4 = self.submod_2
add_1 = torch._higher_order_ops.wrap.wrap_with_set_grad_enabled(False, submod_4, sum_1); submod_4 = sum_1 = None
sub = torch.ops.aten.sub.Tensor(add_1, 1)
return pytree.tree_unflatten((add_1, sub), self._out_spec)
""",
)
mod, args = self.SET_GRAD_ENABLED_TESTS["op_under_no_grad"]
_check_node_users_in_the_same_graph(mod)
self.assertExpectedInline(
mod.code.strip("\n"),
"""\
def forward(self, x):
x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
add = torch.ops.aten.add.Tensor(x, 1); x = None
sin = torch.ops.aten.sin.default(add); add = None
sum_1 = torch.ops.aten.sum.default(sin); sin = None
submod_4 = self.submod_2
add_1 = torch._higher_order_ops.wrap.wrap_with_set_grad_enabled(False, submod_4, sum_1); submod_4 = sum_1 = None
sub = torch.ops.aten.sub.Tensor(add_1, 1)
return pytree.tree_unflatten((add_1, sub), self._out_spec)
""",
)
mod, args = self.SET_GRAD_ENABLED_TESTS["ctx_manager"]
_check_node_users_in_the_same_graph(mod)
self.assertExpectedInline(
mod.code.strip("\n"),
"""\
def forward(self, x):
x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
add = torch.ops.aten.add.Tensor(x, 1); x = None
sin = torch.ops.aten.sin.default(add); add = None
sum_1 = torch.ops.aten.sum.default(sin); sin = None
submod_3 = self.submod_1
add_1 = torch._higher_order_ops.wrap.wrap_with_set_grad_enabled(False, submod_3, sum_1); submod_3 = sum_1 = None
sub = torch.ops.aten.sub.Tensor(add_1, 1)
return pytree.tree_unflatten((add_1, sub), self._out_spec)
""",
)
mod, args = self.SET_GRAD_ENABLED_TESTS["ctx_manager_under_no_grad"]
_check_node_users_in_the_same_graph(mod)
self.assertExpectedInline(
mod.code.strip("\n"),
"""\
def forward(self, x):
x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
add = torch.ops.aten.add.Tensor(x, 1); x = None
submod_5 = self.submod_1
sum_1 = torch._higher_order_ops.wrap.wrap_with_set_grad_enabled(True, submod_5, add); submod_5 = add = None
add_1 = torch.ops.aten.add.Tensor(sum_1, 1); sum_1 = None
submod_6 = self.submod_3
sub = torch._higher_order_ops.wrap.wrap_with_set_grad_enabled(True, submod_6, add_1); submod_6 = None
return pytree.tree_unflatten((add_1, sub), self._out_spec)
""",
)
mod, args = self.SET_GRAD_ENABLED_TESTS["ctx_manager_multi_dep"]
_check_node_users_in_the_same_graph(mod)
self.assertExpectedInline(
mod.code.strip("\n"),
"""\
def forward(self, x):
x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
add = torch.ops.aten.add.Tensor(x, 1); x = None
sin = torch.ops.aten.sin.default(add)
sum_1 = torch.ops.aten.sum.default(sin); sin = None
cos = torch.ops.aten.cos.default(add); add = None
sum_2 = torch.ops.aten.sum.default(cos); cos = None
submod_3 = self.submod_1
wrap_with_set_grad_enabled = torch._higher_order_ops.wrap.wrap_with_set_grad_enabled(False, submod_3, sum_1, sum_2); submod_3 = sum_1 = sum_2 = None
add_1 = wrap_with_set_grad_enabled[0]
add_2 = wrap_with_set_grad_enabled[1]; wrap_with_set_grad_enabled = None
sub = torch.ops.aten.sub.Tensor(add_1, 1)
sub_1 = torch.ops.aten.sub.Tensor(add_2, 1)
return pytree.tree_unflatten((add_1, add_2, sub, sub_1), self._out_spec)
""", # noqa: B950
)
mod, args = self.SET_GRAD_ENABLED_TESTS["ctx_manager_multi_dep_no_grad"]
_check_node_users_in_the_same_graph(mod)
self.assertExpectedInline(
mod.code.strip("\n"),
"""\
def forward(self, x):
x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
add = torch.ops.aten.add.Tensor(x, 1); x = None
submod_5 = self.submod_1
wrap_with_set_grad_enabled = torch._higher_order_ops.wrap.wrap_with_set_grad_enabled(True, submod_5, add); submod_5 = add = None
sum_1 = wrap_with_set_grad_enabled[0]
sum_2 = wrap_with_set_grad_enabled[1]; wrap_with_set_grad_enabled = None
add_1 = torch.ops.aten.add.Tensor(sum_1, 1); sum_1 = None
add_2 = torch.ops.aten.add.Tensor(sum_2, 1); sum_2 = None
submod_6 = self.submod_3
wrap_with_set_grad_enabled_1 = torch._higher_order_ops.wrap.wrap_with_set_grad_enabled(True, submod_6, add_1, add_2); submod_6 = None
sub = wrap_with_set_grad_enabled_1[0]
sub_1 = wrap_with_set_grad_enabled_1[1]; wrap_with_set_grad_enabled_1 = None
return pytree.tree_unflatten((add_1, add_2, sub, sub_1), self._out_spec)
""", # noqa: B950
)
def test_sequential_split(self):
for gm, args in self.SEQUENTIAL_SPLIT_INLINE_TESTS.values():
set_grad_counts = nodes_count(gm.graph.nodes, _is_set_grad_enabled_node)
new_gm = sequential_split(gm, _is_set_grad_enabled_node)
new_set_grad_counts = nodes_count(
new_gm.graph.nodes, _is_set_grad_enabled_sub_mod
)
self.assertEqual(set_grad_counts, new_set_grad_counts)
self.assertEqual(gm(*args), new_gm(*args))
def test_sequential_split_graph(self):
gm, args = self.SEQUENTIAL_SPLIT_INLINE_TESTS["multi_dep_step2"]
new_gm = sequential_split(gm, _is_set_grad_enabled_node)
self.assertEqual(gm(*args), new_gm(*args))
self.assertExpectedInline(
new_gm.code.strip("\n"),
"""\
def forward(self, x1, x2):
x1, x2, = fx_pytree.tree_flatten_spec(([x1, x2], {}), self._in_spec)
submod_1 = self.submod_1(x1, x2); x1 = x2 = None
getitem = submod_1[0]
getitem_1 = submod_1[1]; submod_1 = None
submod_2 = self.submod_2(getitem, getitem_1); getitem = getitem_1 = None
getitem_2 = submod_2[0]
getitem_3 = submod_2[1]; submod_2 = None
submod_3 = self.submod_3(getitem_2, getitem_3); getitem_2 = getitem_3 = None
getitem_4 = submod_3[0]
getitem_5 = submod_3[1]; submod_3 = None
submod_4 = self.submod_4(getitem_4, getitem_5)
getitem_6 = submod_4[0]
getitem_7 = submod_4[1]; submod_4 = None
return pytree.tree_unflatten((getitem_4, getitem_5, getitem_6, getitem_7), self._out_spec)
""",
)
self.assertExpectedInline(
new_gm.submod_1.code.strip("\n"),
"""\
def forward(self, x1, x2):
_set_grad_enabled = torch._C._set_grad_enabled(True)
add = torch.ops.aten.add.Tensor(x1, 1); x1 = None
add_1 = torch.ops.aten.add.Tensor(x2, 1); x2 = None
return (add, add_1)
""",
)
self.assertExpectedInline(
new_gm.submod_2.code.strip("\n"),
"""\
def forward(self, add, add_1):
_set_grad_enabled_1 = torch._C._set_grad_enabled(False)
sin = torch.ops.aten.sin.default(add); add = None
cos = torch.ops.aten.cos.default(add_1); add_1 = None
return (sin, cos)
""",
)
self.assertExpectedInline(
new_gm.submod_3.code.strip("\n"),
"""\
def forward(self, sin, cos):
_set_grad_enabled_2 = torch._C._set_grad_enabled(True)
add_2 = torch.ops.aten.add.Tensor(sin, 1); sin = None
add_3 = torch.ops.aten.add.Tensor(cos, 1); cos = None
return (add_2, add_3)
""",
)
def test_inline_(self):
for gm, args in self.SEQUENTIAL_SPLIT_INLINE_TESTS.values():
before_str = gm.print_readable(print_output=False)
new_gm = sequential_split(gm, _is_set_grad_enabled_node)
nodes_map(
new_gm.graph.nodes,
lambda node: node_inline_(node) if node.op == "call_module" else node,
)
after_inline_str = new_gm.print_readable(print_output=False)
self.assertEqual(before_str, after_inline_str)
self.assertEqual(gm(*args), new_gm(*args))
def test_remove_auto_functionalized_pass(self) -> None:
with _scoped_library("DO_NOT_USE_TEST_ONLY", "DEF") as lib:
lib.define("custom_mutator(Tensor x, Tensor(a!) y) -> Tensor")
@impl(lib, "custom_mutator", "Meta")
def custom_mutator_meta(
x: torch.Tensor,
y: torch.Tensor,
) -> torch.Tensor:
return torch.empty_like(x)
@impl(lib, "custom_mutator", "CompositeExplicitAutograd")
def custom_mutator(
x: torch.Tensor,
y: torch.Tensor,
) -> torch.Tensor:
return x + y.add_(1)
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.register_buffer("state", torch.zeros(1))
def forward(self, x):
return torch.ops.DO_NOT_USE_TEST_ONLY.custom_mutator(x, self.state)
mod = M()
x = torch.randn([3, 3])
ep = export(mod, (x,))
inplace_ep = unsafe_remove_auto_functionalized_pass(ep)
nodes = inplace_ep.graph.nodes
for node in nodes:
if node.op == "call_function":
self.assertFalse(node.target is auto_functionalized)
self.assertFalse(node.target is operator.getitem)
for spec in inplace_ep.graph_signature.output_specs:
self.assertFalse("getitem" in spec.arg.name)
def test_remove_auto_functionalized_pass_tuple(self) -> None:
with _scoped_library("DO_NOT_USE_TEST_ONLY", "DEF") as lib:
lib.define(
"custom_mutator_tuple(Tensor x, Tensor(a!) y) -> (Tensor, Tensor)"
)
@impl(lib, "custom_mutator_tuple", "Meta")
def custom_mutator_tuple_meta(
x: torch.Tensor,
y: torch.Tensor,
):
return (torch.empty_like(x), torch.empty_like(x))
@impl(lib, "custom_mutator_tuple", "CompositeExplicitAutograd")
def custom_mutator_tuple(
x: torch.Tensor,
y: torch.Tensor,
):
return (x, x + y.add_(1))
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.register_buffer("state", torch.zeros(1))
def forward(self, x):
return torch.ops.DO_NOT_USE_TEST_ONLY.custom_mutator_tuple(
x, self.state
)
mod = M()
x = torch.randn([3, 3])
ep = export(mod, (x,))
inplace_ep = unsafe_remove_auto_functionalized_pass(ep)
nodes = inplace_ep.graph.nodes
getitems = 0
for node in nodes:
if node.op == "call_function":
self.assertFalse(node.target is auto_functionalized)
if node.target is operator.getitem:
getitems += 1
self.assertEqual(getitems, 2) # tuple return of len 2
out_specs = inplace_ep.graph_signature.output_specs
self.assertEqual(out_specs[0].arg.name, "b_state") # state
self.assertEqual(out_specs[1].arg.name, "getitem") # tuple return 1
self.assertEqual(out_specs[2].arg.name, "getitem_1") # tuple return 2
if __name__ == "__main__":
run_tests()