blob: a0bfc0b05ed6641a927d582da38ae74295398fda [file] [log] [blame]
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import torch
import functools
from torch import Tensor
from typing import Any, Callable, Optional, Tuple, Union, List
from torch.utils._pytree import tree_flatten, tree_unflatten, _broadcast_to_and_flatten, TreeSpec
from .pytree_hacks import tree_map_
from functools import partial
import os
import itertools
from torch._C._functorch import (
_add_batch_dim,
_remove_batch_dim,
_vmap_decrement_nesting,
_vmap_increment_nesting,
is_batchedtensor,
)
in_dims_t = Union[int, Tuple]
out_dims_t = Union[int, Tuple[int, ...]]
def doesnt_support_saved_tensors_hooks(f):
message = (
"torch.func transforms don't yet support saved tensor hooks. "
"Please open an issue with your use case."
)
@functools.wraps(f)
def fn(*args, **kwargs):
with torch.autograd.graph.disable_saved_tensors_hooks(message):
return f(*args, **kwargs)
return fn
# Checks that all args-to-be-batched have the same batch dim size
def _validate_and_get_batch_size(
flat_in_dims: List[Optional[int]],
flat_args: List) -> int:
batch_sizes = [arg.size(in_dim) for in_dim, arg in zip(flat_in_dims, flat_args)
if in_dim is not None]
if len(batch_sizes) == 0:
raise ValueError('vmap: Expected at least one Tensor to vmap over')
if batch_sizes and any(size != batch_sizes[0] for size in batch_sizes):
raise ValueError(
f'vmap: Expected all tensors to have the same size in the mapped '
f'dimension, got sizes {batch_sizes} for the mapped dimension')
return batch_sizes[0]
def _num_outputs(batched_outputs: Union[Tensor, Tuple[Tensor, ...]]) -> int:
if isinstance(batched_outputs, tuple):
return len(batched_outputs)
return 1
# If value is a tuple, check it has length `num_elements`.
# If value is not a tuple, make a tuple with `value` repeated `num_elements` times
def _as_tuple(value: Any, num_elements: int, error_message_lambda: Callable[[], str]) -> Tuple:
if not isinstance(value, tuple):
return (value,) * num_elements
if len(value) != num_elements:
raise ValueError(error_message_lambda())
return value
def _process_batched_inputs(
in_dims: in_dims_t, args: Tuple, func: Callable
) -> Tuple[int, List[Any], List[Any], TreeSpec]:
if not isinstance(in_dims, int) and not isinstance(in_dims, tuple):
raise ValueError(
f'vmap({_get_name(func)}, in_dims={in_dims}, ...)(<inputs>): '
f'expected `in_dims` to be int or a (potentially nested) tuple '
f'matching the structure of inputs, got: {type(in_dims)}.')
if len(args) == 0:
raise ValueError(
f'vmap({_get_name(func)})(<inputs>): got no inputs. Maybe you forgot to add '
f'inputs, or you are trying to vmap over a function with no inputs. '
f'The latter is unsupported.')
flat_args, args_spec = tree_flatten(args)
flat_in_dims = _broadcast_to_and_flatten(in_dims, args_spec)
if flat_in_dims is None:
raise ValueError(
f'vmap({_get_name(func)}, in_dims={in_dims}, ...)(<inputs>): '
f'in_dims is not compatible with the structure of `inputs`. '
f'in_dims has structure {tree_flatten(in_dims)[1]} but inputs '
f'has structure {args_spec}.')
for i, (arg, in_dim) in enumerate(zip(flat_args, flat_in_dims)):
if not isinstance(in_dim, int) and in_dim is not None:
raise ValueError(
f'vmap({_get_name(func)}, in_dims={in_dims}, ...)(<inputs>): '
f'Got in_dim={in_dim} for an input but in_dim must be either '
f'an integer dimension or None.')
if isinstance(in_dim, int) and not isinstance(arg, Tensor):
raise ValueError(
f'vmap({_get_name(func)}, in_dims={in_dims}, ...)(<inputs>): '
f'Got in_dim={in_dim} for an input but the input is of type '
f'{type(arg)}. We cannot vmap over non-Tensor arguments, '
f'please use None as the respective in_dim')
if in_dim is not None and (in_dim < -arg.dim() or in_dim >= arg.dim()):
raise ValueError(
f'vmap({_get_name(func)}, in_dims={in_dims}, ...)(<inputs>): '
f'Got in_dim={in_dim} for some input, but that input is a Tensor '
f'of dimensionality {arg.dim()} so expected in_dim to satisfy '
f'-{arg.dim()} <= in_dim < {arg.dim()}.')
if in_dim is not None and in_dim < 0:
flat_in_dims[i] = in_dim % arg.dim()
return _validate_and_get_batch_size(flat_in_dims, flat_args), flat_in_dims, flat_args, args_spec
# Creates BatchedTensors for every Tensor in arg that should be batched.
# Returns the (potentially) batched arguments and the batch_size.
def _create_batched_inputs(
flat_in_dims: List[Any], flat_args: List[Any], vmap_level: int, args_spec) -> Tuple:
# See NOTE [Ignored _remove_batch_dim, _add_batch_dim]
batched_inputs = [arg if in_dim is None else
_add_batch_dim(arg, in_dim, vmap_level)
for in_dim, arg in zip(flat_in_dims, flat_args)]
return tree_unflatten(batched_inputs, args_spec)
def _maybe_remove_batch_dim(name, batched_output, vmap_level, batch_size, out_dim):
if out_dim is None:
if isinstance(batched_output, torch.Tensor) and is_batchedtensor(batched_output):
raise ValueError(
f'vmap({name}, ...): `{name}` can not return a '
f'BatchedTensor when out_dim is None'
)
return batched_output
# out_dim is non None
if not isinstance(batched_output, torch.Tensor):
raise ValueError(f'vmap({name}, ...): `{name}` must only return '
f'Tensors, got type {type(batched_output)}. '
'Did you mean to set out_dim= to None for output?')
return _remove_batch_dim(batched_output, vmap_level, batch_size, out_dim)
# Undos the batching (and any batch dimensions) associated with the `vmap_level`.
def _unwrap_batched(
batched_outputs: Union[Tensor, Tuple[Tensor, ...]],
out_dims: out_dims_t,
vmap_level: int, batch_size: int, func: Callable) -> Tuple:
flat_batched_outputs, output_spec = tree_flatten(batched_outputs)
def incompatible_error():
raise ValueError(
f'vmap({_get_name(func)}, ..., out_dims={out_dims})(<inputs>): '
f'out_dims is not compatible with the structure of `outputs`. '
f'out_dims has structure {tree_flatten(out_dims)[1]} but outputs '
f'has structure {output_spec}.')
if isinstance(batched_outputs, torch.Tensor):
# Some weird edge case requires us to spell out the following
# see test_out_dims_edge_case
if isinstance(out_dims, int):
flat_out_dims = [out_dims]
elif isinstance(out_dims, tuple) and len(out_dims) == 1:
flat_out_dims = out_dims
elif out_dims is None:
flat_out_dims = [out_dims]
else:
incompatible_error()
else:
flat_out_dims = _broadcast_to_and_flatten(out_dims, output_spec)
if flat_out_dims is None:
incompatible_error()
flat_outputs = [
_maybe_remove_batch_dim(_get_name(func), batched_output, vmap_level, batch_size, out_dim)
for batched_output, out_dim in zip(flat_batched_outputs, flat_out_dims)
]
return tree_unflatten(flat_outputs, output_spec)
def _check_int_or_none(x, func, out_dims):
if isinstance(x, int):
return
if x is None:
return
raise ValueError(
f'vmap({_get_name(func)}, ..., out_dims={out_dims}): `out_dims` must be '
f'an int, None or a python collection of ints representing where in the outputs the '
f'vmapped dimension should appear.')
def _check_out_dims_is_int_or_int_pytree(out_dims: out_dims_t, func: Callable) -> None:
if isinstance(out_dims, int):
return
tree_map_(partial(_check_int_or_none, func=func, out_dims=out_dims), out_dims)
def _get_name(func: Callable):
if hasattr(func, '__name__'):
return func.__name__
# Not all callables have __name__, in fact, only static functions/methods do.
# A callable created via functools.partial or an nn.Module, to name some
# examples, don't have a __name__.
return repr(func)
DECOMPOSITIONS_LOADED = False
VMAP_DECOMPOSITIONS_LIB = None
# torch.package, Python 3.11, and torch.jit-less environments are unhappy with
# decompositions. Only load them when needed if possible.
def lazy_load_decompositions():
global DECOMPOSITIONS_LOADED
if DECOMPOSITIONS_LOADED:
return
DECOMPOSITIONS_LOADED = True
if not (os.environ.get("PYTORCH_JIT", "1") == "1" and __debug__):
return
# use an alternate way to register an operator into the decomposition table
# _register_jit_decomposition doesn't work for some operators, e.g. addr,
# because the Tensor types generated cannot be unioned by torchscript
# decomp should be type OpOverload
global VMAP_DECOMPOSITIONS_LIB
VMAP_DECOMPOSITIONS_LIB = torch.library.Library("aten", "IMPL", "FuncTorchBatched")
from torch._decomp import decomposition_table
def _register_python_decomposition_vmap(decomp):
if decomp in decomposition_table:
VMAP_DECOMPOSITIONS_LIB.impl(decomp, decomposition_table[decomp])
else:
raise RuntimeError(f"could not find decomposition for {decomp}")
_register_python_decomposition_vmap(torch.ops.aten.mse_loss_backward.default)
_register_python_decomposition_vmap(torch.ops.aten.smooth_l1_loss_backward.default)
_register_python_decomposition_vmap(torch.ops.aten.huber_loss_backward.default)
_register_python_decomposition_vmap(torch.ops.aten.nll_loss_forward.default)
_register_python_decomposition_vmap(torch.ops.aten.nll_loss2d_forward.default)
_register_python_decomposition_vmap(torch.ops.aten.nll_loss_backward.default)
_register_python_decomposition_vmap(torch.ops.aten.nll_loss2d_backward.default)
_register_python_decomposition_vmap(torch.ops.aten.addr.default)
def vmap_impl(func, in_dims, out_dims, randomness, chunk_size, *args, **kwargs):
lazy_load_decompositions()
_check_out_dims_is_int_or_int_pytree(out_dims, func)
batch_size, flat_in_dims, flat_args, args_spec = _process_batched_inputs(in_dims, args, func)
if chunk_size is not None:
chunks_flat_args = _get_chunked_inputs(flat_args, flat_in_dims, batch_size, chunk_size)
return _chunked_vmap(func, flat_in_dims, chunks_flat_args,
args_spec, out_dims, randomness, **kwargs)
# If chunk_size is not specified.
return _flat_vmap(
func, batch_size, flat_in_dims, flat_args, args_spec, out_dims, randomness, **kwargs
)
def get_chunk_sizes(total_elems, chunk_size):
n_chunks = n_chunks = total_elems // chunk_size
chunk_sizes = [chunk_size] * n_chunks
# remainder chunk
remainder = total_elems % chunk_size
if remainder != 0:
chunk_sizes.append(remainder)
return chunk_sizes
def _get_chunked_inputs(flat_args, flat_in_dims, batch_size, chunk_size):
split_idxs = (batch_size,)
if chunk_size is not None:
chunk_sizes = get_chunk_sizes(batch_size, chunk_size)
split_idxs = tuple(itertools.accumulate(chunk_sizes))
flat_args_chunks = tuple(
t.tensor_split(split_idxs, dim=in_dim) if in_dim is not None else [t, ] * len(split_idxs)
for t, in_dim in zip(flat_args, flat_in_dims)
)
# transpose chunk dim and flatten structure
# chunks_flat_args is a list of flatten args
chunks_flat_args = zip(*flat_args_chunks)
return chunks_flat_args
def _flatten_chunks_output(chunks_output_):
# chunks_output is a list of chunked outputs
# flatten chunked outputs:
flat_chunks_output = []
arg_spec = None
for output in chunks_output_:
flat_output, arg_specs = tree_flatten(output)
flat_chunks_output.append(flat_output)
if arg_spec is None:
arg_spec = arg_specs
# transpose chunk dim and flatten structure
# flat_output_chunks is flat list of chunks
flat_output_chunks = list(zip(*flat_chunks_output))
return flat_output_chunks, arg_spec
def _concat_chunked_outputs(out_dims, arg_spec, flat_output_chunks):
# concat chunks on out_dim
flat_out_dims = _broadcast_to_and_flatten(out_dims, arg_spec)
assert len(flat_out_dims) == len(flat_output_chunks)
flat_output = []
for idx, out_dim in enumerate(flat_out_dims):
flat_output.append(torch.cat(flat_output_chunks[idx], dim=out_dim))
# release tensors
flat_output_chunks[idx] = None
return flat_output
# Applies vmap on chunked_input and returns concatenated output over the chunks.
def _chunked_vmap(func, flat_in_dims, chunks_flat_args, args_spec, out_dims, randomness, **kwargs):
chunks_output = []
rs = torch.get_rng_state() if randomness == "same" else None
for flat_args in chunks_flat_args:
batch_size = _validate_and_get_batch_size(flat_in_dims, flat_args)
# The way we compute split the input in `_get_chunked_inputs`,
# we may get a tensor with `0` batch-size. We skip any computation
# in that case.
# Eg.
# >>> chunk_size = 1
# >>> batch_size = 6
# >>> t = torch.zeros(batch_size, 1)
# >>> t.tensor_split([1, 2, 3, 4, 5, 6])
# (tensor([[0.]]), tensor([[0.]]), tensor([[0.]]), tensor([[0.]]),
# tensor([[0.]]), tensor([[0.]]), tensor([], size=(0, 1)))
if batch_size == 0:
continue
if rs is not None:
torch.set_rng_state(rs)
chunks_output.append(
_flat_vmap(
func, batch_size, flat_in_dims, flat_args, args_spec, out_dims, randomness, **kwargs
)
)
flat_output_chunks, arg_spec = _flatten_chunks_output(chunks_output)
# chunked output tensors are held by both `flat_output_chunks` and `chunks_output`.
# eagerly remove the reference from `chunks_output`.
del chunks_output
# concat chunks on out_dim
flat_output = _concat_chunked_outputs(out_dims, arg_spec, flat_output_chunks)
# finally unflatten the output
return tree_unflatten(flat_output, arg_spec)
# Vmap refactored helper funcions:
def _check_randomness_arg(randomness):
if randomness not in ['error', 'different', 'same']:
raise RuntimeError(f"Only allowed values for randomness are 'error', 'different', or 'same'. Got {randomness}")
@doesnt_support_saved_tensors_hooks
def _flat_vmap(func, batch_size, flat_in_dims, flat_args, args_spec, out_dims, randomness, **kwargs):
vmap_level = _vmap_increment_nesting(batch_size, randomness)
try:
batched_inputs = _create_batched_inputs(flat_in_dims, flat_args, vmap_level, args_spec)
batched_outputs = func(*batched_inputs, **kwargs)
return _unwrap_batched(batched_outputs, out_dims, vmap_level, batch_size, func)
finally:
_vmap_decrement_nesting()
# `restore_vmap` is a private helper function. It is vmap but has the following
# differences:
# - instead of returning outputs, it returns an (outputs, out_dims) tuple.
# out_dims is a pytree of same shape as outputs and contains Optional[int]
# specifying where the vmapped dimension, if it exists, is in the corresponding output.
# - does no validation on in_dims or inputs (vmap expects at least one Tensor to be vmapped).
# restore_vmap allows for no inputs to have the vmap dimension
# - does no validation on outputs (vmap expects only Tensor outputs)
# restore_vmap allows for return of arbitrary outputs (not just Tensors)
#
# The TL;DR is that restore_vmap is more general than vmap and has a slightly
# different API. The relaxations are so that we can "pause" vmap in the middle
# of its execution and then "restore" it later (this is what we do in
# the generate_vmap_rule=True implementation of autograd.Function).
#
# restore_vmap can be technically used in the implementation of vmap, but doing
# that refactor is a bit technically challenging because:
# - vmap couples the tensor-wrapping code with error checking
# - vmap's tensor unwrapping code is in C++; we would need to rewrite part of it
# in python because it overlaps with unwrap_batched
@doesnt_support_saved_tensors_hooks
def restore_vmap(func, in_dims, batch_size, randomness):
def inner(*args, **kwargs):
vmap_level = _vmap_increment_nesting(batch_size, randomness)
try:
batched_inputs = wrap_batched(args, in_dims, vmap_level)
batched_outputs = func(*batched_inputs, **kwargs)
return unwrap_batched(batched_outputs, vmap_level)
finally:
_vmap_decrement_nesting()
return inner
def wrap_batched(args, bdims, level):
flat_args, spec = tree_flatten(args)
flat_bdims = _broadcast_to_and_flatten(bdims, spec)
assert flat_bdims is not None
result = _create_batched_inputs(flat_bdims, flat_args, level, spec)
return result
def unwrap_batched(args, level):
flat_args, spec = tree_flatten(args)
if len(flat_args) == 0:
return args, ()
result = [torch._C._functorch._unwrap_batched(arg, level) if isinstance(arg, torch.Tensor)
else (arg, None) for arg in flat_args]
output, bdims = zip(*result)
return tree_unflatten(output, spec), tree_unflatten(bdims, spec)