blob: 7d3b0fc6ada43ccf48e0095b60065b573d1b9af5 [file] [log] [blame]
# mypy: allow-untyped-defs
# This module contains functions that *will be allowed* by dynamo
import functools
from typing import List
import torch
import torch.utils._pytree as pytree
try:
import numpy as np
except ModuleNotFoundError:
np = None # type: ignore[assignment]
def is_compiling() -> bool:
"""
Indicates whether we are tracing/compiling with torch.compile() or torch.export().
If need to check specifically that TorchDynamo is used, then use
torch.compiler.is_dynamo_compiling().
TODO(khabinov): we should deprecate this function and use one of these two:
* torch.compiler.is_compiling(),
* torch.compiler.is_dynamo_compiling().
It will depend on the context where to use what.
"""
return torch.compiler.is_compiling()
def wrap_inline(fn):
"""
Create an extra frame around fn that is not in skipfiles
"""
@functools.wraps(fn)
def inner(*args, **kwargs):
return fn(*args, **kwargs)
return inner
def call_hook(hook, *args):
"""
Used by compiled autograd to handle hook returning None
"""
result = hook(*args)
if result is None:
return args[0]
return result
def wrap_numpy(f):
r"""Decorator that turns a function from ``np.ndarray``s to ``np.ndarray``s into a function
from ``torch.Tensor``s to ``torch.Tensor``s.
"""
if not np:
return f
@functools.wraps(f)
def wrap(*args, **kwargs):
args, kwargs = pytree.tree_map_only(
torch.Tensor, lambda x: x.numpy(), (args, kwargs)
)
out = f(*args, **kwargs)
return pytree.tree_map_only(np.ndarray, lambda x: torch.as_tensor(x), out)
return wrap
class FakeBackwardCFunction:
def __init__(
self,
real: torch.autograd.function.BackwardCFunction,
saved_tensors: List[torch.Tensor],
):
self.real = real
self.saved_tensors = saved_tensors
def __getattr__(self, name):
# route any attribute that isn't defined on this obj
return getattr(self.real, name)
# This function corresponds to the "eager" implementation of a lifted autograd.Function.backward
def call_backward(backward_c_function, saved_tensors, *args):
fake = FakeBackwardCFunction(backward_c_function, saved_tensors)
grads = fake._forward_cls.backward(fake, *args) # type: ignore[attr-defined]
# in eager, we wrap in a tuple when there's only one grad output
if type(grads) is not tuple:
grads = (grads,)
return grads
def untyped_storage_size(x: torch.Tensor):
return x.untyped_storage().size()
class FakeCompiledAutogradEngine:
@staticmethod
def queue_callback(final_callbacks, cb):
final_callbacks.append(cb)
@staticmethod
def exec_final_callbacks(final_callbacks):
i = 0
while i < len(final_callbacks):
cb = final_callbacks[i]
cb()
i += 1
final_callbacks.clear()
@staticmethod
def _exec_final_callbacks_stub():
pass
def call_hook_from_backward_state(*args, bw_state, hook_name: str, **kwargs):
return getattr(bw_state, hook_name)(*args, **kwargs)
def call_module_hooks_from_backward_state(
_, result, *args, bw_state, hooks_name: str, module_name: str
):
module = getattr(bw_state, module_name)
hooks = getattr(bw_state, hooks_name)
for hook in hooks:
new_result = hook(module, result, *args)
if new_result is not None:
result = new_result
return result