| # Copyright (c) Meta Platforms, Inc. and affiliates. |
| # All rights reserved. |
| # |
| # This source code is licensed under the BSD-style license found in the |
| # LICENSE file in the root directory of this source tree. |
| |
| from typing import Any, cast, Optional, Tuple |
| |
| import executorch.exir as exir |
| import torch |
| |
| from executorch.backends.xnnpack.utils.configs import ( |
| get_transform_passes, |
| get_xnnpack_capture_config, |
| get_xnnpack_edge_compile_config, |
| ) |
| from executorch.exir import ExportedProgram |
| from executorch.exir.dialects._ops import ops as exir_ops |
| |
| from torch._export.utils import ( |
| get_buffer, |
| get_lifted_tensor_constant, |
| get_param, |
| is_buffer, |
| is_lifted_tensor_constant, |
| is_param, |
| ) |
| |
| |
| ### XNNPACK Capture ### |
| def capture_graph_for_xnnpack( |
| module: torch.nn.Module, |
| inputs: Tuple[torch.Tensor], |
| enable_aot: Optional[bool] = None, |
| unlift: Optional[bool] = None, |
| ) -> exir.ExirExportedProgram: |
| return ( |
| exir.capture( |
| module, |
| inputs, |
| get_xnnpack_capture_config(enable_aot=enable_aot, unlift=unlift), |
| ) |
| .to_edge(get_xnnpack_edge_compile_config()) |
| .transform(*get_transform_passes()) |
| ) |
| |
| |
| ### XNNPACK Utils ### |
| PERM_NCHW_TO_NHWC = [0, 2, 3, 1] |
| PERM_NHWC_TO_NCHW = [0, 3, 1, 2] |
| |
| |
| def check_or_raise(condition: bool, err: str) -> None: |
| """ |
| Raises runtime error if condition is false, with the given error message |
| |
| Args: |
| condition: boolean condition to check |
| err: error message to raise if condition is not true |
| """ |
| if not condition: |
| raise RuntimeError(err) |
| |
| |
| def is_node(node: Any) -> bool: |
| """ |
| returns true if node is a torch.fx.Node, otherwise false |
| """ |
| return isinstance(node, torch.fx.Node) |
| |
| |
| def is_getitem(node: torch.fx.Node) -> bool: |
| if node.op != "call_function": |
| return False |
| |
| return node.target.__name__ == "getitem" # pyre-ignore |
| |
| |
| def get_input_node(node: torch.fx.Node, input_index: int) -> torch.fx.Node: |
| return cast(torch.fx.Node, node.args[input_index]) |
| |
| |
| def get_relu_fused_node(node: torch.fx.Node) -> Optional[torch.fx.Node]: |
| """ |
| Checks if the current node is only consumed by a relu node and can be fused, |
| if so, we return the relu node that can be fused, otherwise return None |
| """ |
| if ( |
| len(node.users) == 1 |
| and list(node.users.keys())[0].target == exir_ops.edge.aten.relu.default |
| ): |
| relu_node = list(node.users.keys())[0] |
| return relu_node |
| |
| return None |
| |
| |
| def is_get_attr_node(node: torch.fx.Node) -> bool: |
| """ |
| Returns true if the given node is a get attr node for a tensor of the model |
| """ |
| return isinstance(node, torch.fx.Node) and node.op == "get_attr" |
| |
| |
| def is_param_node(exp_prog: ExportedProgram, node: torch.fx.Node) -> bool: |
| return ( |
| is_get_attr_node(node) |
| or is_param(exp_prog, node) |
| or is_buffer(exp_prog, node) |
| or is_lifted_tensor_constant(exp_prog, node) |
| ) |
| |
| |
| def get_param_tensor( |
| exp_prog: ExportedProgram, node: torch.fx.Node |
| ) -> Optional[torch.Tensor]: |
| if node is None: |
| return None |
| elif is_param(exp_prog, node): |
| return get_param(exp_prog, node) |
| elif is_buffer(exp_prog, node): |
| return get_buffer(exp_prog, node) |
| elif is_lifted_tensor_constant(exp_prog, node): |
| return get_lifted_tensor_constant(exp_prog, node) |
| elif is_get_attr_node(node): |
| # This is a hack to support both lifted and unlifted graph |
| try: |
| return getattr(node.graph.owning_module, node.target) |
| except AttributeError: |
| return getattr(exp_prog.graph_module, node.target) |
| raise RuntimeError(f"unsupported param type, {node.op}.") |
| |
| |
| def get_source_fn(node: torch.fx.Node) -> Optional[torch.fx.Node]: |
| """ |
| Returns the source fn of the given node, return None if something goes wrong |
| """ |
| if ( |
| node.op != "call_function" |
| or (source_fn_st := node.meta.get("source_fn_stack", None)) is None |
| ): |
| return None |
| source_fn = source_fn_st[-1] |
| return source_fn[1] |