blob: 26436f1a5ba02952ffe60b0756f6b25dcb912355 [file] [log] [blame]
import collections
import dataclasses
import warnings
from contextlib import contextmanager, nullcontext
from functools import wraps
from typing import Any, Callable, Dict, List, Optional, Tuple
import torch
import torch.fx.traceback as fx_traceback
import torch.nn as nn
import torch.utils._pytree as pytree
import torch.utils.dlpack
from torch import Tensor
from torch._subclasses import FakeTensorMode
from torch.fx import immutable_collections, Interpreter
from torch.fx.experimental.symbolic_shapes import ShapeEnv
from torch.nn.utils import stateless
from functorch import make_fx
from functorch._C import CompileCache
from functorch.experimental import functionalize
from torch._dispatch.python import enable_python_dispatcher
from . import config
from .named_members_polyfill import _named_buffers, _named_parameters
from .partitioners import default_partition
try:
from torchdynamo import disable as disable_torchdynamo
except ImportError:
def disable_torchdynamo(x):
return x
try:
from torchdynamo.utils import dynamo_timed
except ImportError:
def dynamo_timed(x):
return x
pytree._register_pytree_node(
immutable_collections.immutable_list,
lambda x: (list(x), None),
lambda x, c: immutable_collections.immutable_list(x),
)
pytree._register_pytree_node(
immutable_collections.immutable_dict,
lambda x: (list(x.values()), list(x.keys())),
lambda x, c: immutable_collections.immutable_dict(
{key: value for key, value in zip(c, x)}
),
)
aten = torch.ops.aten
@contextmanager
def preserve_rng_state():
rng_state = torch.clone(torch.random.get_rng_state())
if torch.cuda.is_available():
cuda_rng_state = torch.clone(torch.cuda.get_rng_state())
try:
yield
finally:
torch.random.set_rng_state(rng_state)
if torch.cuda.is_available():
torch.cuda.set_rng_state(cuda_rng_state)
# Set up hooks so that during backward the fx's stack_trace is properly set
callback_set = False
def setup_stacktrace_preservation_hooks(roots: List):
def iter_graph(roots):
if not roots:
return
seen = set()
q = collections.deque()
for node in roots:
if node is not None:
seen.add(node)
q.append(node)
while q:
node = q.popleft()
for fn, _idx in node.next_functions:
if fn in seen or fn is None:
continue
seen.add(fn)
q.append(fn)
yield node
def get_callback(saved_stack_):
def callback():
global callback_set
fx_traceback.set_stack_trace(saved_stack_)
callback_set = False
return callback
def get_prehook(stack_):
def prehook(grad_output):
global callback_set
if not callback_set:
torch.autograd.variable.Variable._execution_engine.queue_callback(
get_callback(fx_traceback.format_stack())
)
callback_set = True
fx_traceback.set_stack_trace(stack_)
return prehook
def get_posthook(special_stack_):
def posthook(grad_input, grad_output):
fx_traceback.set_stack_trace(special_stack_)
return posthook
for node in iter_graph(roots):
forward_node_stack = node.metadata.get("traceback_", [])
node.register_prehook(get_prehook(forward_node_stack))
special_stack = forward_node_stack.copy()
special_stack.append(
"Gradient addition node due to mulitple use of tensor around:"
)
node.register_hook(get_posthook(special_stack))
def create_joint_forward_backward(fn):
def joint_forward_backward(
primals: List[Any], tangents: List[Any]
) -> Tuple[List[Any], List[Any]]:
# Call the forward pass
outs = fn(*primals)
# Get the inputs that need gradients
grad_primals = []
inputs_needs_grads = []
for p in primals:
is_grad_tensor = isinstance(p, Tensor) and p.requires_grad
inputs_needs_grads.append(is_grad_tensor)
if is_grad_tensor:
grad_primals.append(p)
# Get the outputs that need gradients
assert len(tangents) == len(outs)
needed_outs = []
needed_tangents = []
for out, tangent in zip(outs, tangents):
if isinstance(out, Tensor) and out.requires_grad:
needed_outs.append(out)
needed_tangents.append(tangent)
setup_stacktrace_preservation_hooks([out.grad_fn for out in needed_outs])
backward_out = []
# Call the backwards pass
if grad_primals:
with fx_traceback.override_stack_trace():
backward_out = torch.autograd.grad(
needed_outs,
grad_primals,
grad_outputs=needed_tangents,
allow_unused=True,
)
backward_out_iter = iter(backward_out)
return outs, [
next(backward_out_iter) if i else None for i in inputs_needs_grads
]
return joint_forward_backward
def normalize_as_list(x):
if isinstance(x, tuple):
return list(x)
elif isinstance(x, list):
return x
return [x]
aot_autograd_decompositions = {}
# This is a list since looking forward, we can have this arbitrarily nested.
graph_being_compiled: List[str] = []
nth_graph: int = 0
model_name: str = "model"
def set_model_name(name):
global model_name
model_name = name
def get_aot_compilation_context() -> Tuple[List[str], str, int]:
return list(graph_being_compiled), model_name, nth_graph
def get_aot_graph_name() -> str:
"""
Returns the name of the graph being compiled.
"""
global model_name, graph_being_compiled, nth_graph
return f"{model_name}_{'_'.join(graph_being_compiled)}_{nth_graph}"
get_graph_being_compiled = get_aot_graph_name
@contextmanager
def track_graph_compiling(graph_name, increment_index=False):
global graph_being_compiled
graph_being_compiled = [graph_name]
yield
if increment_index:
global nth_graph
nth_graph += 1
graph_being_compiled = []
def make_boxed_func(f):
def g(args):
return f(*args)
g._boxed_call = True
return g
def make_boxed_compiler(compiler):
@wraps(compiler)
def f(fx_g, inps):
out_f = compiler(fx_g, inps)
fx_g = make_boxed_func(out_f)
return fx_g
return f
def call_func_with_args(f, args, steal_args=False):
if not steal_args:
args = list(args)
assert isinstance(args, list)
if hasattr(f, "_boxed_call"):
out = normalize_as_list(f(args))
else:
# TODO: Please remove soon
# https://github.com/pytorch/pytorch/pull/83137#issuecomment-1211320670
warnings.warn(
"Your compiler for AOTAutograd is returning a a function that doesn't take boxed arguments. "
"Please wrap it with functorch.compile.make_boxed_func or handle the boxed arguments yourself. "
"See https://github.com/pytorch/pytorch/pull/83137#issuecomment-1211320670 for rationale."
)
out = normalize_as_list(f(*args))
return out
@dataclasses.dataclass
class AOTConfig:
"""
Configuration for AOTDispatcher
"""
fw_compiler: Callable
bw_compiler: Callable
partition_fn: Callable
decompositions: Dict[Callable, Callable]
def aot_dispatch_base(flat_fn, flat_args: List[Tensor], aot_config: AOTConfig):
fw_module = make_fx(flat_fn, aot_config.decompositions)(*flat_args)
with track_graph_compiling("inference"):
compiled_fw = aot_config.fw_compiler(fw_module, flat_args)
@wraps(compiled_fw)
def new_fn(args):
fw_outs = call_func_with_args(compiled_fw, args)
return fw_outs
return new_fn
def aot_dispatch_autograd(flat_fn, flat_args: List[Tensor], aot_config: AOTConfig):
# Deduplicate inputs. Suppose you have:
#
# [a, b, a, c]
#
# We want:
#
# remove_dupe_args([a, b, a, c]) == [a, b, c]
# add_dupe_args([a, b, c]) == [a, b, a, c]
#
# This is done via (respectively):
#
# seen_args = {2} # what to drop
# add_dupe_map = { # how to get args from the deduped list
# 0: 0,
# 1: 1,
# 2: 0,
# 3: 2,
# }
#
# Whether to use flat_args or deduped_flat_args? flat_fn takes flat_args,
# and the autograd.Function must take deduped_flat_args; everything
# else is just getting the types right.
seen_args = {}
keep_arg_mask = []
dropped_args = False
add_dupe_map = {}
duped_arg_len = len(flat_args)
j = 0 # index into deduped_flat_args
for i, t in enumerate(flat_args):
if t in seen_args:
keep_arg_mask.append(False)
dropped_args = True
add_dupe_map[i] = seen_args[t]
continue
keep_arg_mask.append(True)
seen_args[t] = j
add_dupe_map[i] = j
j += 1
# NB: Hot path, avoid set lookups here
def remove_dupe_args(args):
if not dropped_args:
return args
return [t for t, keep in zip(args, keep_arg_mask) if keep]
def add_dupe_args(args):
if not dropped_args:
return args
return [args[add_dupe_map[i]] for i in range(duped_arg_len)]
deduped_flat_args = remove_dupe_args(flat_args)
joint_forward_backward = create_joint_forward_backward(lambda *args: flat_fn(*add_dupe_args(args)))
out = flat_fn(*flat_args)
out = pytree.tree_map(
lambda x: x.detach().contiguous() if isinstance(x, Tensor) else x,
out,
)
if isinstance(out, (list, tuple)):
_num_outs = len(out)
else:
_num_outs = 1
joint_inputs = (deduped_flat_args, out)
if config.use_functionalize:
# Trace once without decompositions, into a graph of ATen ops.
# NB: tracing_mode is real, as it's assumed the calling context setup
# fake tensor mode / symbolic shapes if that is needed
fx_g = make_fx(joint_forward_backward)(*joint_inputs)
def fake_fn(primals, tangents):
with torch.fx.traceback.override_stack_trace():
return torch.fx.Interpreter(fx_g).run(primals, tangents)
# Trace a second time, running functionalization, and THEN running decompositions.
# functionalization only acts on ATen today, and doesn't currently handle
# view and inplace ops that come from primtorch.
# Eventually, functionalization should support primtorch view/inplace ops,
# which will make it ok to run decompositions before functionalization.
fx_g = make_fx(functionalize(fake_fn), aot_config.decompositions)(*joint_inputs)
fx_g.graph.eliminate_dead_code()
fx_g.recompile()
else:
fx_g = make_fx(joint_forward_backward, aot_config.decompositions)(*joint_inputs)
if config.debug_joint:
print("====== Joint graph ======")
fx_g.print_readable()
with torch.no_grad():
with track_graph_compiling("joint"):
fw_module, bw_module = aot_config.partition_fn(fx_g, joint_inputs)
if config.debug_graphs:
print("====== Forward graph ======")
fw_module.print_readable()
print("====== Backward graph ======")
bw_module.print_readable()
with track_graph_compiling("forward"):
compiled_fw_func = aot_config.fw_compiler(fw_module, deduped_flat_args)
if config.debug_partitioner:
fw_outs = call_func_with_args(compiled_fw_func, deduped_flat_args)
activation_sizes = 0
for out in fw_outs[_num_outs:]:
if isinstance(out, torch.Tensor):
activation_sizes += out.storage().nbytes()
print(f"Real Activations Stored(GB): {activation_sizes/1e9}")
class CompiledFunction(torch.autograd.Function):
compiled_fw = compiled_fw_func
compiled_bw = None
num_outs = _num_outs
@staticmethod
@disable_torchdynamo
def forward(ctx, *deduped_flat_tensor_args):
fw_outs = call_func_with_args(
CompiledFunction.compiled_fw, deduped_flat_tensor_args
)
num_outs = CompiledFunction.num_outs
ctx.save_for_backward(*fw_outs[num_outs:])
return tuple(fw_outs[0:num_outs])
@staticmethod
@disable_torchdynamo
def backward(ctx, *flat_args):
contiguous_args = [t.contiguous() for t in flat_args]
all_args = list(ctx.saved_tensors) + list(contiguous_args)
if CompiledFunction.compiled_bw is None:
with track_graph_compiling("backward", True):
CompiledFunction.compiled_bw = aot_config.bw_compiler(
bw_module, all_args
)
ctx.maybe_clear_saved_tensors()
out = call_func_with_args(
CompiledFunction.compiled_bw, all_args, steal_args=True
)
return tuple(out)
@wraps(CompiledFunction.apply)
def compiled_function(*args):
return CompiledFunction.apply(*remove_dupe_args(args))
return compiled_function
@dynamo_timed
def create_aot_dispatcher_function(
flat_fn, flat_args: List[Tensor], aot_config: AOTConfig
):
"""
Traces the forward and backward graphs of the attr:`flat_fn` to generate a
joint graph. The joint graph is an Fx graph with Aten ops. Please refer to
the tracing mechanism to understand the graph capturing details.
The joint graph is then passed through attr:`partition_fn` to isolate the
forward and backward portions, which are then respectively compiled via the
provided attr:`fw_compiler` and attr:`bw_compiler`.
The resulting compiled forward and backward graphs are then wrapped up in a
``torch.autograd.Function`` object.
"""
# This is the main entry point.
# TODO: Chillee argues that dynamo itself should pass in fake tensors to
# the list of arguments when compiling; at the moment we do not do this
if aot_config.decompositions is None:
aot_config.decompositions = {}
aot_config.decompositions = {
**aot_autograd_decompositions,
**aot_config.decompositions,
}
# NB: don't bother setting allow_fallback_kernels; this should not actually
# be configurable in fake tensor, we should automatically do the right
# thing
fake_mode = FakeTensorMode() if config.use_fake_tensor else nullcontext()
python_dispatcher_mode = enable_python_dispatcher() if config.use_dynamic_shapes else nullcontext()
shape_env = ShapeEnv() if config.use_dynamic_shapes else None
with preserve_rng_state(), fake_mode, python_dispatcher_mode:
def process_inputs(flat_args):
if config.use_fake_tensor:
def convert(x):
return fake_mode.from_tensor(x, shape_env=shape_env)
return pytree.tree_map_only(Tensor, convert, flat_args)
else:
return flat_args
fake_flat_tensor_args = process_inputs(flat_args)
needs_autograd = (
any(
[
x.requires_grad
for x in fake_flat_tensor_args
if isinstance(x, Tensor)
]
)
and torch.is_grad_enabled()
)
# crappy version of dispatcher
# TODO: Do this properly
if needs_autograd:
return make_boxed_func(
aot_dispatch_autograd(flat_fn, fake_flat_tensor_args, aot_config)
)
else:
return aot_dispatch_base(flat_fn, fake_flat_tensor_args, aot_config)
class _CompileCache(CompileCache):
pass
# using a C++-based pytree reduces the overhead by about 50%
compile_cache = None
# Inspired by autodidax (thanks!)
class PytreeThunk:
spec = None
# These are some kinda dumb microoptimizations that save about 3-4 us of overhead.
is_simple = (
None # if the output spec is a tuple/list, we won't bother unflattening it.
)
is_really_simple = None # if the output spec is a LeafSpec
def set(self, spec):
assert self.spec is None or self.spec == spec
self.spec = spec
if type(self.spec) in [tuple, list] and all(
isinstance(i, pytree.LeafSpec) for i in spec.children_specs
):
self.is_simple = True
if isinstance(self.spec, pytree.LeafSpec):
self.is_really_simple = True
def unflatten(self, x):
if self.is_really_simple:
return x[0]
if self.is_simple:
return x
return pytree.tree_unflatten(x, self.spec)
def filter_tensor_and_static_args(args, static_argnums):
"""
Separate out the tensor and static args. Also, for the static args, store
the hash.
"""
tensor_args = []
static_args = []
static_args_hashed = []
for idx, arg in enumerate(args):
if idx not in static_argnums:
tensor_args.append(arg)
else:
static_args.append(arg)
static_args_hashed.append(arg.__hash__())
return tensor_args, static_args, static_args_hashed
def rearrange(tensor_args, static_args, static_argnums):
"""
Generate the args as per the original spec. static_argnums is sorted.
"""
tensor_index = 0
static_index = 0
index = 0
args = []
assert len(static_args) == len(static_argnums)
while tensor_index < len(tensor_args) and static_index < len(static_args):
if index == static_argnums[static_index]:
args.append(static_args[static_index])
static_index += 1
else:
args.append(tensor_args[tensor_index])
tensor_index += 1
index += 1
while tensor_index < len(tensor_args):
args.append(tensor_args[tensor_index])
tensor_index += 1
while static_index < len(static_args):
args.append(static_args[static_index])
static_index += 1
return args
KNOWN_TYPES = [torch.Tensor, int, str, float, bool]
def aot_function(
fn: Callable,
fw_compiler: Callable,
bw_compiler: Optional[Callable] = None,
partition_fn: Callable = default_partition,
decompositions: Optional[Dict] = None,
hasher_type: str = "StaticShapeHasher",
static_argnums: Optional[Tuple[int]] = None,
) -> Callable:
"""
Traces the forward and backward graph of :attr:`fn` using torch dispatch
mechanism, and then compiles the generated forward and backward graphs
through :attr:`fw_compiler` and :attr:`bw_compiler`.
:func:`aot_function` traces the forward and backward graph ahead of time,
and generates a joint forward and backward graph. :attr:`partition_fn` is
then used to separate out forward and backward graphs. The partitioner
function can be used to perform optimizations such as recomputation. One can
set `decompositions` dictionary to decompose the operators into a sequence
of core or simpler operators supported by the backend compilers.
:func:`aot_function` uses a compilation cache, based on input tensor
properties, to detect when there is a need of recompilation. By default, its
behavior is static, i.e., it recompiles if shape of any input tensor
changes.
:attr:`static_argnums` allows user to mark the arguments of the original
:attr:`fn` as static. This is useful when an argument is a non-tensor, e.g.,
``int`` or ``bool``. A change in the actual value of static arg causes
recompilation.
.. warning::
This API is experimental and likely to change.
Args:
fn (Callable): A Python function that takes one ore more arguments. Must
return one or more Tensors.
fw_compiler (Callable): A Python function that accepts an Fx graph with
Aten ops and input args, and returns a Callable that semantically is
equivalent to the input Fx graph.
bw_compiler (Optional[Callable]): A Python function that accepts an
Fx graph with Aten ops and input args, and returns a Callable that
semantically is equivalent to the input Fx graph. Default: None
(when None, it defaults to the :attr:`fw_compiler`)
partition_fn (Callable): A Python function that takes a joint forward
and backward graph, and partitions it into separate forward and
backward graphs.
decompositions (Dict): A dictionary to define the decomposition of
larger Aten ops into simpler or core Aten ops.
static_argnums (Optional[Tuple[Int]]): An option tuple of ints to mark
the arguments of the function as static.
Returns:
Returns a ``Callable`` that retains the eager behavior of the original
:attr:`fn`, but with forward and backward graph compiled via
:attr:`fw_compile` and :attr:`bw_compile`.
A simple example usage of :func:`aot_function` is as follows. This example
will print the forward and backward graphs of the function ``fn``
>>> fn = lambda x : x.sin().cos()
>>> def print_compile_fn(fx_module, args):
>>> print(fx_module)
>>> return fx_module
>>> aot_fn = aot_function(fn, print_compile_fn)
>>> x = torch.randn(4, 5, requires_grad=True)
>>> aot_fn(x)
The static argnums are used to mark the non-tensor arguments as static. An
example is as follows where the dropout probability is as argument to the
original function.
>>> def fn(input, bias, residual, p: float):
>>> a = torch.add(input, bias)
>>> b = torch.nn.functional.dropout(a, p, training=True)
>>> c = b + residual
>>> return c
>>> aot_fn = aot_function(fn, print_compile_fn, static_argnums=(3,))
"""
global compile_cache
if compile_cache is None:
compile_cache = CompileCache()
if bw_compiler is None:
bw_compiler = fw_compiler
aot_config = AOTConfig(
fw_compiler=fw_compiler,
bw_compiler=bw_compiler,
partition_fn=partition_fn,
decompositions=decompositions,
)
cached_res = None
fn_id = id(fn)
fw_compiler_id = id(fw_compiler)
bw_compiler_id = id(bw_compiler)
if isinstance(static_argnums, int):
static_argnums = [static_argnums]
elif static_argnums is not None and len(static_argnums) == 0:
static_argnums = None
elif static_argnums is not None:
static_argnums = list(static_argnums)
static_argnums.sort()
@wraps(fn)
def returned_function(*args, **kwargs):
global compile_cache
nonlocal cached_res
# Separate out static args if static_argnums is present
tensor_args = args
static_args = []
# TODO - move the hashing part of static_args to C++.
static_args_hashed = []
if static_argnums is not None:
(
tensor_args,
static_args,
static_args_hashed,
) = filter_tensor_and_static_args(args, static_argnums)
# Now flatten the tensor args
flat_tensor_args, _ = pytree.tree_flatten((tensor_args, kwargs))
# Check if the fn is already compiled
num_tensor_args = len(flat_tensor_args)
flat_args_for_cache = flat_tensor_args + static_args_hashed
cached_res = compile_cache.at(
fn_id,
fw_compiler_id,
bw_compiler_id,
num_tensor_args,
hasher_type,
*flat_args_for_cache,
)
# Compile the function and save it in the cache
if cached_res is None:
# Save the args_spec for flat_tensor_args to unflatten while tracing
_, tensor_args_spec = pytree.tree_flatten((tensor_args, kwargs))
out_spec = PytreeThunk()
def flat_fn(*flat_tensor_args):
# The input are flattened tensor args. Prepare the args in the
# order that original function expects. Add static args as well.
# They will appear as tensor constants in the traced graph.
nonlocal out_spec, static_args
tensor_args, kwargs = pytree.tree_unflatten(
flat_tensor_args, tensor_args_spec
)
if static_argnums is None:
args = tensor_args
else:
args = rearrange(tensor_args, static_args, static_argnums)
tree_out = fn(*args, **kwargs)
flat_out, spec = pytree.tree_flatten(tree_out)
for i in flat_out:
is_known_type = False
for j in KNOWN_TYPES:
if isinstance(i, j):
is_known_type = True
break
if not is_known_type:
raise RuntimeError(
f"Found {type(i)} in output, which is not a known type. "
"If this type holds tensors, you need to register a pytree for it. "
"See https://github.com/pytorch/functorch/issues/475 for a brief "
"explanation why. If you don't need to register a pytree, please "
"leave a comment explaining your use case and we'll make this more "
"ergonomic to deal with"
)
out_spec.set(spec)
return flat_out
compiled_fn = create_aot_dispatcher_function(
flat_fn,
flat_tensor_args,
aot_config,
)
cached_res = (compiled_fn, out_spec)
# Save the compiled_fn in the cache
compile_cache.insert(
fn_id,
fw_compiler_id,
bw_compiler_id,
num_tensor_args,
hasher_type,
cached_res,
*flat_args_for_cache,
)
cached_fn, out_spec = cached_res
out = cached_fn(flat_tensor_args)
return out_spec.unflatten(out)
return returned_function
def num_of_recompilations():
"""
Returns the numbers of recompilations since the last time cache was cleared.
This is equivalent to the number of entries in the compilation cache.
"""
global compile_cache
if compile_cache is None:
return 0
return compile_cache.size()
def clear_compile_cache():
"""
Clears the compilation cache.
"""
global compile_cache
if compile_cache is not None:
compile_cache.clear()
compile_cache = None
def aot_module(mod: nn.Module, *args, **kwargs) -> nn.Module:
"""
Traces the forward and backward graph of :attr:`mod` using torch dispatch
tracing mechanism. It is wrapper function, that underneath uses
:func:`aot_function` to perform tracing and compilation.
:func:`aot_module` lifts the parameters and buffers of ``nn.Module`` as inputs
to a new callable which is then compiled through :func:`aot_function`.
.. warning::
This API is experimental and likely to change.
Args:
mod (Callable): A ``nn.Module`` module.
args : args to be passed to :func:`aot_function`
kwargs : kwargs to be passed to :func:`aot_function`
Returns:
Returns a ``nn.Module`` that retains the eager behavior of the original
:attr:`mod`, but with forward and backward graph compiled.
"""
def functional_call(named_params, named_buffers, *args, **kwargs):
params_and_buffers = {**named_params, **named_buffers}
return stateless.functional_call(mod, params_and_buffers, args, kwargs)
compiled_f = aot_function(functional_call, *args, **kwargs)
class AOTModule(nn.Module):
def __init__(self):
super(AOTModule, self).__init__()
self.orig_module = mod
def forward(self, *args, **kwargs):
return compiled_f(
dict(_named_parameters(mod, remove_duplicate=False)),
dict(_named_buffers(mod, remove_duplicate=False)),
*args,
**kwargs,
)
return AOTModule()
def aot_module_simplified(mod: nn.Module, *top_args, **top_kwargs) -> nn.Module:
"""
This is the simplified or low overhead version of aot_module. For frontends
like TorchDynamo, the input functions/modules to AOT are static and have
unpacked inputs/outputs. This gives us an opportunity to remove the
(1) pytree overhead to parse inputs/outputs,
(2) AOT Autograd cache,
(3) Reading of params/buffers in every forward call
:func:`aot_module_simplified` removes these overheads.
"""
#########################################################
params = {
**dict(_named_parameters(mod, remove_duplicate=False)),
**dict(_named_buffers(mod, remove_duplicate=False)),
}
params_flat, params_spec = pytree.tree_flatten(params)
params_flat = tuple(params_flat)
params_len = len(params_flat)
def functional_call(*args, **kwargs):
with stateless._reparametrize_module(
mod, pytree.tree_unflatten(args[:params_len], params_spec)
):
if isinstance(mod, torch.fx.GraphModule):
with fx_traceback.override_stack_trace(), warnings.catch_warnings():
warnings.filterwarnings(
"ignore", "Anomaly Detection has been enabled."
)
with torch.autograd.detect_anomaly(check_nan=False):
out = Interpreter(mod).run(*args[params_len:], **kwargs)
else:
out = mod(*args[params_len:], **kwargs)
if not isinstance(out, (tuple, list)):
raise RuntimeError(
"Graph output must be a tuple(). This is so that we can avoid "
"pytree processing of the ouputs. Please change the module to "
"have tuple outputs or use aot_module instead."
)
return out
def aot_function_simplified(
fn: Callable,
fw_compiler: Callable,
bw_compiler: Optional[Callable] = None,
partition_fn: Callable = default_partition,
decompositions: Optional[Dict] = None,
hasher_type: str = "StaticShapeHasher",
static_argnums: Optional[Tuple[int]] = None,
) -> Callable:
assert static_argnums is None
if bw_compiler is None:
bw_compiler = fw_compiler
aot_config = AOTConfig(
fw_compiler=fw_compiler,
bw_compiler=bw_compiler,
partition_fn=partition_fn,
decompositions=decompositions,
)
compiled_fn = None
@wraps(fn)
def new_func(*args):
nonlocal compiled_fn
if compiled_fn is None:
compiled_fn = create_aot_dispatcher_function(
fn,
args,
aot_config,
)
return compiled_fn(args)
return new_func
compiled_f = aot_function_simplified(functional_call, *top_args, **top_kwargs)
if top_kwargs:
def forward(*args, **kwargs):
return compiled_f(
*params_flat,
*args,
**kwargs,
)
else:
def forward(*args):
return compiled_f(
*params_flat,
*args,
)
forward.zero_grad = mod.zero_grad
forward.named_parameters = mod.named_parameters
return forward
compiled_function = aot_function
compiled_module = aot_module