| import collections |
| import contextlib |
| import copy |
| import cProfile |
| import dataclasses |
| import datetime |
| import dis |
| import enum |
| import functools |
| import gc |
| import inspect |
| import itertools |
| import logging |
| import math |
| import operator |
| import os |
| import pstats |
| import sys |
| import textwrap |
| import time |
| import types |
| import typing |
| import weakref |
| from contextlib import contextmanager |
| from functools import lru_cache, wraps |
| from typing import Any, Dict, Tuple, Union |
| |
| import torch._logging |
| from torch._guards import detect_fake_mode # noqa: F401 |
| from . import config |
| |
| try: |
| import numpy as np |
| |
| HAS_NUMPY = True |
| except ModuleNotFoundError: |
| np = None # type: ignore[assignment] |
| HAS_NUMPY = False |
| |
| try: |
| import torch_np |
| |
| HAS_NUMPY_TORCH_INTEROP = True |
| except ModuleNotFoundError: |
| torch_np = None |
| HAS_NUMPY_TORCH_INTEROP = False |
| |
| import importlib |
| |
| import torch |
| import torch.fx.experimental.symbolic_shapes |
| from torch import fx |
| from torch._dispatch.python import enable_python_dispatcher |
| from torch._subclasses.fake_tensor import FakeTensor |
| from torch.nn.modules.lazy import LazyModuleMixin |
| from torch.utils._pytree import tree_map |
| |
| counters = collections.defaultdict(collections.Counter) |
| troubleshooting_url = "https://pytorch.org/docs/master/compile/troubleshooting.html" |
| nnmodule_doc_url = "https://pytorch.org/docs/master/compile/nn-module.html" |
| nnmodule_doc_url_msg = f"See {nnmodule_doc_url} for more information and limitations." |
| |
| log = logging.getLogger(__name__) |
| |
| # profiling compilation time |
| compilation_metrics = collections.OrderedDict() |
| |
| timer_counter = itertools.count() |
| |
| |
| def tabulate(rows, headers): |
| try: |
| import tabulate |
| |
| return tabulate.tabulate(rows, headers=headers) |
| except ImportError: |
| return "\n".join( |
| ", ".join(map(str, row)) for row in itertools.chain([headers], rows) |
| ) |
| |
| |
| def dynamo_profiled(func): |
| @wraps(func) |
| def profile_wrapper(*args, **kwargs): |
| global timer_counter |
| datafn = ( |
| func.__name__ + f"{next(timer_counter)}.profile" |
| ) # Name the data file sensibly |
| prof = cProfile.Profile() |
| prof.enable() |
| retval = prof.runcall(func, *args, **kwargs) |
| prof.disable() |
| print(f"### Cprofile for {func.__name__} iter {next(timer_counter)} ###") |
| ps = pstats.Stats(prof) |
| ps.sort_stats(pstats.SortKey.TIME).print_stats(20) |
| ps.sort_stats(pstats.SortKey.CUMULATIVE).print_stats(20) |
| prof.dump_stats(datafn) |
| return retval |
| |
| return profile_wrapper |
| |
| |
| frame_phase_timing = collections.OrderedDict() |
| |
| curr_frame = 0 |
| |
| |
| # Note: Called for you by dynamo - you almost never ever want to invoke this yourself. |
| def increment_frame(): |
| global curr_frame |
| curr_frame = curr_frame + 1 |
| |
| |
| # Note: Called for you by dynamo - you almost never ever want to invoke this yourself. |
| def reset_frame_count(): |
| global curr_frame |
| frame_phase_timing.clear() |
| curr_frame = 0 |
| |
| |
| op_count = 0 |
| |
| |
| def increment_op_count(cnt): |
| global op_count |
| op_count += cnt |
| |
| |
| # Print a report of time spent so far |
| # Ex: |
| # TIMING: |
| # entire_frame_compile:8.574629999999999 |
| # backend_compile:5.26806 |
| def print_time_report(): |
| total = 0 |
| total_by_key = {} |
| for frame, timings in frame_phase_timing.items(): |
| for key, timing in timings.items(): |
| total += timing |
| if key not in total_by_key: |
| total_by_key[key] = timing |
| else: |
| total_by_key[key] += timing |
| |
| out = "TIMING:" |
| for key, value in total_by_key.items(): |
| out = f"{out} {key}:{round(value, 5)}" |
| |
| print(out) |
| |
| |
| # dynamo_timed API works as a function decorator |
| # By wrapping a function in dynamo_timed, we can store a record in compilation_metrics |
| # where the key is the functions name. |
| # For example: |
| # |
| # @dynamo_timed |
| # def _foo(...): |
| # |
| # Would show up as an entry in our timing dict: |
| # OrderedDict([('bar.<locals>._foo', [0.083690, 0.23949, 3.1425e-05])]) |
| # This is extremely useful for granular debugging. |
| # |
| # For a higher-level mode, pass a phase_name into dynamo_timed |
| # phase_names record an extra record into a separate compilation timing structure, |
| # one keyed on frame+name rather than function. |
| # The frame is incremented outside of this function, in def increment_frame() above. |
| def dynamo_timed(original_function=None, phase_name=None): |
| def dynamo_timed_inner(func): |
| @wraps(func) |
| def time_wrapper(*args, **kwargs): |
| key = func.__qualname__ |
| if key not in compilation_metrics: |
| compilation_metrics[key] = [] |
| with torch.profiler.record_function(f"{key} (dynamo_timed)"): |
| t0 = time.time() |
| r = func(*args, **kwargs) |
| time_spent = time.time() - t0 |
| # print(f"Dynamo timer: key={key}, latency={latency:.2f} sec") |
| compilation_metrics[key].append(time_spent) |
| if phase_name: |
| frame_key = str(curr_frame) |
| if frame_key not in frame_phase_timing: |
| frame_phase_timing[frame_key] = {} |
| assert ( |
| phase_name not in frame_phase_timing[frame_key] |
| ), f"Duplicate phase name {phase_name} for frame {frame_key}" |
| frame_phase_timing[frame_key][phase_name] = time_spent |
| return r |
| |
| return time_wrapper |
| |
| if original_function: |
| return dynamo_timed_inner(original_function) |
| return dynamo_timed_inner |
| |
| |
| def compile_times(repr="str", aggregate=False): |
| """ |
| Get metrics about torchdynamo frontend/backend compilation times. |
| |
| Accumulates information from functions tagged with `@dynamo_timed`. |
| |
| repr='str' returns a printable string for user interaction, and 'csv' |
| returns headers, rows which can be logged for output |
| |
| aggregate causes values from multiple compilations (e.g. split graphs) |
| to be accumulated into one value. If false, expect more than one value |
| per metric. |
| """ |
| |
| def fmt_fn(values, item_fn=lambda x: x): |
| if aggregate: |
| return item_fn(sum(values)) |
| return ", ".join(map(item_fn, values)) |
| |
| if repr == "str": |
| rows = [ |
| (k, fmt_fn(compilation_metrics[k], item_fn=lambda x: f"{x:.4f}")) |
| for k in compilation_metrics |
| ] |
| out = "TorchDynamo compilation metrics:\n" |
| out += tabulate(rows, headers=("Function", "Runtimes (s)")) |
| return out |
| elif repr == "csv": |
| values = [ |
| fmt_fn(v, item_fn=lambda x: f"{x:.6f}") |
| for v in compilation_metrics.values() |
| ] |
| headers = list(compilation_metrics.keys()) |
| return headers, values |
| |
| |
| tensortype_to_dtype = { |
| torch.FloatTensor: (torch.float32, torch.float), |
| torch.DoubleTensor: (torch.float64, torch.double), |
| torch.HalfTensor: (torch.float16, torch.half), |
| torch.BFloat16Tensor: (torch.bfloat16,), |
| torch.ByteTensor: (torch.uint8,), |
| torch.CharTensor: (torch.int8,), |
| torch.LongTensor: (torch.int64, torch.long), |
| torch.IntTensor: (torch.int32, torch.int), |
| torch.ShortTensor: (torch.int16, torch.short), |
| torch.BoolTensor: (torch.bool,), |
| } |
| |
| |
| class DuplicateWarningChecker: |
| def __init__(self, maxsize=4096): |
| self.maxsize = maxsize |
| self.reset() |
| |
| def reset(self): |
| self.set = collections.OrderedDict() |
| |
| def add(self, key): |
| if key in self.set: |
| self.set.move_to_end(key, last=True) |
| if not config.verbose: |
| return False |
| else: |
| self.set[key] = None |
| while len(self.set) > self.maxsize: |
| self.set.popitem(last=False) |
| return True |
| |
| |
| graph_break_dup_warning_checker = DuplicateWarningChecker() |
| |
| |
| def setup_compile_debug(): |
| compile_debug = bool(os.environ.get("TORCH_COMPILE_DEBUG", False)) |
| |
| if compile_debug: |
| torch._logging.set_logs( |
| dynamo=logging.DEBUG, |
| aot=logging.DEBUG, |
| inductor=logging.DEBUG, |
| output_code=True, # this is off by default |
| ) |
| return add_file_handler() |
| |
| return contextlib.ExitStack() |
| |
| |
| def reset_graph_break_dup_checker(): |
| graph_break_dup_warning_checker.reset() |
| |
| |
| def add_file_handler(): |
| log_path = os.path.join(get_debug_dir(), "torchdynamo") |
| if not os.path.exists(log_path): |
| os.makedirs(log_path) |
| |
| log_file_handler = logging.FileHandler(os.path.join(log_path, "debug.log")) |
| logger = logging.getLogger("torch._dynamo") |
| logger.addHandler(log_file_handler) |
| |
| exitstack = contextlib.ExitStack() |
| exitstack.callback(lambda: logger.removeHandler(log_file_handler)) |
| return exitstack |
| |
| |
| def setup_log_file(): |
| exitstack = contextlib.ExitStack() |
| if config.log_file_name is not None: |
| log_file_handler = logging.FileHandler(config.log_file_name) |
| for logger in logging.get_loggers(): |
| logger.addHandler(log_file_handler) |
| exitstack.callback(lambda: logger.removeHandler(log_file_handler)) |
| return exitstack |
| |
| return exitstack |
| |
| |
| def gen_record_file_name(exc, code): |
| return f"{get_debug_dir()}/error_recordings/\ |
| {code.co_name}_{type(exc).__name__}_{code.co_firstlineno}.rec" |
| |
| |
| def write_record_to_file(filename, exec_record): |
| try: |
| if os.path.exists(filename): |
| log.warning( |
| "Unable to write execution record %s; file already exists.", filename |
| ) |
| else: |
| os.makedirs(os.path.dirname(filename), exist_ok=True) |
| with open(filename, "wb") as f: |
| exec_record.dump(f) |
| except Exception: |
| log.error("Unable to write execution record %s", filename, exc_info=1) |
| |
| |
| def count_calls(g: fx.Graph): |
| c = 0 |
| for n in g.nodes: |
| if "call" in n.op: |
| c += 1 |
| return c |
| |
| |
| def identity(x): |
| return x |
| |
| |
| def nothing(*args, **kwargs): |
| pass |
| |
| |
| class ExactWeakKeyDictionary: |
| """Similar to weakref.WeakKeyDictionary, but use `is`/`id` rather than `==` to compare equality""" |
| |
| def __init__(self): |
| self.values = dict() |
| self.refs = dict() |
| |
| def __getitem__(self, key): |
| return self.values[id(key)] |
| |
| def get(self, key, default=None): |
| return self.values.get(id(key), default) |
| |
| def __contains__(self, key): |
| return id(key) in self.values |
| |
| def __setitem__(self, key, value): |
| idx = id(key) |
| if idx not in self.refs: |
| self.refs[idx] = weakref.ref(key, lambda ref: self._remove_id(idx)) |
| self.values[idx] = value |
| |
| def _remove_id(self, idx): |
| if idx in self.values: |
| del self.values[idx] |
| if idx in self.refs: |
| del self.refs[idx] |
| |
| def clear(self): |
| self.refs.clear() |
| self.values.clear() |
| |
| |
| def istype(obj, allowed_types): |
| """isinstance() without subclasses""" |
| if isinstance(allowed_types, (tuple, list, set)): |
| return type(obj) in allowed_types |
| return type(obj) is allowed_types |
| |
| |
| def is_typing(value): |
| if sys.version_info < (3, 9): |
| return isinstance(value, typing._GenericAlias) |
| else: |
| return isinstance( |
| value, (typing._SpecialGenericAlias, typing._UnionGenericAlias) |
| ) |
| |
| |
| def is_numpy_int_type(value): |
| if HAS_NUMPY: |
| return istype( |
| value, |
| ( |
| np.int8, |
| np.int16, |
| np.int32, |
| np.int64, |
| np.uint8, |
| np.uint16, |
| np.uint32, |
| np.uint64, |
| ), |
| ) |
| else: |
| return False |
| |
| |
| def is_numpy_float_type(value): |
| if HAS_NUMPY: |
| return istype( |
| value, |
| ( |
| np.float16, |
| np.float32, |
| np.float64, |
| ), |
| ) |
| else: |
| return False |
| |
| |
| def is_numpy_ndarray(value): |
| if HAS_NUMPY: |
| return istype(value, np.ndarray) |
| else: |
| return False |
| |
| |
| def istensor(obj): |
| """Check of obj is a tensor""" |
| tensor_list = ( |
| torch.Tensor, |
| torch.nn.Parameter, |
| *config.traceable_tensor_subclasses, |
| ) |
| tensor_list = tensor_list + (torch._subclasses.FakeTensor,) |
| return istype(obj, tensor_list) |
| |
| |
| def is_lazy_module(mod): |
| return isinstance(mod, LazyModuleMixin) |
| |
| |
| @functools.lru_cache(4096) |
| def print_once(*args): |
| print(*args) |
| |
| |
| def make_cell(val=None): |
| """Some black magic to create a cell object that usually only exists in a closure""" |
| x = val |
| |
| def f(): |
| return x |
| |
| assert len(f.__closure__) == 1 |
| return f.__closure__[0] |
| |
| |
| def proxy_args_kwargs(args, kwargs): |
| try: |
| proxy_args = tuple(arg.as_proxy() for arg in args) |
| proxy_kwargs = {key: arg.as_proxy() for key, arg in kwargs.items()} |
| return proxy_args, proxy_kwargs |
| except NotImplementedError as e: |
| from .exc import unimplemented |
| from .variables.base import typestr |
| |
| raise unimplemented( |
| f"call_function args: {typestr(*args)} {typestr(*list(kwargs.values()))}" |
| ) from e |
| |
| |
| @dataclasses.dataclass |
| class CleanupHook: |
| """Remove a global variable when hook is called""" |
| |
| scope: Dict[str, Any] |
| name: str |
| |
| def __call__(self, *args): |
| CleanupManager.count -= 1 |
| del self.scope[self.name] |
| |
| @staticmethod |
| def create(scope, name, val): |
| assert name not in scope |
| CleanupManager.count += 1 |
| scope[name] = val |
| return CleanupHook(scope, name) |
| |
| |
| class CleanupManager(ExactWeakKeyDictionary): |
| count = 0 |
| |
| def _remove_id(self, idx): |
| for hook in self.values[idx]: |
| hook() |
| super()._remove_id(idx) |
| |
| |
| CleanupManager.instance = CleanupManager() |
| |
| |
| def clone_tensor(x): |
| """Clone the tensor and its gradient""" |
| y = x.clone().requires_grad_(x.requires_grad) |
| if x.is_leaf and x.grad is not None: |
| y.grad = x.grad.clone() |
| return y |
| |
| |
| def clone_input(x): |
| """copy while preserving strides""" |
| # TODO: this is questionable |
| if isinstance(x, torch._subclasses.FakeTensor): |
| # this func fails on fake tensors in __torch_dispatch__ |
| return x |
| |
| def torch_clone(x): |
| y = torch.clone(x) |
| if x.is_leaf: |
| y.requires_grad_(x.requires_grad) |
| if x.is_leaf and x.grad is not None: |
| y.grad = clone_input(x.grad) |
| if hasattr(x, "_dynamo_dynamic_indices"): |
| y._dynamo_dynamic_indices = x._dynamo_dynamic_indices.copy() |
| return y |
| |
| with torch.no_grad(): |
| if x.device.type == "xla": |
| # Access data_ptr() for a xla tensor will cause crash |
| return torch_clone(x) |
| |
| needed_size = sum( |
| (shape - 1) * stride for shape, stride in zip(x.size(), x.stride()) |
| ) |
| if x.is_quantized: |
| result = torch.empty_quantized((needed_size + 32,), x) |
| else: |
| result = torch.empty(needed_size + 32, dtype=x.dtype, device=x.device) |
| cache_line_offset = ( |
| (x.data_ptr() - result.data_ptr()) % 32 |
| ) // x.element_size() |
| result.as_strided_(x.size(), x.stride(), cache_line_offset) |
| try: |
| result.copy_(x.clone()) |
| if x.is_leaf: |
| result.requires_grad_(x.requires_grad) |
| if x.is_leaf and x.grad is not None: |
| result.grad = clone_input(x.grad) |
| except RuntimeError: |
| # RuntimeError: unsupported operation: more than one element of the written-to |
| # tensor refers to a single memory location. Please clone() the tensor before |
| # performing the operation. |
| return torch_clone(x) |
| if hasattr(x, "_dynamo_dynamic_indices"): |
| result._dynamo_dynamic_indices = x._dynamo_dynamic_indices.copy() |
| return result |
| |
| |
| def clone_inputs(example_inputs): |
| if type(example_inputs) is dict: |
| res = dict(example_inputs) |
| for key, value in res.items(): |
| assert isinstance(value, torch.Tensor) |
| res[key] = clone_input(value) |
| return res |
| |
| res = list(example_inputs) |
| for i in range(len(res)): |
| if isinstance(res[i], torch.Tensor): |
| res[i] = clone_input(res[i]) |
| return res |
| |
| |
| @contextmanager |
| def preserve_rng_state(): |
| rng = torch.clone(torch.random.get_rng_state()) |
| if torch.cuda.is_available(): |
| cuda_rng = torch.clone(torch.cuda.get_rng_state()) |
| try: |
| yield |
| finally: |
| torch.random.set_rng_state(rng) |
| if torch.cuda.is_available(): |
| torch.cuda.set_rng_state(cuda_rng) |
| |
| |
| def is_jit_model(model0): |
| return isinstance( |
| model0, |
| ( |
| torch.jit._trace.TopLevelTracedModule, |
| torch.jit._script.RecursiveScriptModule, |
| torch.jit.ScriptFunction, |
| torch.jit.ScriptModule, |
| ), |
| ) |
| |
| |
| def torchscript(model, example_inputs, verbose=False): |
| if is_jit_model(model): |
| # already done? |
| return model |
| |
| try: |
| return torch.jit.trace(model, example_inputs) |
| except Exception: |
| try: |
| return torch.jit.script(model) |
| except Exception: |
| if verbose: |
| log.exception("jit error") |
| else: |
| log.error("Both torch.jit.trace and torch.jit.script failed") |
| return None |
| |
| |
| def getfile(obj): |
| try: |
| return inspect.getfile(obj) |
| except TypeError: |
| return None |
| |
| |
| def is_namedtuple(obj): |
| """Test if an object is a namedtuple or a torch.return_types.* quasi-namedtuple""" |
| return is_namedtuple_cls(type(obj)) |
| |
| |
| def is_namedtuple_cls(cls): |
| """Test if an object is a namedtuple or a torch.return_types.* quasi-namedtuple""" |
| try: |
| if issubclass(cls, tuple): |
| bases = getattr(cls, "__bases__", []) or [None] |
| module = getattr(cls, "__module__", None) |
| return module == "torch.return_types" or ( |
| bases[0] is tuple and hasattr(cls, "_make") and hasattr(cls, "_fields") |
| ) |
| except TypeError: |
| pass |
| return False |
| |
| |
| @functools.lru_cache(1) |
| def namedtuple_fields(cls): |
| """Get the fields of a namedtuple or a torch.return_types.* quasi-namedtuple""" |
| if cls is slice: |
| return ["start", "stop", "step"] |
| |
| assert issubclass(cls, tuple) |
| if hasattr(cls, "_fields"): |
| # normal namedtuples |
| return cls._fields |
| |
| @dataclasses.dataclass |
| class Marker: |
| index: int |
| |
| # frustrating ones e.g. torch.return_types.max |
| assert cls.__module__ == "torch.return_types" |
| obj = cls(map(Marker, range(cls.n_fields))) |
| fields = [None] * cls.n_fields |
| for name in dir(obj): |
| if name[0] != "_" and isinstance(getattr(obj, name), Marker): |
| fields[getattr(obj, name).index] = name |
| return fields |
| |
| |
| def checkpoint_params(gm): |
| with torch.no_grad(): |
| rng_state = torch.clone(torch.random.get_rng_state()) |
| if torch.cuda.is_available(): |
| cuda_rng_state = torch.clone(torch.cuda.get_rng_state()) |
| saved_state = [] |
| for param in itertools.chain(gm.parameters(), gm.buffers()): |
| saved_state.append((param, param._version, torch.clone(param))) |
| |
| def restore(): |
| with torch.no_grad(): |
| torch.random.set_rng_state(rng_state) |
| if torch.cuda.is_available(): |
| torch.cuda.set_rng_state(cuda_rng_state) |
| for param, version, original_value in saved_state: |
| if param._version != version: |
| param.copy_(original_value) |
| |
| return restore |
| |
| |
| def timed(model, example_inputs, times=1): |
| if torch.cuda.is_available(): |
| synchronize = torch.cuda.synchronize |
| else: |
| synchronize = nothing |
| |
| synchronize() |
| gc.collect() |
| torch.manual_seed(1337) |
| t0 = time.perf_counter() |
| for _ in range(times): |
| result = model(*example_inputs) |
| synchronize() |
| t1 = time.perf_counter() |
| return result, t1 - t0 |
| |
| |
| def check_is_cuda(gm, example_inputs): |
| return all(x.is_cuda for x in itertools.chain(example_inputs, gm.parameters(True))) |
| |
| |
| @lru_cache(32) |
| def rot_n_helper(n): |
| assert n > 1 |
| vars = [f"v{i}" for i in range(n)] |
| rotated = reversed(vars[-1:] + vars[:-1]) |
| fn = eval(f"lambda {','.join(vars)}: ({','.join(rotated)})") |
| fn.__name__ = f"rot_{n}_helper" |
| return fn |
| |
| |
| def is_safe_constant(v): |
| if istype(v, (tuple, frozenset)): |
| return all(map(is_safe_constant, v)) |
| return isinstance(v, (enum.Enum, type)) or istype( |
| v, |
| ( |
| types.CodeType, |
| int, |
| float, |
| bool, |
| str, |
| bytes, |
| type(None), |
| slice, |
| type(type), |
| torch.device, |
| torch.dtype, |
| ), |
| ) |
| |
| |
| def check_constant_args(args, kwargs): |
| return all(x.is_python_constant() for x in itertools.chain(args, kwargs.values())) |
| |
| |
| def check_unspec_python_args(args, kwargs): |
| from .variables.constant import ConstantVariable |
| from .variables.tensor import UnspecializedPythonVariable |
| |
| unspec_count = 0 |
| for x in itertools.chain(args, kwargs.values()): |
| if isinstance(x, UnspecializedPythonVariable): |
| unspec_count += 1 |
| elif not isinstance(x, (UnspecializedPythonVariable, ConstantVariable)): |
| return False |
| else: |
| pass |
| |
| return unspec_count > 0 |
| |
| |
| def specialize_args_kwargs(tx, args, kwargs): |
| specialized_args = [] |
| specialized_kwargs = {} |
| for x in args: |
| specialized_args.append(x.as_specialized(tx)) |
| for k, v in kwargs.items(): |
| specialized_kwargs.update({k: v.as_specialized(tx)}) |
| return specialized_args, specialized_kwargs |
| |
| |
| dict_values = type(dict().values()) |
| odict_values = type(collections.OrderedDict().values()) |
| tuple_iterator = type(iter(tuple())) |
| tuple_iterator_len = tuple_iterator.__length_hint__ |
| object_new = object.__new__ |
| |
| |
| def nn_module_new(cls): |
| obj = object_new(cls) |
| torch.nn.Module.__init__(obj) |
| return obj |
| |
| |
| def product(it): |
| return functools.reduce(operator.mul, it, 1) |
| |
| |
| def tuple_iterator_getitem(it, index): |
| _, (obj,), start = it.__reduce__() |
| return obj[start + index] |
| |
| |
| def enum_repr(value, local): |
| enum_name = str(value) |
| |
| name, val = enum_name.split(".") |
| scope = "L" if local else "G" |
| local_name = f'{scope}["{name}"].{val}' |
| return local_name |
| |
| |
| def dict_param_key_ids(value): |
| return {id(k) for k in value.keys() if isinstance(k, torch.nn.Parameter)} |
| |
| |
| def dict_const_keys(value): |
| return {k for k in value.keys() if not isinstance(k, torch.nn.Parameter)} |
| |
| |
| def dict_const_keys_repr(const_keys, *, local): |
| if any(isinstance(k, enum.Enum) for k in const_keys): |
| # To workaround repr(Enum) returning invalid global reference before python 3.11 |
| # by calling enum_repr and removing quotes to render enum in guard code. |
| const_keys_str = f"{ {enum_repr(k, local=local) if isinstance(k, enum.Enum) else repr(k) for k in const_keys} }".replace( |
| "'", "" |
| ) |
| else: |
| const_keys_str = f"{const_keys!r}" |
| return const_keys_str |
| |
| |
| def global_key_name(key): |
| return f"__dict_key_{id(key)}" |
| |
| |
| from torch._subclasses import ( # noqa: F401 |
| FakeTensorMode, |
| UnsupportedFakeTensorException, |
| ) |
| |
| |
| def wrap_fake_exception(fn): |
| try: |
| return fn() |
| except UnsupportedFakeTensorException as e: |
| from .exc import unimplemented |
| |
| msg = f"Unsupported: {e.reason} with fake tensor propagation." |
| log.warning(msg) |
| raise unimplemented(msg) from e |
| |
| |
| def deepcopy_to_fake_tensor(obj, fake_mode): |
| with torch._subclasses.fake_tensor.FakeCopyMode(fake_mode): |
| return wrap_fake_exception(lambda: copy.deepcopy(obj)) |
| |
| |
| def rmse(ref, res): |
| """ |
| Calculate root mean squared error |
| """ |
| return torch.sqrt(torch.mean(torch.square(ref - res))) |
| |
| |
| def same( |
| ref, |
| res, |
| fp64_ref=None, |
| cos_similarity=False, |
| tol=1e-4, |
| equal_nan=False, |
| exact_dtype=True, |
| relax_numpy_equality=False, |
| ): |
| """Check correctness to see if ref and res match""" |
| if fp64_ref is None: |
| fp64_ref = ref |
| if isinstance(ref, (list, tuple, torch.nn.ParameterList, torch.Size)): |
| assert isinstance(res, (list, tuple)), f"type mismatch {type(ref)} {type(res)}" |
| return len(ref) == len(res) and all( |
| same( |
| ai, |
| bi, |
| fp64_refi, |
| cos_similarity, |
| tol, |
| equal_nan, |
| exact_dtype, |
| relax_numpy_equality, |
| ) |
| for ai, bi, fp64_refi in zip(ref, res, fp64_ref) |
| ) |
| elif isinstance(ref, dict): |
| assert isinstance(res, dict) |
| assert set(ref.keys()) == set( |
| res.keys() |
| ), f"keys mismatch {set(ref.keys())} == {set(res.keys())}" |
| for k in sorted(ref.keys()): |
| if not ( |
| same( |
| ref[k], |
| res[k], |
| fp64_ref[k], |
| cos_similarity=cos_similarity, |
| tol=tol, |
| equal_nan=equal_nan, |
| exact_dtype=exact_dtype, |
| relax_numpy_equality=relax_numpy_equality, |
| ) |
| ): |
| log.error("Accuracy failed for key name %s", k) |
| return False |
| return True |
| elif isinstance(ref, torch.Tensor): |
| assert not isinstance(ref, torch._subclasses.FakeTensor) |
| assert not isinstance(res, torch._subclasses.FakeTensor) |
| |
| if ref.is_sparse: |
| assert res.is_sparse |
| ref = ref.to_dense() |
| res = res.to_dense() |
| assert isinstance(res, torch.Tensor), f"type mismatch {type(ref)} {type(res)}" |
| if exact_dtype: |
| if ref.dtype != res.dtype: |
| log.error("dtype mismatch %s, %s", ref.dtype, res.dtype) |
| return False |
| if ref.dtype == torch.bool: |
| # triton stores bool as int8, so add this for more accurate checking |
| r = torch.allclose( |
| ref.to(dtype=torch.uint8), |
| res.to(dtype=torch.uint8), |
| atol=tol, |
| rtol=tol, |
| equal_nan=equal_nan, |
| ) |
| if not r: |
| log.error("Accuracy failed: uint8 tensor did not match") |
| return r |
| if cos_similarity: |
| ref = ref.flatten().to(torch.float32) |
| res = res.flatten().to(torch.float32) |
| if torch.allclose(ref, res, atol=tol, rtol=tol, equal_nan=True): |
| # early exit that handles zero/nan better |
| # cosine_similarity(zeros(10), zeros(10), dim=0) is 0 |
| return True |
| score = torch.nn.functional.cosine_similarity(ref, res, dim=0, eps=1e-6) |
| if score < 0.99: |
| log.warning("Similarity score=%s", score.cpu().detach().item()) |
| return score >= 0.99 |
| else: |
| if not exact_dtype: |
| ref = ref.to(res.dtype) |
| |
| # First try usual allclose |
| if torch.allclose(ref, res, atol=tol, rtol=tol, equal_nan=equal_nan): |
| return True |
| |
| # Check error from fp64 version |
| if fp64_ref.dtype == torch.float64: |
| ref_error = rmse(fp64_ref, ref).item() |
| res_error = rmse(fp64_ref, res).item() |
| multiplier = 2.0 |
| |
| if ( |
| fp64_ref.numel() < 1000 |
| or (ref.ndim == 4 and ref.shape[-1] == ref.shape[-2] == 1) |
| # large tol means a benchmark has been specified as REQUIRE_HIGHER_TOLERANCE |
| or tol >= 2 * 1e-2 |
| ): |
| # In the presence of noise, noise might dominate our error |
| # metric for smaller tensors. |
| # Similary, for 1x1 kernels, there seems to be high noise with amp. |
| multiplier = 3.0 |
| |
| passes_test = res_error <= (multiplier * ref_error + tol / 10.0) |
| if not passes_test: |
| log.error( |
| "RMSE (res-fp64): %.5f, (ref-fp64): %.5f and shape=%s", |
| res_error, |
| ref_error, |
| res.size(), |
| ) |
| # import pdb; pdb.set_trace() |
| return passes_test |
| |
| log.error("Accuracy failed: allclose not within tol=%s", tol) |
| return False |
| elif isinstance(ref, (str, int, type(None), bool, torch.device)): |
| r = ref == res |
| if not r: |
| log.error("Accuracy failed (%s): %s != %s", type(ref), ref, res) |
| return r |
| elif isinstance(ref, float): |
| r = math.isclose(ref, res, rel_tol=tol, abs_tol=tol) |
| if not r: |
| log.error( |
| "Accuracy failed (float): %s != %s (within tol=%s)", ref, res, tol |
| ) |
| return r |
| elif is_numpy_int_type(ref) or is_numpy_float_type(ref): |
| if relax_numpy_equality and not ( |
| is_numpy_int_type(res) or is_numpy_float_type(res) |
| ): |
| ref = ref.item() |
| r = (type(ref) is type(res)) and (ref == res) |
| if not r: |
| log.error("Accuracy failed (numpy): %s != %s", ref, res) |
| return r |
| elif is_numpy_ndarray(ref): |
| return (type(ref) is type(res)) and (ref == res).all() |
| elif type(ref).__name__ in ( |
| "MaskedLMOutput", |
| "Seq2SeqLMOutput", |
| "CausalLMOutputWithCrossAttentions", |
| "LongformerMaskedLMOutput", |
| "Instances", |
| "SquashedNormal", |
| "Boxes", |
| "Normal", |
| "TanhTransform", |
| "Foo", |
| "Variable", |
| ): |
| assert type(ref) is type(res) |
| return all( |
| same( |
| getattr(ref, key), |
| getattr(res, key), |
| getattr(fp64_ref, key), |
| cos_similarity=cos_similarity, |
| tol=tol, |
| equal_nan=equal_nan, |
| exact_dtype=exact_dtype, |
| relax_numpy_equality=relax_numpy_equality, |
| ) |
| for key in ref.__dict__.keys() |
| ) |
| else: |
| raise RuntimeError(f"unsupported type: {type(ref).__name__}") |
| |
| |
| def format_func_info(code): |
| short_filename = code.co_filename.split("/")[-1] |
| return f"'{code.co_name}' ({short_filename}:{code.co_firstlineno})" |
| |
| |
| @contextlib.contextmanager |
| def disable_cache_limit(): |
| prior = config.cache_size_limit |
| config.cache_size_limit = sys.maxsize |
| |
| try: |
| yield |
| finally: |
| config.cache_size_limit = prior |
| |
| |
| # map from transformed code back to original user code |
| orig_code_map = ExactWeakKeyDictionary() |
| |
| # keep a record of code_obj -> list of guard failure reasons for logging |
| guard_failures = collections.defaultdict(list) |
| |
| # keep record of compiled code, if we are in "error if recompile" |
| # to track code that dynamo has compiled previously |
| seen_code_map = ExactWeakKeyDictionary() |
| |
| |
| class CompileProfiler: |
| """Utility for profiling how and what dynamo would compile. |
| |
| Can be used for |
| * diagnosing recompilation issues |
| * determining an appropriate compile cache limit |
| * (TODO)confirming which functions got compiled/skipped |
| """ |
| |
| def __init__(self): |
| self.frame_count = 0 |
| self.op_count = 0 |
| self.backend_ctx_ctor = lambda: disable_cache_limit() |
| |
| def __call__(self, gm: torch.fx.GraphModule, example_inputs): |
| self.frame_count += 1 |
| for node in gm.graph.nodes: |
| if "call" in node.op: |
| self.op_count += 1 |
| return gm.forward |
| |
| def get_metrics(self): |
| return {"guard_failures": guard_failures} |
| |
| def report(self): |
| metrics = self.get_metrics() |
| gf = metrics["guard_failures"] |
| |
| def num_recompiles(code): |
| return len(gf[code]) |
| |
| def recompile_reasons(code): |
| return "\n".join([str(x) for x in gf[code]]) |
| |
| summarized_gf = [ |
| [format_func_info(code), num_recompiles(code), recompile_reasons(code)] |
| for code in gf |
| ] |
| |
| def graph_break_report(): |
| if "graph_break" in counters: |
| graph_breaks = counters["graph_break"] |
| return tabulate( |
| [[msg, graph_breaks[msg]] for msg in graph_breaks], |
| headers=["Graph Break Reason", "Count"], |
| ) |
| |
| def recompilation_report(): |
| if len(gf): |
| max_recompiles = max([num_recompiles(code) for code in gf]) |
| recomp_table = tabulate( |
| summarized_gf, |
| headers=["Function", "Recompiles", "Recompile Reasons"], |
| ) |
| return recomp_table + textwrap.dedent( |
| f""" |
| |
| Set torch._dynamo.config.cache_size_limit to {max_recompiles} to avoid being cache limited. |
| """ |
| ) |
| |
| report = textwrap.dedent( |
| """ |
| Torchdynamo Profiler Report |
| =========================== |
| |
| Graph Breaks |
| ------------ |
| Graph breaks happen when torchdynamo encounters code it can't safely trace. |
| If you want to find out why breaks are happening, check below for each break reason |
| You may gain additional insight by passing `fullgraph=True` to torch.compile, |
| to stop at the first break. |
| |
| """ |
| ) |
| report += graph_break_report() or "No graph breaks detected." |
| report += textwrap.dedent( |
| """ |
| |
| Recompilation |
| ------------- |
| These subgraphs were recompiled more than once due to guard failures |
| Guard failures indicate some condition assumed to be static by the tracer changed, |
| making it unsafe to reuse the compiled program. |
| |
| """ |
| ) |
| report += recompilation_report() or "No recompilation detected.\n" |
| return report |
| |
| |
| # return same dir unless user changes config between calls |
| @functools.lru_cache(None) |
| def _get_debug_dir(root_dir): |
| dir_name = ( |
| "run_" |
| + datetime.datetime.now().strftime("%Y_%m_%d_%H_%M_%S_%f") |
| # use pid to avoid conflicts among ranks |
| + "-pid_" |
| + str(os.getpid()) |
| ) |
| return os.path.join(root_dir, dir_name) |
| |
| |
| def get_debug_dir(): |
| debug_root = config.debug_dir_root |
| return _get_debug_dir(debug_root) |
| |
| |
| def get_fake_value(node, tx): |
| """ |
| Run the computation represented by `node` using fake tensors and return the result. |
| """ |
| from .exc import ( |
| TorchRuntimeError, |
| unimplemented, |
| Unsupported, |
| UserError, |
| UserErrorType, |
| ) |
| |
| op = node.op |
| |
| def fake_wrapper(e): |
| if isinstance(e, torch.Tensor): |
| assert isinstance(e, FakeTensor) |
| return e |
| |
| def visit(n: torch.fx.Node): |
| return n.meta["example_value"] |
| |
| args, kwargs = torch.fx.node.map_arg((node.args, node.kwargs), visit) |
| args = tree_map(fake_wrapper, args) |
| kwargs = tree_map(fake_wrapper, kwargs) |
| |
| nnmodule = None |
| if op == "call_method" and len(args) > 0 and isinstance(args[0], torch.nn.Module): |
| # If the first argument is nn.Module, should copy to fake mode. |
| args = (deepcopy_to_fake_tensor(args[0], tx.fake_mode),) + tuple(args[1:]) |
| |
| if op == "call_module": |
| nnmodule = tx.output.nn_modules[node.target] |
| |
| if is_lazy_module(nnmodule) and hasattr(nnmodule, "_initialize_hook"): |
| # In the case of a lazy module, we want to run |
| # the pre-hooks which initialize it. |
| # Afterwards, lazy module deletes its pre-hooks |
| # to avoid treating it as lazy on subsequent recompile. |
| nnmodule._infer_parameters(nnmodule, args) |
| |
| # no matter it's lazy module or not, we should copy to fake mode. |
| nnmodule = deepcopy_to_fake_tensor(nnmodule, tx.fake_mode) |
| |
| try: |
| with tx.fake_mode, enable_python_dispatcher(): |
| return wrap_fake_exception( |
| lambda: run_node(tx.output, node, args, kwargs, nnmodule) |
| ) |
| except Unsupported: |
| raise |
| except RuntimeError as e: |
| cause = e |
| if e.__cause__ is not None: |
| cause = e.__cause__ |
| if isinstance( |
| cause, torch._subclasses.fake_tensor.DataDependentOutputException |
| ): |
| unimplemented(f"data dependent operator: {cause.func}") |
| elif isinstance( |
| cause, torch._subclasses.fake_tensor.DynamicOutputShapeException |
| ): |
| unimplemented(f"dynamic shape operator: {cause.func}") |
| elif isinstance( |
| cause, torch._subclasses.fake_tensor.UnsupportedOperatorException |
| ): |
| unimplemented( |
| f"unsupported operator: {cause.func} (see " |
| "https://docs.google.com/document/d/1GgvOe7C8_NVOMLOCwDaYV1mXXyHMXY7ExoewHqooxrs/edit#heading=h.64r4npvq0w0" |
| " for how to fix)" |
| ) |
| elif isinstance( |
| cause, torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode |
| ): |
| unimplemented("guard on data-dependent symbolic int/float") |
| elif isinstance(cause, torch.utils._sympy.value_ranges.ValueRangeError): |
| raise UserError(UserErrorType.CONSTRAIN_VIOLATION, e.args[0]) from e |
| raise TorchRuntimeError() from e |
| |
| |
| def run_node(output_graph, node, args, kwargs, nnmodule): |
| """ |
| Runs a given node, with the given args and kwargs. |
| |
| Behavior is dicatated by a node's op. |
| |
| run_node is useful for extracting real values out of nodes. |
| See get_real_value for more info on common usage. |
| |
| Note: The output_graph arg is only used for 'get_attr' ops |
| Note: The nnmodule arg is only used for 'call_module' ops |
| |
| Nodes that are not call_function, call_method, call_module, or get_attr will |
| raise an AssertionError. |
| """ |
| op = node.op |
| try: |
| if op == "call_function": |
| return node.target(*args, **kwargs) |
| elif op == "call_method": |
| return getattr(args[0], node.target)(*args[1:], **kwargs) |
| elif op == "call_module": |
| assert nnmodule is not None |
| return nnmodule(*args, **kwargs) |
| elif op == "get_attr": |
| return output_graph.get_submodule(node.target) |
| elif op == "placeholder": |
| assert "example_value" in node.meta |
| return node.meta["example_value"] |
| except Exception as e: |
| raise RuntimeError( |
| f"Failed running {op} {node.target}(*{args}, **{kwargs}):\n{e}\n(scroll up for backtrace)" |
| ) from e |
| raise AssertionError(op) |
| |
| |
| def get_real_value(node, output_graph): |
| """ |
| Run the actual computation represented by `node` and return the result. |
| This will execute any dependent nodes in the graph as well. |
| """ |
| cache = output_graph.real_value_cache |
| if node in cache: |
| return cache[node] |
| |
| op = node.op |
| args, kwargs = torch.fx.node.map_arg( |
| (node.args, node.kwargs), |
| lambda n: get_real_value(n, output_graph), |
| ) |
| |
| if op == "call_module": |
| nn_module = output_graph.nn_modules[node.target] |
| if not is_lazy_module(nn_module): |
| nn_module = copy.deepcopy(nn_module) |
| else: |
| # In the case of a lazy module, we want to run |
| # the pre-hooks which initialize it |
| nn_module(*args, **kwargs) |
| else: |
| nn_module = None |
| |
| try: |
| real_value = run_node(output_graph, node, args, kwargs, nn_module) |
| cache[node] = real_value |
| except RuntimeError as e: |
| raise TorchRuntimeError() from e |
| return real_value |
| |
| |
| def assert_no_fake_params_or_buffers(gm): |
| from torch._subclasses.fake_tensor import FakeTensorConfig |
| |
| def stack_or_hint(t): |
| if FakeTensorConfig.debug: |
| import traceback |
| |
| return f"FAKE TENSOR CREATION TRACEBACK: \n {traceback.format_list(t._debug_trace)}" |
| else: |
| return "Enable TORCH_FAKE_TENSOR_DEBUG=1 to get creation stack traces on fake tensors." |
| |
| for name, buffer in gm.named_buffers(): |
| assert not isinstance( |
| buffer, torch._subclasses.FakeTensor |
| ), f"Unexpected fake buffer {name} {stack_or_hint(buffer)}" |
| for name, param in gm.named_parameters(): |
| assert not isinstance( |
| param, torch._subclasses.FakeTensor |
| ), f"Unexpected fake param {name} {stack_or_hint(param)}" |
| |
| |
| def fqn(obj: Any): |
| """ |
| Returns the fully qualified name of the object. |
| """ |
| return f"{obj.__module__}.{obj.__qualname__}" |
| |
| |
| def ifdyn(count1, count2): |
| if torch._dynamo.config.dynamic_shapes: |
| return count1 |
| else: |
| return count2 |
| |
| |
| def ifunspec(count1, count2): |
| if torch._dynamo.config.dynamic_shapes and not torch._dynamo.config.specialize_int: |
| return count1 |
| else: |
| return count2 |
| |
| |
| def import_submodule(mod: types.ModuleType): |
| """ |
| Ensure all the files in a given submodule are imported |
| """ |
| for filename in sorted(os.listdir(os.path.dirname(mod.__file__))): |
| if filename.endswith(".py") and filename[0] != "_": |
| importlib.import_module(f"{mod.__name__}.{filename[:-3]}") |
| |
| |
| def object_has_getattribute(value: Any): |
| try: |
| if isinstance( |
| inspect.getattr_static(type(value), "__getattribute__"), |
| types.FunctionType, |
| ): |
| return True |
| except AttributeError: |
| pass |
| return False |
| |
| |
| def get_custom_getattr(value: Any): |
| try: |
| getattr_fn = inspect.getattr_static(type(value), "__getattr__") |
| except AttributeError: |
| getattr_fn = None |
| if getattr_fn is torch.nn.Module.__getattr__: |
| # ignore this case of getattr |
| getattr_fn = None |
| return getattr_fn |
| |
| |
| class TensorStaticReason(enum.Enum): |
| PARAMETER = 2 |
| CONFIG_NOT_DYN = 3 |
| NOT_TENSOR = 4 |
| NN_MODULE_PROPERTY = 5 |
| |
| |
| def tensor_static_reason_to_message(reason: TensorStaticReason): |
| if reason == TensorStaticReason.PARAMETER: |
| return "mark_dynamic on parameter, parameters are always static today." |
| if reason == TensorStaticReason.CONFIG_NOT_DYN: |
| return "mark_dynamic usage with dynamic_shapes=False is not yet supported" |
| if reason == TensorStaticReason.NOT_TENSOR: |
| return "mark_dynamic on a non tensor, how did this happen?" |
| if reason == TensorStaticReason.NN_MODULE_PROPERTY: |
| return "tensor is static because it is nn module associated." |
| raise AssertionError(f"Illegal reason {reason}") |
| |
| |
| def tensor_always_has_static_shape( |
| tensor: Union[torch.Tensor, Any], is_tensor: bool, guard_source: "GuardSource" |
| ) -> Tuple[bool, TensorStaticReason]: |
| """ |
| Given a tensor, source, and is_tensor flag, determine if a shape should be static. |
| |
| Args: |
| tensor - the real tensor to evaluate, parameters force a static shape. |
| is_tensor - internal dynamo check, esentially "is_tensor": target_cls is TensorVariable, |
| tensors not in a TensorVariable for whatever reason are forced static. |
| |
| Returns a tuple, where the first element is the bool of whether or not this tensor should have a static shape. |
| The second element is a TensorStaticReason, useful for passing to tensor_static_reason_to_message if needed. |
| """ |
| if type(tensor) is torch.nn.Parameter: |
| return True, TensorStaticReason.PARAMETER |
| if config.dynamic_shapes is False: |
| return True, TensorStaticReason.CONFIG_NOT_DYN |
| if not is_tensor: |
| return True, TensorStaticReason.NOT_TENSOR |
| if guard_source.is_nn_module(): |
| return True, TensorStaticReason.NN_MODULE_PROPERTY |
| return False, None |
| |
| |
| class LazyString: |
| def __init__(self, func, *args, **kwargs): |
| self.func = func |
| self.args = args |
| self.kwargs = kwargs |
| |
| def __str__(self): |
| return self.func(*self.args, **self.kwargs) |
| |
| |
| def lazy_format_graph_code(name, gm, maybe_id=None): |
| def format_name(): |
| if maybe_id is not None: |
| return f"{name} {maybe_id}" |
| else: |
| return name |
| |
| return LazyString( |
| lambda: _format_graph_code( |
| f"===== {format_name()} =====\n", |
| gm.forward.__code__.co_filename, |
| gm.print_readable(print_output=False), |
| ) |
| ) |
| |
| |
| def _format_graph_code(name, filename, graph_str): |
| return f"TRACED GRAPH\n {name} {filename} {graph_str}\n" |
| |
| |
| def lazy_format_graph_tabular(fn_name, gm): |
| def inner(): |
| try: |
| from tabulate import tabulate # TODO: Check that this is installed |
| except ImportError: |
| return ( |
| "Tabulate module missing, please install tabulate to log the graph in tabular format, logging code instead:\n" |
| + format_graph_code(fn_name, gm) |
| ) |
| |
| node_specs = [ |
| [n.op, n.name, n.target, n.args, n.kwargs] for n in gm.graph.nodes |
| ] |
| graph_str = tabulate( |
| node_specs, headers=["opcode", "name", "target", "args", "kwargs"] |
| ) |
| return _format_graph_code(fn_name, gm.forward.__code__.co_filename, graph_str) |
| |
| return LazyString(inner) |
| |
| |
| def format_bytecode(prefix, name, filename, line_no, code): |
| return f"{prefix} {name} {filename} line {line_no} \n{dis.Bytecode(code).dis()}\n" |
| |
| |
| def nnmodule_has_hooks( |
| mod, |
| check_forward_hooks=False, |
| check_backward_hooks=False, |
| check_state_dict_hooks=False, |
| ): |
| """ |
| Sometimes its useful to differentiate between types of hooks such as forward/backward/pre |
| hooks executed during module.__call__, and state_dict hooks which are executed separately. |
| """ |
| hook_dicts_to_check = [] |
| check_all_hooks = ( |
| not check_forward_hooks |
| and not check_backward_hooks |
| and not check_state_dict_hooks |
| ) |
| if check_forward_hooks or check_all_hooks: |
| hook_dicts_to_check.extend( |
| [ |
| "_forward_pre_hooks", |
| "_forward_hooks", |
| ] |
| ) |
| if check_backward_hooks or check_all_hooks: |
| hook_dicts_to_check.extend( |
| [ |
| "_backward_pre_hooks", |
| "_backward_hooks", |
| ] |
| ) |
| if check_state_dict_hooks: |
| hook_dicts_to_check.extend( |
| [ |
| "_state_dict_pre_hooks", |
| "_state_dict_hooks", |
| "_load_state_dict_pre_hooks", |
| "_load_state_dict_post_hooks", |
| ] |
| ) |
| return any(len(getattr(mod, x)) > 0 for x in hook_dicts_to_check if hasattr(mod, x)) |
| |
| |
| def to_numpy_helper(___tmp_0): |
| def convert(obj): |
| if isinstance(obj, torch_np.ndarray): |
| return obj.tensor.numpy() |
| else: |
| return obj |
| |
| if isinstance(___tmp_0, tuple): |
| return tuple([convert(obj) for obj in ___tmp_0]) |
| return convert(___tmp_0) |