Backend dialect is a special variant of edge dialect, because it contains backend specific nodes and metadata, after backend specific graph transformations. Backend dialect is an optional stage, only needed if we want to introduce backend-awareness into the graph. More specifically, a graph in backend dialect may contain operators or delegated lowered modules (see delegate doc) that are only meaningful to the target backend. One use case is that if we want to fuse operators into a single operator, for example, fusing consecutive addmm + relu to a single operator addmm_relu, we can do that here.
This document describes how to introduce backend specific operators.
Difference between custom ops and backend specific ops: while custom ops are showing up in eager mode, ATen dialect, and edge dialect, backend-specific ops are only being introduced by passes happening after edge dialect.
This dialect allows the introduction of operators that do not conform to the schema defined in the canonical ATen operator set, and which do not appear in any of the dialects above (ATen dialect and edge dialect). Consider using backend operators if your use case satisfies one or more of the following criteria:
linear_relu (equivalent to linear + relu) that can be executed faster on a certain backend.For an operator/subgraph replacement, the common flow is:
In order to facilitate the process, we provide an API to help reduce the effort for ExecuTorch users to do these steps.
To lower edge ops to backend ops, a pass will perform pattern matching to identify the edge ops of interest in the graph, and then replace them with equivalent backend operators. There are two APIs to register such passes:
transform(). An API on ExportProgram that allows users to provide custom passes. Note that this is not guarded by any validator so the soundness of the program is not guaranteed.Example: one such pass is QuantFusion. This pass takes a “canonical quantization pattern”, ie. “dequant - some_op - quant” and fuses this pattern into a single operator that is backend specific, i.e. quantized_decomposed::some_op. Another simpler example is here where we replace sym_size operators to the ones that are understood by ExecuTorch
We provide a decorator bind_pattern_to_op to help users easily register their backend operators into EXIR. This decorator takes:
torch.Library object, it indicates which library or namespace this backend operator belongs to.torch.Library object, only a name is needed. Otherwise we can register the schema if a schema string is being passed in.This decorator should be added to the pattern we are trying to match (and then lower to this backend op) on edge dialect. This way we are registering this pattern as a CompositeImplicitAutograd kernel for this backend operator.
Then the operator can be accessed/used from the passes. The CompositeImplicitAutograd kernel makes sure:
ExportProgram. Once retraced, the backend operator will be decomposed into the ATen ops used in the pattern.Let’s assume a simple program that contains both add and relu operators:
def f(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: z = x + y return torch.ops.aten.relu.default(z)
After lowering to edge dialect it becomes:
graph():
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
%arg1_1 : [num_users=1] = placeholder[target=arg1_1]
%aten_add_tensor : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.add.Tensor](args = (%arg0_1, %arg1_1), kwargs = {})
%aten_relu_default : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.relu.default](args = (%aten_add_tensor,), kwargs = {})
return (aten_relu_default,)
Now I want to write a pass to merge add and relu into add_relu, the first step is to write a pattern:
# In the pattern, we can use edge ops and ATen ops interchangably def pattern(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: z = torch.ops.aten.add.Tensor(x, y) out = torch.ops.aten.relu.default(z) return out
Then we need to create an operator library from the fused operator namespace, then use the decorator on our pattern:
lib = Library("foo_namespace", "DEF") @bind_pattern_to_op(lib, "add_relu(Tensor self, Tensor other) -> Tensor") def pattern(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: z = torch.ops.aten.add.Tensor(x, y) out = torch.ops.aten.relu.default(z) return out
This way we are registering the pattern as a kernel to add_relu and it is ready to be used in a pass. A simple pass looks like this:
class AddReluFusionPass(ExportPass): def call(self, graph_module: GraphModule) -> PassResult: # decorator registers this pattern as a CompositeExplicitAutograd kernel, since there's no kernel registered before. @bind_pattern_to_op(lib, "add_relu") def pattern(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: z = torch.ops.aten.add.Tensor(x, y) out = torch.ops.aten.relu.default(z) return out def replacement(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: return torch.ops.foo_namespace.add_relu.default(x, y) subgraph_rewriter.replace_pattern( graph_module, _trace_and_lower_to_edge_ops(pattern), _trace_and_lower_to_edge_ops(replacement), ) return PassResult(graph_module, True)
The result graph looks like this:
graph():
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
%arg1_1 : [num_users=1] = placeholder[target=arg1_1]
%foo_namespace_add_relu_default : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.foo_namespace.add_relu.default](args = (%arg0_1, %arg1_1), kwargs = {})
return (foo_namespace_add_relu_default,)
There are the backend operators currently using bind_pattern_to_op API.
executorch_prims::add.int(SymInt a, SymInt b) -> SymIntexecutorch_prims::mul.int(SymInt a, SymInt b) -> SymIntexecutorch_prims::sub.int(SymInt a, SymInt b) -> SymIntexecutorch_prims::floordiv.int(SymInt a, SymInt b) -> SymIntexecutorch_prims::truediv.int(Scalar a, Scalar b) -> Scalarexecutorch_prims::sym_float.Scalar(Scalar a) -> Scalarexecutorch_prims::gt.int(SymInt a, SymInt b) -> boolexecutorch_prims::lt.int(SymInt a, SymInt b) -> boolexecutorch_prims::ge.int(SymInt a, SymInt b) -> boolexecutorch_prims::le.int(SymInt a, SymInt b) -> boolexecutorch_prims::eq.int(SymInt a, SymInt b) -> boolexecutorch_prims::mod.Scalar(SymInt a, SymInt b) -> SymIntexecutorch_prims::neg.Scalar(Scalar a) -> Scalarquantized_decomposed::embedding_byte(Tensor weight, Tensor weight_scales, Tensor weight_zero_points, int weight_quant_min, int weight_quant_max, Tensor indices) -> Tensorquantized_decomposed::add(Tensor a, float a_scale, int a_zero_point, int a_quant_min, int a_quant_max, Tensor b, float b_scale, int b_zero_point, int b_quant_min, int b_quant_max, float out_scale, int out_zero_point, int out_quant_min, int out_quant_max) -> Tensor qcquantized_decomposed::add.scalar(Tensor qa, float a_scale, int a_zero_point, int a_quant_min, int a_quant_max, ScalarType a_dtype, Scalar b, float out_scale, int out_zero_point, int out_quant_min, int out_quant_max, ScalarType out_dtype) -> Tensorquantized_decomposed::add_relu(Tensor a, float a_scale, int a_zero_point, int a_quant_min, int a_quant_max, Tensor b, float b_scale, int b_zero_point, int b_quant_min, int b_quant_max, float out_scale, int out_zero_point, int out_quant_min, int out_quant_max) -> Tensor qc