blob: 03c48039b937ad2c42a6baaa03c88901801bb3b4 [file] [log] [blame] [edit]
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from typing import Optional
import torch
from executorch.exir import ExportedProgram
from torch._export.utils import (
get_buffer,
get_lifted_tensor_constant,
get_param,
is_buffer,
is_lifted_tensor_constant,
is_param,
)
def is_get_attr_node(node: torch.fx.Node) -> bool:
"""
Returns true if the given node is a get attr node for a tensor of the model
"""
return isinstance(node, torch.fx.Node) and node.op == "get_attr"
def is_param_node(exp_prog: ExportedProgram, node: torch.fx.Node) -> bool:
return (
is_get_attr_node(node)
or is_param(exp_prog, node)
or is_buffer(exp_prog, node)
or is_lifted_tensor_constant(exp_prog, node)
)
def get_param_tensor(
exp_prog: ExportedProgram, node: torch.fx.Node
) -> Optional[torch.Tensor]:
if node is None:
return None
elif is_param(exp_prog, node):
return get_param(exp_prog, node)
elif is_buffer(exp_prog, node):
return get_buffer(exp_prog, node)
elif is_lifted_tensor_constant(exp_prog, node):
return get_lifted_tensor_constant(exp_prog, node)
elif is_get_attr_node(node):
# This is a hack to support both lifted and unlifted graph
try:
return getattr(node.graph.owning_module, node.target)
except AttributeError:
return getattr(exp_prog.graph_module, node.target)
raise RuntimeError(f"unsupported param type, {node.op}.")