| # Copyright (c) Meta Platforms, Inc. and affiliates. |
| # All rights reserved. |
| # |
| # This source code is licensed under the BSD-style license found in the |
| # LICENSE file in the root directory of this source tree. |
| |
| from typing import Optional |
| |
| import torch |
| from executorch.exir import ExportedProgram |
| |
| from torch._export.utils import ( |
| get_buffer, |
| get_lifted_tensor_constant, |
| get_param, |
| is_buffer, |
| is_lifted_tensor_constant, |
| is_param, |
| ) |
| |
| |
| def is_get_attr_node(node: torch.fx.Node) -> bool: |
| """ |
| Returns true if the given node is a get attr node for a tensor of the model |
| """ |
| return isinstance(node, torch.fx.Node) and node.op == "get_attr" |
| |
| |
| def is_param_node(exp_prog: ExportedProgram, node: torch.fx.Node) -> bool: |
| return ( |
| is_get_attr_node(node) |
| or is_param(exp_prog, node) |
| or is_buffer(exp_prog, node) |
| or is_lifted_tensor_constant(exp_prog, node) |
| ) |
| |
| |
| def get_param_tensor( |
| exp_prog: ExportedProgram, node: torch.fx.Node |
| ) -> Optional[torch.Tensor]: |
| if node is None: |
| return None |
| elif is_param(exp_prog, node): |
| return get_param(exp_prog, node) |
| elif is_buffer(exp_prog, node): |
| return get_buffer(exp_prog, node) |
| elif is_lifted_tensor_constant(exp_prog, node): |
| return get_lifted_tensor_constant(exp_prog, node) |
| elif is_get_attr_node(node): |
| # This is a hack to support both lifted and unlifted graph |
| try: |
| return getattr(node.graph.owning_module, node.target) |
| except AttributeError: |
| return getattr(exp_prog.graph_module, node.target) |
| raise RuntimeError(f"unsupported param type, {node.op}.") |