blob: 1ff5696417f3516379f0e23df8dfa6190e378754 [file] [log] [blame]
# mypy: allow-untyped-defs
import dataclasses
from typing import Any, Callable, Optional, Protocol
from .. import _C, _ops, autograd, Tensor
from ..utils import _pytree
from . import utils
class InfoProtocol(Protocol):
_backward_fn: Optional[Callable]
_setup_context_fn: Optional[Callable]
@dataclasses.dataclass
class Info:
_backward_fn: Optional[Callable]
_setup_context_fn: Optional[Callable]
def make_autograd_impl(op: _ops.OpOverload, info: InfoProtocol) -> Callable:
name: str = f"GeneratedBackwardFor_{op._namespace}_{op._opname}_{op._overloadname}"
saved_keyset = None
saved_keyword_only_args = None
has_kwarg_only_args = utils.has_kwarg_only_args(op._schema)
def forward(ctx, *args):
with _C._AutoDispatchBelowAutograd():
nonlocal saved_keyset, saved_keyword_only_args
keyset = saved_keyset
assert keyset is not None, "Should have been set by autograd_impl"
saved_keyset = None
kwargs = saved_keyword_only_args
assert kwargs is not None, "Should have been set by autograd_impl"
saved_keyword_only_args = None
result = op.redispatch(keyset & _C._after_autograd_keyset, *args, **kwargs)
if info._setup_context_fn:
# The Dispatcher will remove args that are equal to their default
# values from (args, kwargs). We're going to add it back so that
# the user can access them.
#
# This is OK to do: The Dispatcher removed the args for serialization
# FC/BC reasons (that is, a graph will not store args that are equal
# to their default values), but that doesn't matter here. If the user
# adds a new default arg, then they must update
# their setup_context (along with the rest of their operator
# registrations)
args, kwargs = utils.fill_defaults(op._schema, args, kwargs)
if has_kwarg_only_args:
info._setup_context_fn(
ctx=ctx, inputs=args, keyword_only_inputs=kwargs, output=result
)
else:
info._setup_context_fn(ctx=ctx, inputs=args, output=result)
return result
def backward(ctx, *grads):
if info._backward_fn:
result = info._backward_fn(ctx, *grads)
return result
raise RuntimeError(
f"Trying to backward through {op} but no autograd "
f"formula was registered. "
f"Please use register_autograd to add one."
)
Generated = type(
name,
(autograd.Function,),
{
"forward": staticmethod(forward),
"backward": staticmethod(backward),
},
)
schema = op._schema
if any(
utils.is_tensorlist_like_type(a.type)
for a in (*schema.arguments, *schema.returns)
):
Generated = supports_tensorlist(Generated)
# The dispatcher passes any keyword-only-args as kwargs and the
# rest of the args (even if specified as kwargs) as args.
def autograd_impl(keyset, *args, **keyword_only_args):
# We set a nonlocal to ferry keyset from here to the forward.
# This supports recursive calls (we implement the forward carefully so
# that it'll read saved_keyset before making a recursive call to the op).
nonlocal saved_keyset, saved_keyword_only_args
assert saved_keyset is None
saved_keyset = keyset
assert saved_keyword_only_args is None
saved_keyword_only_args = keyword_only_args
result = Generated.apply(*args) # type: ignore[attr-defined]
return result
return autograd_impl
def supports_tensorlist(cls: Any) -> Any:
"""Allows a given autograd.Function class to support List[Tensor] inputs/outputs.
Regular autograd.Function has a constraint that it only directly supports autograd for
Tensors. Applying @supports_tensorlist enables an autograd.Function to support
autograd for List[Tensor] inputs and outputs.
"""
# NB: All calls to the autograd.Function.apply shares these variables
# We assume that only one call to .apply happens at a time. This means that
# you cannot call the autograd.Function recursively (e.g. from its own forward).
input_spec: Optional[spec_t] = None
output_spec: Optional[spec_t] = None
result_is_tuple = None
orig_forward = cls.forward
orig_backward = cls.backward
orig_apply = cls.apply
def new_forward(ctx, *args):
if input_spec is None:
raise NotImplementedError(
"NYI: calling supports_tensorlist autograd.Function.forward directly. "
"You should probably be calling .apply instead. "
"Please file an issue if not."
)
args = unflatten(list(args), input_spec)
result = orig_forward(ctx, *args)
nonlocal output_spec
nonlocal result_is_tuple
result_is_tuple = isinstance(result, tuple)
if not result_is_tuple:
result = (result,)
nonlocal output_spec
flat_result, output_spec = flatten(result, not_list_of_tensor)
# Save the input_spec/output_spec for backward because another call to
# .apply will override the nonlocals.
if hasattr(ctx, "_pt_metadata"):
raise RuntimeError(
"Please don't set ctx._pt_metadata; PyTorch uses it to store info"
)
ctx._pt_metadata = (input_spec, output_spec)
return tuple(flat_result)
def new_backward(ctx, *grads):
if not hasattr(ctx, "_pt_metadata"):
raise NotImplementedError(
"NYI: calling supports_tensorlist autograd.Function.backward directly. "
"This will automatically get called by PyTorch autograd. "
"Please file an issue if you need this."
)
input_spec, output_spec = ctx._pt_metadata
grads = unflatten(list(grads), output_spec)
grad_inputs = orig_backward(ctx, *grads)
if not isinstance(grad_inputs, tuple):
grad_inputs = (grad_inputs,)
# Assume that any Nones in the backward are Tensors.
# If the forward has an arg that is [1, 2, 3], the backward should
# return None as the grad.
# If the forward has an arg that is [tensor, tensor], the backward
# may return [None, None], [grad, None], [None, grad], or [grad, grad].
flat_grad_inputs, grad_inputs_spec = flatten(
grad_inputs, not_list_of_optional_tensor
)
if grad_inputs_spec != input_spec:
raise RuntimeError(
f"Expected the return from backward to be of the same structure "
f"as the inputs. Got: {grad_inputs_spec} (return from backward), "
f"{input_spec} (inputs)"
)
return tuple(flat_grad_inputs)
def new_apply(*args):
nonlocal input_spec
if input_spec is not None:
raise NotImplementedError(
"NYI: Recursive call to autograd.Function decorated with "
"`supports_tensorlist`. Please file an issue."
)
try:
flat_args, input_spec = flatten(args, is_leaf=not_list_of_tensor)
result = orig_apply(*flat_args) # type: ignore[misc]
finally:
input_spec = None
assert output_spec is not None
result = unflatten(list(result), output_spec)
if not result_is_tuple:
assert isinstance(result, tuple)
assert len(result) == 1
return result[0]
return result
cls.forward = new_forward
cls.backward = new_backward
cls.apply = new_apply
return cls
def not_list_of_tensor(tree):
if isinstance(tree, tuple):
return False
if isinstance(tree, list):
return any(not isinstance(l, Tensor) for l in tree)
return True
def not_list_of_optional_tensor(tree):
if isinstance(tree, tuple):
return False
if isinstance(tree, list):
return any(l is not None and not isinstance(l, Tensor) for l in tree)
return True
flatten = _pytree.tree_flatten
unflatten = _pytree.tree_unflatten
spec_t = _pytree.TreeSpec