| import copy |
| import re |
| import torch |
| import torch.nn as nn |
| from torch.ao.quantization import QuantType |
| from torch.ao.quantization.utils import is_per_tensor, is_per_channel |
| from torch.ao.quantization.quantize import is_activation_post_process |
| |
| from torch.fx import GraphModule, map_arg |
| |
| from torch.fx.graph import ( |
| Graph, |
| Node, |
| ) |
| from .custom_config import PrepareCustomConfig |
| |
| from typing import Callable, Optional, List, Dict, Any, Set, Tuple, Union, Type |
| from collections import namedtuple |
| import operator |
| import warnings |
| |
| # TODO: revisit this list. Many helper methods shouldn't be public |
| __all__ = [ |
| "all_node_args_except_first", |
| "all_node_args_have_no_tensors", |
| "assert_and_get_unique_device", |
| "BIAS_INDEX_DICT", |
| "collect_producer_nodes", |
| "create_getattr_from_value", |
| "create_node_from_old_node_preserve_meta", |
| "create_qparam_nodes", |
| "EMPTY_ARG_DICT", |
| "get_custom_module_class_keys", |
| "get_linear_prepack_op_for_dtype", |
| "get_new_attr_name_with_prefix", |
| "get_non_observable_arg_indexes_and_types", |
| "get_per_tensor_qparams", |
| "get_qconv_op", |
| "get_qconv_prepack_op", |
| "get_quantize_node_info", |
| "graph_module_from_producer_nodes", |
| "graph_pretty_str", |
| "is_get_tensor_info_node", |
| "maybe_get_next_module", |
| "NodeInfo", |
| "node_return_type_is_int", |
| "NON_OBSERVABLE_ARG_DICT", |
| "NON_QUANTIZABLE_WEIGHT_OPS", |
| "quantize_node", |
| "return_arg_list", |
| "WEIGHT_INDEX_DICT", |
| "get_skipped_module_name_and_classes", |
| ] |
| |
| |
| # A dictionary for querying the weight index for a given op |
| WEIGHT_INDEX_DICT = { |
| torch.nn.functional.conv1d : [1], |
| torch.nn.functional.conv2d : [1], |
| torch.nn.functional.conv3d : [1], |
| torch.nn.functional.linear : [1], |
| torch.nn.functional.layer_norm : [2], |
| torch.nn.functional.group_norm : [2], |
| torch.nn.functional.instance_norm : [3], |
| } |
| |
| NON_QUANTIZABLE_WEIGHT_OPS = {torch.nn.functional.layer_norm, torch.nn.functional.group_norm, torch.nn.functional.instance_norm} |
| |
| BIAS_INDEX_DICT = { |
| torch.nn.functional.conv1d : [2], |
| torch.nn.functional.conv2d : [2], |
| torch.nn.functional.conv3d : [2], |
| torch.nn.functional.linear : [2], |
| torch.nn.functional.layer_norm : [3], |
| torch.nn.functional.group_norm : [3], |
| torch.nn.functional.instance_norm : [4], |
| } |
| |
| def graph_pretty_str(g, shorten=True) -> str: |
| """Returns a printable representation of the ops in the graph of g. |
| If shorten is True, tries to abbreviate fields. |
| """ |
| built_in_func_re = re.compile('<built-in function (.*)>') |
| built_in_meth_re = re.compile('<built-in method (.*) of type.*>') |
| op_dict = { |
| 'placeholder': 'plchdr', |
| 'get_attr': 'gt_prm', |
| 'call_function': 'cl_fun', |
| 'call_module': 'cl_mod', |
| 'call_method': 'cl_meth', |
| } |
| |
| max_lens = {} |
| col_names = ("name", "op", "target", "args", "kwargs") |
| for s in col_names: |
| max_lens[s] = len(s) |
| |
| results = [] |
| for n in g.nodes: |
| |
| # activation_post_process_0 -> obs_0 |
| name = str(n.name) |
| if shorten: |
| name = name.replace("activation_post_process", "obs") |
| |
| op = str(n.op) |
| # placeholder -> plchdr, and so on |
| if shorten and op in op_dict: |
| op = op_dict[op] |
| |
| target = str(n.target) |
| # <built-in function foo> -> <bi_fun foo>, and so on |
| if shorten: |
| built_in_func = built_in_func_re.search(target) |
| if built_in_func: |
| target = f"<bi_fun {built_in_func.group(1)}>" |
| built_in_meth = built_in_meth_re.search(target) |
| if built_in_meth: |
| target = f"<bi_meth {built_in_meth.group(1)}>" |
| target = target.replace("activation_post_process", "obs") |
| |
| args = str(n.args) |
| if shorten: |
| args = args.replace("activation_post_process", "obs") |
| |
| kwargs = str(n.kwargs) |
| |
| # calculate maximum length of each column, so we can tabulate properly |
| for k, v in zip(col_names, (name, op, target, args, kwargs)): |
| max_lens[k] = max(max_lens[k], len(v)) |
| results.append([name, op, target, args, kwargs]) |
| |
| res_str = "" |
| format_str = "{:<{name}} {:<{op}} {:<{target}} {:<{args}} {:<{kwargs}}\n" |
| res_str += format_str.format(*col_names, **max_lens) |
| for result in results: |
| res_str += format_str.format(*result, **max_lens) |
| |
| # print an exra note on abbreviations which change attribute names, |
| # since users will have to un-abbreviate for further debugging |
| if shorten: |
| res_str += "*obs_{n} = activation_post_process_{n}\n" |
| return res_str |
| |
| def get_per_tensor_qparams(activation_post_process): |
| assert is_per_tensor(activation_post_process.qscheme), 'Only per tensor quantization is supported' |
| scale, zero_point = activation_post_process.calculate_qparams() |
| scale = float(scale) |
| zero_point = int(zero_point) |
| dtype = activation_post_process.dtype |
| return scale, zero_point, dtype |
| |
| def get_quantize_node_info(activation_post_process: Callable) -> Optional[Tuple[str, Union[Callable, str], Dict[str, Any]]]: |
| ''' Given an activation_post_process module, |
| return node_type(e.g. call_function), quantize op(e.g. quantize_per_tensor) and a dictionary |
| of extracted qparams from the module |
| ''' |
| dtype = activation_post_process.dtype # type: ignore[attr-defined] |
| compute_dtype = None |
| if hasattr(activation_post_process, "compute_dtype"): |
| compute_dtype = activation_post_process.compute_dtype # type: ignore[attr-defined] |
| quantize_op : Optional[Union[Callable, str]] = None |
| if dtype in [torch.quint8, torch.qint8]: |
| node_type = "call_function" |
| scale, zero_point = activation_post_process.calculate_qparams() # type: ignore[attr-defined] |
| if is_per_channel(activation_post_process.qscheme): # type: ignore[attr-defined] |
| ch_axis = int(activation_post_process.ch_axis) # type: ignore[attr-defined] |
| qparams = {"_scale_": scale, "_zero_point_": zero_point, "_axis_": ch_axis, "_dtype_": dtype} |
| quantize_op = torch.quantize_per_channel |
| else: |
| scale = float(scale) |
| zero_point = int(zero_point) |
| qparams = {"_scale_": scale, "_zero_point_": zero_point, "_dtype_": dtype} |
| quantize_op = torch.quantize_per_tensor |
| elif dtype == torch.float16: |
| node_type = "call_method" |
| quantize_op = "to" |
| qparams = {"_dtype_": dtype} |
| elif dtype == torch.float32 and compute_dtype in [torch.quint8, torch.qint8, torch.float16]: |
| # dynamic quantization |
| node_type = "call_function" |
| quantize_op = torch.quantize_per_tensor_dynamic |
| # TODO: get reduce range from observer |
| # reduce_range = activation_post_process.reduce_range |
| reduce_range = torch.backends.quantized.engine == "fbgemm" |
| qparams = {"_dtype_": compute_dtype, "_reduce_range_": reduce_range} |
| else: |
| warnings.warn(f"Unsupported activation_post_process in get_quantize_node_info: {activation_post_process}") |
| return None |
| return node_type, quantize_op, qparams |
| |
| def quantize_node( |
| in_node: Node, |
| obs_module: torch.nn.Module, |
| obs_node: Node, |
| modules: Dict[str, torch.nn.Module], |
| quantized_graph: Graph, |
| node_name_to_scope: Dict[str, Tuple[str, type]], |
| is_input: bool, |
| output_prefix: str = "_output") -> Node: |
| ''' Add quantization nodes (eg. quantize_per_tensor/per_channel) for given node to graph |
| with the qparams calculated from activation_post_process (obs_module). |
| The observer node (obs_node) is used to find the FQN of the user of act_post_process. |
| e.g. Given input `node` in `node = self.conv(x)`, insert node: |
| `quantized_node = torch.quantize_per_tensor(x, self._scale_0, self._zer_point_0, self._dtype_0)` |
| where self._scale_0, self._zero_point_0 and self._dtype_0 are |
| calculated from `obs_module` |
| ''' |
| # Find the first use of the observer node, we use this to get the scope of the module. |
| if is_input: |
| # if the quantize function is at the input of op, then we find the first user of the observer_node |
| # to get the path. If a linear call_function is in the user list, we return the first instance |
| # of linear node to get the FQN. |
| users = list(obs_node.users) |
| first_linear_use_or_first_use = users[0] if users else None |
| linear_node = None |
| for n in users: |
| if n.op == "call_function" and n.target == torch.nn.functional.linear: |
| linear_node = n |
| break |
| if linear_node: |
| first_linear_use_or_first_use = linear_node |
| prefix = "_input" |
| else: |
| # if the quantize function is at the output of the op, we use the observer input node to get the path |
| first_linear_use_or_first_use = in_node |
| prefix = output_prefix |
| |
| if first_linear_use_or_first_use and first_linear_use_or_first_use.name in node_name_to_scope: |
| module_path, _ = node_name_to_scope[first_linear_use_or_first_use.name] |
| else: |
| # TODO: it's not used, so actually we can skip quantization |
| # but this requires changing return type of quantize_node |
| # we can fix it later if needed |
| module_path = "" |
| root_module = modules[''] |
| graph = quantized_graph |
| maybe_quantize_node_info = get_quantize_node_info(obs_module) |
| assert maybe_quantize_node_info is not None, \ |
| f"Expecting quantize node info not to be None, observer: {obs_module}" |
| node_type, quantize_op, qparams = maybe_quantize_node_info |
| inputs = [in_node] |
| |
| for key, value in qparams.items(): |
| if key in ['_scale_', '_zero_point_']: |
| # For scale and zero_point values we register them as buffers in the root module. |
| qparam_node = create_getattr_from_value(root_module, graph, module_path + prefix + key, value) |
| inputs.append(qparam_node) |
| else: |
| # for qparams that are not scale/zero_point (like axis, dtype) we store them as literals in the graph. |
| inputs.append(value) |
| return graph.create_node(node_type, quantize_op, tuple(inputs), {}) |
| |
| def get_custom_module_class_keys(custom_module_mapping: Dict[QuantType, Dict[Type, Type]]) -> List[Any]: |
| r""" Get all the unique custom module keys in the custom config dict |
| e.g. |
| Input: |
| { |
| QuantType.STATIC: { |
| CustomModule1: ObservedCustomModule |
| }, |
| QuantType.DYNAMIC: { |
| CustomModule2: DynamicObservedCustomModule |
| }, |
| QuantType.WEIGHT_ONLY: { |
| CustomModule3: WeightOnlyObservedCustomModule |
| }, |
| } |
| |
| Output: |
| # extract the keys across all inner STATIC, DYNAMIC, and WEIGHT_ONLY dicts |
| [CustomModule1, CustomModule2, CustomModule3] |
| """ |
| # using set to dedup |
| float_custom_module_classes : Set[Any] = set() |
| for quant_mode in [QuantType.STATIC, QuantType.DYNAMIC, QuantType.WEIGHT_ONLY]: |
| quant_mode_custom_module_config = custom_module_mapping.get(quant_mode, {}) |
| quant_mode_custom_module_classes = set(quant_mode_custom_module_config.keys()) |
| float_custom_module_classes |= quant_mode_custom_module_classes |
| return list(float_custom_module_classes) |
| |
| def get_linear_prepack_op_for_dtype(dtype): |
| if dtype == torch.float16: |
| return torch.ops.quantized.linear_prepack_fp16 |
| elif dtype == torch.qint8: |
| return torch.ops.quantized.linear_prepack |
| else: |
| raise Exception("can't get linear prepack op for dtype:", dtype) |
| |
| def get_qconv_prepack_op(conv_op: Callable) -> Callable: |
| prepack_ops = { |
| torch.nn.functional.conv1d: torch.ops.quantized.conv1d_prepack, |
| torch.nn.functional.conv2d: torch.ops.quantized.conv2d_prepack, |
| torch.nn.functional.conv3d: torch.ops.quantized.conv3d_prepack |
| } |
| prepack_op = prepack_ops.get(conv_op, None) |
| assert prepack_op, "Didn't find prepack op for {}".format(conv_op) |
| return prepack_op |
| |
| def get_qconv_op(conv_op: Callable, has_relu: bool) -> Callable: |
| qconv_op = { |
| # has relu |
| True: { |
| torch.nn.functional.conv1d: torch.ops.quantized.conv1d_relu, |
| torch.nn.functional.conv2d: torch.ops.quantized.conv2d_relu, |
| torch.nn.functional.conv3d: torch.ops.quantized.conv3d_relu |
| }, |
| False: { |
| torch.nn.functional.conv1d: torch.ops.quantized.conv1d, |
| torch.nn.functional.conv2d: torch.ops.quantized.conv2d, |
| torch.nn.functional.conv3d: torch.ops.quantized.conv3d |
| } |
| } |
| qconv = qconv_op[has_relu].get(conv_op) |
| assert qconv, "Can't find corresponding quantized conv op for {} {}".format(conv_op, has_relu) |
| return qconv |
| |
| # Returns a function that can get a new attribute name for module with given |
| # prefix, for example, |
| # >> get_new_observer_name = get_new_attr_name_with_prefix('_observer') |
| # >> new_name = get_new_observer_name(module) |
| # new_name will be an unused attribute name on module, e.g. `_observer_1` |
| def get_new_attr_name_with_prefix(prefix: str) -> Callable: |
| prefix = prefix.replace(".", "_") |
| |
| def get_new_attr_name(module: torch.nn.Module): |
| def get_attr_name(i: int): |
| return prefix + str(i) |
| i = 0 |
| attr_name = get_attr_name(i) |
| while hasattr(module, attr_name): |
| i += 1 |
| attr_name = get_attr_name(i) |
| return attr_name |
| return get_new_attr_name |
| |
| def collect_producer_nodes(node: Node) -> Optional[List[Node]]: |
| r''' Starting from a target node, trace back until we hit inpu or |
| getattr node. This is used to extract the chain of operators |
| starting from getattr to the target node, for example |
| def forward(self, x): |
| observed = self.observer(self.weight) |
| return F.linear(x, observed) |
| collect_producer_nodes(observed) will either return a list of nodes that |
| produces the observed node or None if we can't extract a self contained |
| graph without free variables(inputs of the forward function). |
| ''' |
| nodes = [node] |
| frontier = [node] |
| while frontier: |
| node = frontier.pop() |
| all_args = list(node.args) + list(node.kwargs.values()) |
| for arg in all_args: |
| if not isinstance(arg, Node): |
| continue |
| if arg.op == 'placeholder': |
| # hit input, can't fold in this case |
| return None |
| nodes.append(arg) |
| if not (arg.op == 'call_function' and arg.target == getattr): |
| frontier.append(arg) |
| return nodes |
| |
| def graph_module_from_producer_nodes( |
| root: GraphModule, producer_nodes: List[Node]) -> GraphModule: |
| r''' Construct a graph module from extracted producer nodes |
| from `collect_producer_nodes` function |
| Args: |
| root: the root module for the original graph |
| producer_nodes: a list of nodes we use to construct the graph |
| Return: |
| A graph module constructed from the producer nodes |
| ''' |
| assert len(producer_nodes) > 0, 'list of producer nodes can not be empty' |
| # since we traced back from node to getattrr |
| producer_nodes.reverse() |
| graph = Graph() |
| env: Dict[Any, Any] = {} |
| |
| def load_arg(a): |
| return map_arg(a, lambda node: env[node]) |
| for producer_node in producer_nodes: |
| env[producer_node] = graph.node_copy(producer_node, load_arg) |
| graph.output(load_arg(producer_nodes[-1])) |
| graph_module = GraphModule(root, graph) |
| return graph_module |
| |
| def assert_and_get_unique_device(module: torch.nn.Module) -> Any: |
| """ |
| Returns the unique device for a module, or None if no device is found. |
| Throws an error if multiple devices are detected. |
| """ |
| devices = {p.device for p in module.parameters()} | \ |
| {p.device for p in module.buffers()} |
| assert len(devices) <= 1, ( |
| "prepare only works with cpu or single-device CUDA modules, " |
| "but got devices {}".format(devices) |
| ) |
| device = next(iter(devices)) if len(devices) > 0 else None |
| return device |
| |
| def create_getattr_from_value(module: torch.nn.Module, graph: Graph, prefix: str, value: Any) -> Node: |
| """ |
| Given a value of any type, creates a getattr node corresponding to the value and |
| registers the value as a buffer to the module. |
| """ |
| get_new_attr_name = get_new_attr_name_with_prefix(prefix) |
| attr_name = get_new_attr_name(module) |
| device = assert_and_get_unique_device(module) |
| module.register_buffer(attr_name, torch.tensor(value, device=device)) |
| # Create get_attr with value |
| attr_node = graph.create_node("get_attr", attr_name) |
| return attr_node |
| |
| def create_qparam_nodes( |
| node_name: str, |
| scale: Any, |
| zero_point: Any, |
| modules: Dict[str, torch.nn.Module], |
| quantized_graph: Graph, |
| node_name_to_scope: Dict[str, Tuple[str, type]] |
| ) -> Tuple[Node, Node]: |
| """ |
| Create getattr nodes in the quantized graph for scale and zero point values. |
| The nodes are registered with the root_module of the model. |
| """ |
| root_module = modules[''] |
| module_path, _ = node_name_to_scope[node_name] |
| scale_node = create_getattr_from_value(root_module, quantized_graph, (module_path + "_scale_"), scale) |
| zero_point_node = create_getattr_from_value(root_module, quantized_graph, (module_path + "_zero_point_"), zero_point) |
| return (scale_node, zero_point_node) |
| |
| |
| def all_node_args_have_no_tensors(node: Node, modules: Dict[str, torch.nn.Module], cache: Dict[Node, bool]) -> bool: |
| """ |
| If we know for sure that all of this node's args have no |
| tensors (are primitives), return True. If we either |
| find a tensor or are not sure, return False. Note: this |
| function is not exact. |
| """ |
| if cache and node in cache: |
| return cache[node] |
| |
| result = False # will be overwritten |
| if not isinstance(node, Node): |
| result = True |
| elif node.op == 'placeholder': |
| result = False |
| elif node.op == 'call_module': |
| assert isinstance(node.target, str) |
| if is_activation_post_process(modules[node.target]): |
| result = all_node_args_have_no_tensors(node.args[0], modules, cache) # type: ignore[arg-type] |
| elif node.op == 'call_module': |
| result = False |
| elif node.op == 'call_function' and node.target is operator.getitem: |
| result = all_node_args_have_no_tensors(node.args[0], modules, cache) # type: ignore[arg-type] |
| elif node.op == 'get_attr': |
| result = False |
| elif node.target is getattr and node.args[1] in ['ndim', 'shape']: |
| # x1 = x0.ndim |
| result = True |
| elif node.op == 'call_method' and node.target == 'size': |
| # x1 = x0.size(0) |
| result = True |
| else: |
| found_one_tensor = False |
| for arg in node.args: |
| if isinstance(arg, list): |
| for list_el in arg: |
| if isinstance(list_el, Node): |
| this_list_el_args_have_no_tensors = \ |
| all_node_args_have_no_tensors(list_el, modules, cache) |
| found_one_tensor = found_one_tensor or \ |
| (not this_list_el_args_have_no_tensors) |
| # If found_one_tensor is True, there is no point in |
| # recursing further as the end result will always |
| # be True. |
| # TODO(future PR): remove this entire function and |
| # change to dtype inference without recursion. |
| if found_one_tensor: |
| result = not found_one_tensor |
| if cache: |
| cache[node] = result |
| return result |
| elif isinstance(arg, int): |
| pass |
| else: |
| if isinstance(arg, Node): |
| this_arg_args_have_no_tensors = all_node_args_have_no_tensors(arg, modules, cache) |
| found_one_tensor = found_one_tensor or \ |
| (not this_arg_args_have_no_tensors) |
| # If found_one_tensor is True, there is no point in |
| # recursing further as the end result will always |
| # be True. |
| # TODO(future PR): remove this entire function and |
| # change to dtype inference without recursion. |
| if found_one_tensor: |
| result = not found_one_tensor |
| if cache: |
| cache[node] = result |
| return result |
| else: |
| found_one_tensor = True |
| result = not found_one_tensor |
| if cache: |
| cache[node] = result |
| return result |
| |
| def all_node_args_except_first(node: Node) -> List[int]: |
| """ |
| Returns all node arg indices after first |
| """ |
| return list(range(1, len(node.args))) |
| |
| def return_arg_list(arg_indices: List[int]) -> Callable[[Node], List[int]]: |
| """ |
| Constructs a function that takes a node as arg and returns the arg_indices |
| that are valid for node.args |
| """ |
| def arg_indices_func(node: Node) -> List[int]: |
| return [i for i in arg_indices if i < len(node.args)] |
| return arg_indices_func |
| |
| NodeInfo = namedtuple("NodeInfo", "op target") |
| |
| # this dict identifies which indices of a node are non tensors |
| # so that they can be propagated correctly since inserting observers |
| # for them would cause errors |
| |
| NON_OBSERVABLE_ARG_DICT: Dict[NodeInfo, Dict[Union[type, torch.dtype], Callable[[Node], List[int]]]] = { |
| NodeInfo("call_method", "masked_fill") : { |
| torch.bool: return_arg_list([1]), |
| float: return_arg_list([2]) |
| }, |
| NodeInfo("call_method", "permute") : { |
| int: all_node_args_except_first |
| }, |
| NodeInfo("call_method", "repeat") : { |
| int: all_node_args_except_first |
| }, |
| NodeInfo("call_method", "reshape") : { |
| int: all_node_args_except_first |
| }, |
| NodeInfo("call_method", "size") : { |
| int: return_arg_list([1]) |
| }, |
| NodeInfo("call_method", "transpose") : { |
| int: all_node_args_except_first |
| }, |
| NodeInfo("call_method", torch.transpose) : { |
| int: all_node_args_except_first |
| }, |
| NodeInfo("call_method", "unsqueeze") : { |
| int: return_arg_list([1]) |
| }, |
| NodeInfo("call_method", "unsqueeze_") : { |
| int: return_arg_list([1]) |
| }, |
| NodeInfo("call_method", torch.unsqueeze) : { |
| int: return_arg_list([1]) |
| }, |
| NodeInfo("call_method", "view") : { |
| int: all_node_args_except_first |
| }, |
| } |
| |
| EMPTY_ARG_DICT: Dict[Union[type, torch.dtype], Callable[[Node], List[int]]] = {} |
| |
| def get_non_observable_arg_indexes_and_types(node: Node) -> Dict[Union[type, torch.dtype], Callable[[Node], List[int]]]: |
| """ |
| Returns a dict with of non float tensor types as keys and values which correspond to a |
| function to retrieve the list (which takes the node as an argument) |
| """ |
| info = NodeInfo(node.op, node.target) |
| |
| return NON_OBSERVABLE_ARG_DICT.get(info, EMPTY_ARG_DICT) |
| |
| def node_return_type_is_int(node: Node) -> bool: |
| """ |
| Returns true if this node results in an integer, even if some of the args |
| are Tensors. |
| """ |
| return node.op == 'call_method' and node.target == 'size' |
| |
| |
| def is_get_tensor_info_node(node: Node) -> bool: |
| """ Returns True if this node is a node that takes a Tensor as input and output some |
| meta information about the Tensor, e.g. shape, size etc. |
| """ |
| result: bool = \ |
| node.op == "call_function" and node.target == getattr and node.args[1] == "shape" # type: ignore[assignment] |
| return result |
| |
| def maybe_get_next_module( |
| node: Node, |
| modules: Dict[str, nn.Module], |
| target_module_type: Optional[Type[nn.Module]] = None, |
| target_functional_type: Any = None, |
| ) -> Optional[Node]: |
| """ Gets the next module that matches what is needed in |
| is_target_module_type if it exists |
| |
| Args: |
| node: The node whose users we want to look at |
| target_module_type: Module type that we want to check |
| target_functional_type: Functional type that we want to check |
| """ |
| |
| for user, _ in node.users.items(): |
| if user.op == 'call_module' and target_module_type is not None and \ |
| isinstance(modules[str(user.target)], target_module_type): |
| return user |
| elif (user.op == 'call_function' and target_functional_type is not None and |
| user.target == target_functional_type): |
| return user |
| |
| return None |
| |
| def create_node_from_old_node_preserve_meta( |
| quantized_graph: Graph, |
| create_node_args: Tuple[Any, ...], |
| old_node: Node, |
| ) -> Node: |
| """ |
| Creates `new_node` and copies the necessary metadata to it from `old_node`. |
| """ |
| new_node = quantized_graph.create_node(*create_node_args) |
| new_node.stack_trace = old_node.stack_trace |
| return new_node |
| |
| def get_skipped_module_name_and_classes( |
| prepare_custom_config: PrepareCustomConfig, |
| is_standalone_module: bool) -> Tuple[List[str], List[Type[Any]]]: |
| skipped_module_names = copy.copy(prepare_custom_config.non_traceable_module_names) |
| skipped_module_classes = copy.copy(prepare_custom_config.non_traceable_module_classes) |
| if not is_standalone_module: |
| # standalone module and custom module config are applied in top level module |
| skipped_module_names += list(prepare_custom_config.standalone_module_names.keys()) |
| skipped_module_classes += list(prepare_custom_config.standalone_module_classes.keys()) |
| skipped_module_classes += get_custom_module_class_keys(prepare_custom_config.float_to_observed_mapping) |
| |
| return skipped_module_names, skipped_module_classes |