Interpreter for decomposing aten -> prims (#79989)

If an aten -> prim decomposition is needed *after* the initial trace
with make_fx, this interpreter can be used to perform the decomposition.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/79989
Approved by: https://github.com/SherlockNoMad
diff --git a/test/test_proxy_tensor.py b/test/test_proxy_tensor.py
index 09aabb1..d8d3150 100644
--- a/test/test_proxy_tensor.py
+++ b/test/test_proxy_tensor.py
@@ -9,8 +9,9 @@
 from torch.testing._internal.common_methods_invocations import op_db, wrapper_set_seed
 from torch._subclasses.fake_tensor import DynamicOutputShapeException
 
+from torch._decomp import decomposition_table
 from torch.testing._internal.common_device_type import ops
-from torch.fx.experimental.proxy_tensor import make_fx
+from torch.fx.experimental.proxy_tensor import make_fx, DecompositionInterpreter
 
 # Copied from functorch
 def xfail(op_name, variant_name='', *, device_type=None, dtypes=None):
@@ -156,6 +157,35 @@
         self.assertTrue(all([isinstance(node.target, torch._ops.OpOverload)
                              for node in traced.graph.nodes if node.op == 'call_function']))
 
+    def test_decomposition_interpreter(self):
+        def fn(x):
+            return torch.nn.functional.silu(x)
+
+        x = torch.rand((4, 4))
+        fx_module = make_fx(fn, decomposition_table=None)(x)
+
+        found_silu = False
+        for n in fx_module.graph.nodes:
+            if n.target == torch.ops.aten.silu or n.target == torch.ops.aten.silu.default:
+                found_silu = True
+
+        self.assertTrue(found_silu)
+
+        new_graph = torch.fx.Graph()
+        silu_decomp_table = {torch.ops.aten.silu.default: decomposition_table[torch.ops.aten.silu.default]}
+        DecompositionInterpreter(
+            fx_module,
+            new_graph=new_graph,
+            decomposition_table=silu_decomp_table,
+        ).run(x)
+
+        decomposed_module = torch.fx.GraphModule(fx_module, new_graph)
+
+        for n in decomposed_module.graph.nodes:
+            self.assertTrue(n.target != torch.ops.aten.silu)
+            self.assertTrue(n.target != torch.ops.aten.silu.default)
+
+        self.assertEqual(fx_module(x), decomposed_module(x))
 
 make_fx_failures = {
     # unknown
diff --git a/torch/fx/experimental/proxy_tensor.py b/torch/fx/experimental/proxy_tensor.py
index c6b7254..d0ea069 100644
--- a/torch/fx/experimental/proxy_tensor.py
+++ b/torch/fx/experimental/proxy_tensor.py
@@ -17,7 +17,7 @@
 
 from torch.utils._python_dispatch import TorchDispatchMode
 
-__all__ = ["ProxyTensor", "PythonKeyTracer", "dispatch_trace", "make_fx", "enable_strict"]
+__all__ = ["ProxyTensor", "PythonKeyTracer", "dispatch_trace", "make_fx", "enable_strict", "DecompositionInterpreter"]
 aten = torch.ops.aten
 
 CURRENT_DECOMPOSITION_TABLE: Dict[torch._ops.OpOverload, Callable] = {}
@@ -225,6 +225,38 @@
             return wrap_output(inner_res, proxy_res)
 
 
+class DecompositionInterpreter(torch.fx.Interpreter):
+    def __init__(self, module: torch.fx.GraphModule, new_graph: torch.fx.Graph, decomposition_table=None, **kwargs):
+        super().__init__(module, **kwargs)
+        self.new_graph = new_graph
+        self.tracer = torch.fx.proxy.GraphAppendingTracer(self.new_graph)
+        self.decomposition_table = decomposition_table
+        if self.decomposition_table is None:
+            self.decomposition_table = {}
+
+    def placeholder(self, target, args, kwargs):
+        out = super().placeholder(target, args, kwargs)
+        # TODO handle case where the first character of target is '*'
+        return ProxyTensor(out, torch.fx.Proxy(self.new_graph.placeholder(target), self.tracer))
+
+    def get_attr(self, target, args, kwargs):
+        out = super().get_attr(target, args, kwargs)
+        return ProxyTensor(out, torch.fx.Proxy(self.new_graph.get_attr(target), self.tracer))
+
+    # call_function, call_method, call_module get traced automatically by the ProxyTensors.
+
+    def output(self, target, args, kwargs):
+        out = super().output(target, args, kwargs)
+
+        def unwrap(e):
+            return e.proxy.node if isinstance(e, ProxyTensor) else e
+        self.new_graph.output(pytree.tree_map(unwrap, out))
+        return out
+
+    def run(self, *args, **kwargs):
+        with decompose(self.decomposition_table):
+            return super().run(*args, **kwargs)
+
 def make_fx(f, decomposition_table=None, trace_factory_functions=True, use_fake=False):
     if decomposition_table is None:
         decomposition_table = {}