Interpreter for decomposing aten -> prims (#79989)

If an aten -> prim decomposition is needed *after* the initial trace
with make_fx, this interpreter can be used to perform the decomposition.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/79989
Approved by: https://github.com/SherlockNoMad
diff --git a/test/test_proxy_tensor.py b/test/test_proxy_tensor.py
index 09aabb1..d8d3150 100644
--- a/test/test_proxy_tensor.py
+++ b/test/test_proxy_tensor.py
@@ -9,8 +9,9 @@
 from torch.testing._internal.common_methods_invocations import op_db, wrapper_set_seed
 from torch._subclasses.fake_tensor import DynamicOutputShapeException
 
+from torch._decomp import decomposition_table
 from torch.testing._internal.common_device_type import ops
-from torch.fx.experimental.proxy_tensor import make_fx
+from torch.fx.experimental.proxy_tensor import make_fx, DecompositionInterpreter
 
 # Copied from functorch
 def xfail(op_name, variant_name='', *, device_type=None, dtypes=None):
@@ -156,6 +157,35 @@
         self.assertTrue(all([isinstance(node.target, torch._ops.OpOverload)
                              for node in traced.graph.nodes if node.op == 'call_function']))
 
+    def test_decomposition_interpreter(self):
+        def fn(x):
+            return torch.nn.functional.silu(x)
+
+        x = torch.rand((4, 4))
+        fx_module = make_fx(fn, decomposition_table=None)(x)
+
+        found_silu = False
+        for n in fx_module.graph.nodes:
+            if n.target == torch.ops.aten.silu or n.target == torch.ops.aten.silu.default:
+                found_silu = True
+
+        self.assertTrue(found_silu)
+
+        new_graph = torch.fx.Graph()
+        silu_decomp_table = {torch.ops.aten.silu.default: decomposition_table[torch.ops.aten.silu.default]}
+        DecompositionInterpreter(
+            fx_module,
+            new_graph=new_graph,
+            decomposition_table=silu_decomp_table,
+        ).run(x)
+
+        decomposed_module = torch.fx.GraphModule(fx_module, new_graph)
+
+        for n in decomposed_module.graph.nodes:
+            self.assertTrue(n.target != torch.ops.aten.silu)
+            self.assertTrue(n.target != torch.ops.aten.silu.default)
+
+        self.assertEqual(fx_module(x), decomposed_module(x))
 
 make_fx_failures = {
     # unknown