Interpreter for decomposing aten -> prims (#79989)
If an aten -> prim decomposition is needed *after* the initial trace
with make_fx, this interpreter can be used to perform the decomposition.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/79989
Approved by: https://github.com/SherlockNoMad
diff --git a/test/test_proxy_tensor.py b/test/test_proxy_tensor.py
index 09aabb1..d8d3150 100644
--- a/test/test_proxy_tensor.py
+++ b/test/test_proxy_tensor.py
@@ -9,8 +9,9 @@
from torch.testing._internal.common_methods_invocations import op_db, wrapper_set_seed
from torch._subclasses.fake_tensor import DynamicOutputShapeException
+from torch._decomp import decomposition_table
from torch.testing._internal.common_device_type import ops
-from torch.fx.experimental.proxy_tensor import make_fx
+from torch.fx.experimental.proxy_tensor import make_fx, DecompositionInterpreter
# Copied from functorch
def xfail(op_name, variant_name='', *, device_type=None, dtypes=None):
@@ -156,6 +157,35 @@
self.assertTrue(all([isinstance(node.target, torch._ops.OpOverload)
for node in traced.graph.nodes if node.op == 'call_function']))
+ def test_decomposition_interpreter(self):
+ def fn(x):
+ return torch.nn.functional.silu(x)
+
+ x = torch.rand((4, 4))
+ fx_module = make_fx(fn, decomposition_table=None)(x)
+
+ found_silu = False
+ for n in fx_module.graph.nodes:
+ if n.target == torch.ops.aten.silu or n.target == torch.ops.aten.silu.default:
+ found_silu = True
+
+ self.assertTrue(found_silu)
+
+ new_graph = torch.fx.Graph()
+ silu_decomp_table = {torch.ops.aten.silu.default: decomposition_table[torch.ops.aten.silu.default]}
+ DecompositionInterpreter(
+ fx_module,
+ new_graph=new_graph,
+ decomposition_table=silu_decomp_table,
+ ).run(x)
+
+ decomposed_module = torch.fx.GraphModule(fx_module, new_graph)
+
+ for n in decomposed_module.graph.nodes:
+ self.assertTrue(n.target != torch.ops.aten.silu)
+ self.assertTrue(n.target != torch.ops.aten.silu.default)
+
+ self.assertEqual(fx_module(x), decomposed_module(x))
make_fx_failures = {
# unknown
diff --git a/torch/fx/experimental/proxy_tensor.py b/torch/fx/experimental/proxy_tensor.py
index c6b7254..d0ea069 100644
--- a/torch/fx/experimental/proxy_tensor.py
+++ b/torch/fx/experimental/proxy_tensor.py
@@ -17,7 +17,7 @@
from torch.utils._python_dispatch import TorchDispatchMode
-__all__ = ["ProxyTensor", "PythonKeyTracer", "dispatch_trace", "make_fx", "enable_strict"]
+__all__ = ["ProxyTensor", "PythonKeyTracer", "dispatch_trace", "make_fx", "enable_strict", "DecompositionInterpreter"]
aten = torch.ops.aten
CURRENT_DECOMPOSITION_TABLE: Dict[torch._ops.OpOverload, Callable] = {}
@@ -225,6 +225,38 @@
return wrap_output(inner_res, proxy_res)
+class DecompositionInterpreter(torch.fx.Interpreter):
+ def __init__(self, module: torch.fx.GraphModule, new_graph: torch.fx.Graph, decomposition_table=None, **kwargs):
+ super().__init__(module, **kwargs)
+ self.new_graph = new_graph
+ self.tracer = torch.fx.proxy.GraphAppendingTracer(self.new_graph)
+ self.decomposition_table = decomposition_table
+ if self.decomposition_table is None:
+ self.decomposition_table = {}
+
+ def placeholder(self, target, args, kwargs):
+ out = super().placeholder(target, args, kwargs)
+ # TODO handle case where the first character of target is '*'
+ return ProxyTensor(out, torch.fx.Proxy(self.new_graph.placeholder(target), self.tracer))
+
+ def get_attr(self, target, args, kwargs):
+ out = super().get_attr(target, args, kwargs)
+ return ProxyTensor(out, torch.fx.Proxy(self.new_graph.get_attr(target), self.tracer))
+
+ # call_function, call_method, call_module get traced automatically by the ProxyTensors.
+
+ def output(self, target, args, kwargs):
+ out = super().output(target, args, kwargs)
+
+ def unwrap(e):
+ return e.proxy.node if isinstance(e, ProxyTensor) else e
+ self.new_graph.output(pytree.tree_map(unwrap, out))
+ return out
+
+ def run(self, *args, **kwargs):
+ with decompose(self.decomposition_table):
+ return super().run(*args, **kwargs)
+
def make_fx(f, decomposition_table=None, trace_factory_functions=True, use_fake=False):
if decomposition_table is None:
decomposition_table = {}