Interpreter for decomposing aten -> prims (#79989)
If an aten -> prim decomposition is needed *after* the initial trace
with make_fx, this interpreter can be used to perform the decomposition.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/79989
Approved by: https://github.com/SherlockNoMad
diff --git a/test/test_proxy_tensor.py b/test/test_proxy_tensor.py
index 09aabb1..d8d3150 100644
--- a/test/test_proxy_tensor.py
+++ b/test/test_proxy_tensor.py
@@ -9,8 +9,9 @@
from torch.testing._internal.common_methods_invocations import op_db, wrapper_set_seed
from torch._subclasses.fake_tensor import DynamicOutputShapeException
+from torch._decomp import decomposition_table
from torch.testing._internal.common_device_type import ops
-from torch.fx.experimental.proxy_tensor import make_fx
+from torch.fx.experimental.proxy_tensor import make_fx, DecompositionInterpreter
# Copied from functorch
def xfail(op_name, variant_name='', *, device_type=None, dtypes=None):
@@ -156,6 +157,35 @@
self.assertTrue(all([isinstance(node.target, torch._ops.OpOverload)
for node in traced.graph.nodes if node.op == 'call_function']))
+ def test_decomposition_interpreter(self):
+ def fn(x):
+ return torch.nn.functional.silu(x)
+
+ x = torch.rand((4, 4))
+ fx_module = make_fx(fn, decomposition_table=None)(x)
+
+ found_silu = False
+ for n in fx_module.graph.nodes:
+ if n.target == torch.ops.aten.silu or n.target == torch.ops.aten.silu.default:
+ found_silu = True
+
+ self.assertTrue(found_silu)
+
+ new_graph = torch.fx.Graph()
+ silu_decomp_table = {torch.ops.aten.silu.default: decomposition_table[torch.ops.aten.silu.default]}
+ DecompositionInterpreter(
+ fx_module,
+ new_graph=new_graph,
+ decomposition_table=silu_decomp_table,
+ ).run(x)
+
+ decomposed_module = torch.fx.GraphModule(fx_module, new_graph)
+
+ for n in decomposed_module.graph.nodes:
+ self.assertTrue(n.target != torch.ops.aten.silu)
+ self.assertTrue(n.target != torch.ops.aten.silu.default)
+
+ self.assertEqual(fx_module(x), decomposed_module(x))
make_fx_failures = {
# unknown