Interpreter for decomposing aten -> prims (#79989) If an aten -> prim decomposition is needed *after* the initial trace with make_fx, this interpreter can be used to perform the decomposition. Pull Request resolved: https://github.com/pytorch/pytorch/pull/79989 Approved by: https://github.com/SherlockNoMad
diff --git a/test/test_proxy_tensor.py b/test/test_proxy_tensor.py index 09aabb1..d8d3150 100644 --- a/test/test_proxy_tensor.py +++ b/test/test_proxy_tensor.py
@@ -9,8 +9,9 @@ from torch.testing._internal.common_methods_invocations import op_db, wrapper_set_seed from torch._subclasses.fake_tensor import DynamicOutputShapeException +from torch._decomp import decomposition_table from torch.testing._internal.common_device_type import ops -from torch.fx.experimental.proxy_tensor import make_fx +from torch.fx.experimental.proxy_tensor import make_fx, DecompositionInterpreter # Copied from functorch def xfail(op_name, variant_name='', *, device_type=None, dtypes=None): @@ -156,6 +157,35 @@ self.assertTrue(all([isinstance(node.target, torch._ops.OpOverload) for node in traced.graph.nodes if node.op == 'call_function'])) + def test_decomposition_interpreter(self): + def fn(x): + return torch.nn.functional.silu(x) + + x = torch.rand((4, 4)) + fx_module = make_fx(fn, decomposition_table=None)(x) + + found_silu = False + for n in fx_module.graph.nodes: + if n.target == torch.ops.aten.silu or n.target == torch.ops.aten.silu.default: + found_silu = True + + self.assertTrue(found_silu) + + new_graph = torch.fx.Graph() + silu_decomp_table = {torch.ops.aten.silu.default: decomposition_table[torch.ops.aten.silu.default]} + DecompositionInterpreter( + fx_module, + new_graph=new_graph, + decomposition_table=silu_decomp_table, + ).run(x) + + decomposed_module = torch.fx.GraphModule(fx_module, new_graph) + + for n in decomposed_module.graph.nodes: + self.assertTrue(n.target != torch.ops.aten.silu) + self.assertTrue(n.target != torch.ops.aten.silu.default) + + self.assertEqual(fx_module(x), decomposed_module(x)) make_fx_failures = { # unknown