| # Copyright (c) Meta Platforms, Inc. and affiliates. |
| # All rights reserved. |
| # |
| # This source code is licensed under the BSD-style license found in the |
| # LICENSE file in the root directory of this source tree. |
| |
| # pyre-strict |
| |
| from typing import Callable, List, Optional, Union |
| |
| import torch |
| import torch.fx.passes.infra.pass_manager as fx |
| import torch.utils._pytree as pytree |
| from executorch.exir.error import ExportError, ExportErrorType |
| from torch.fx.passes.infra.pass_base import PassResult |
| from typing_extensions import TypeAlias |
| |
| PassType: TypeAlias = Callable[[torch.fx.GraphModule], Optional[PassResult]] |
| |
| |
| class PassManager(fx.PassManager): |
| """ |
| Class to run multiple passes on a given graph module. The PassManager is |
| callable so to run it, we can just call the PassManager instance. |
| |
| Private Attributes: |
| * **passes**: A list of callable passes |
| * **params**: An instance of PassManagerParams containing the result of the |
| flags set in the constructor. |
| """ |
| |
| def __init__( |
| self, |
| passes: Optional[Union[List[PassType], List[List[PassType]]]] = None, |
| run_checks_after_each_pass: bool = False, |
| suppress_check_failures: bool = False, |
| ) -> None: |
| r""" |
| Args: |
| passes: A list of passes |
| enable_debug_pass: set to true to enable the debug passes |
| run_checks_after_each_pass: whether to run checks and linting after each pass |
| """ |
| |
| # Flatten the passes to a list of callables |
| passes = passes if passes else [] |
| flattened_passes = [ |
| fx.pass_result_wrapper(fn) for fn in pytree.tree_flatten(passes)[0] |
| ] |
| |
| super().__init__( |
| flattened_passes, |
| run_checks_after_each_pass=run_checks_after_each_pass, |
| suppress_check_failures=suppress_check_failures, |
| ) |
| |
| def check(self, module: torch.nn.Module) -> None: |
| """ |
| Runs various checks on the given graph module to make sure it contains |
| the needed data for passes. |
| |
| Some checks that need to be run: |
| - Ensure that types of operator node match the types specified in |
| the node's spec field (ex. if the op returns a tuple then the |
| node's spec field is a tuple) |
| - Ensure that the graph module has type torch.fx.GraphModule |
| """ |
| assert isinstance(module, fx.GraphModule) |
| module.recompile() |
| module.graph.lint() |
| # TODO(qihan): use verifier.check_is_exir |
| |
| for node in module.graph.nodes: |
| if node.op == "call_method": |
| raise ExportError( |
| ExportErrorType.NOT_SUPPORTED, |
| f"call_method `{node}` is not supported except for backend delegate.", |
| ) |