blob: efb0f6ed0b81cac91f3b52d75f8ea419660f0b45 [file] [log] [blame]
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import torch
import functools
from torch import Tensor
from typing import Any, Callable, Optional, Tuple, Union, List
from torch.utils._pytree import tree_flatten, tree_unflatten, _broadcast_to_and_flatten, TreeSpec
from .pytree_hacks import tree_map_
from functools import partial
import os
import itertools
from torch._C._functorch import (
_add_batch_dim,
_remove_batch_dim,
_vmap_decrement_nesting,
_vmap_increment_nesting,
is_batchedtensor,
)
from torch._functorch.utils import exposed_in
in_dims_t = Union[int, Tuple]
out_dims_t = Union[int, Tuple[int, ...]]
def doesnt_support_saved_tensors_hooks(f):
message = (
"torch.func transforms don't yet support saved tensor hooks. "
"Please open an issue with your use case."
)
@functools.wraps(f)
def fn(*args, **kwargs):
with torch.autograd.graph.disable_saved_tensors_hooks(message):
return f(*args, **kwargs)
return fn
# Checks that all args-to-be-batched have the same batch dim size
def _validate_and_get_batch_size(
flat_in_dims: List[Optional[int]],
flat_args: List) -> int:
batch_sizes = [arg.size(in_dim) for in_dim, arg in zip(flat_in_dims, flat_args)
if in_dim is not None]
if len(batch_sizes) == 0:
raise ValueError('vmap: Expected at least one Tensor to vmap over')
if batch_sizes and any(size != batch_sizes[0] for size in batch_sizes):
raise ValueError(
f'vmap: Expected all tensors to have the same size in the mapped '
f'dimension, got sizes {batch_sizes} for the mapped dimension')
return batch_sizes[0]
def _num_outputs(batched_outputs: Union[Tensor, Tuple[Tensor, ...]]) -> int:
if isinstance(batched_outputs, tuple):
return len(batched_outputs)
return 1
# If value is a tuple, check it has length `num_elements`.
# If value is not a tuple, make a tuple with `value` repeated `num_elements` times
def _as_tuple(value: Any, num_elements: int, error_message_lambda: Callable[[], str]) -> Tuple:
if not isinstance(value, tuple):
return (value,) * num_elements
if len(value) != num_elements:
raise ValueError(error_message_lambda())
return value
def _process_batched_inputs(
in_dims: in_dims_t, args: Tuple, func: Callable
) -> Tuple[int, List[Any], List[Any], TreeSpec]:
if not isinstance(in_dims, int) and not isinstance(in_dims, tuple):
raise ValueError(
f'vmap({_get_name(func)}, in_dims={in_dims}, ...)(<inputs>): '
f'expected `in_dims` to be int or a (potentially nested) tuple '
f'matching the structure of inputs, got: {type(in_dims)}.')
if len(args) == 0:
raise ValueError(
f'vmap({_get_name(func)})(<inputs>): got no inputs. Maybe you forgot to add '
f'inputs, or you are trying to vmap over a function with no inputs. '
f'The latter is unsupported.')
flat_args, args_spec = tree_flatten(args)
flat_in_dims = _broadcast_to_and_flatten(in_dims, args_spec)
if flat_in_dims is None:
raise ValueError(
f'vmap({_get_name(func)}, in_dims={in_dims}, ...)(<inputs>): '
f'in_dims is not compatible with the structure of `inputs`. '
f'in_dims has structure {tree_flatten(in_dims)[1]} but inputs '
f'has structure {args_spec}.')
for i, (arg, in_dim) in enumerate(zip(flat_args, flat_in_dims)):
if not isinstance(in_dim, int) and in_dim is not None:
raise ValueError(
f'vmap({_get_name(func)}, in_dims={in_dims}, ...)(<inputs>): '
f'Got in_dim={in_dim} for an input but in_dim must be either '
f'an integer dimension or None.')
if isinstance(in_dim, int) and not isinstance(arg, Tensor):
raise ValueError(
f'vmap({_get_name(func)}, in_dims={in_dims}, ...)(<inputs>): '
f'Got in_dim={in_dim} for an input but the input is of type '
f'{type(arg)}. We cannot vmap over non-Tensor arguments, '
f'please use None as the respective in_dim')
if in_dim is not None and (in_dim < -arg.dim() or in_dim >= arg.dim()):
raise ValueError(
f'vmap({_get_name(func)}, in_dims={in_dims}, ...)(<inputs>): '
f'Got in_dim={in_dim} for some input, but that input is a Tensor '
f'of dimensionality {arg.dim()} so expected in_dim to satisfy '
f'-{arg.dim()} <= in_dim < {arg.dim()}.')
if in_dim is not None and in_dim < 0:
flat_in_dims[i] = in_dim % arg.dim()
return _validate_and_get_batch_size(flat_in_dims, flat_args), flat_in_dims, flat_args, args_spec
# Creates BatchedTensors for every Tensor in arg that should be batched.
# Returns the (potentially) batched arguments and the batch_size.
def _create_batched_inputs(
flat_in_dims: List[Any], flat_args: List[Any], vmap_level: int, args_spec) -> Tuple:
# See NOTE [Ignored _remove_batch_dim, _add_batch_dim]
batched_inputs = [arg if in_dim is None else
_add_batch_dim(arg, in_dim, vmap_level)
for in_dim, arg in zip(flat_in_dims, flat_args)]
return tree_unflatten(batched_inputs, args_spec)
def _maybe_remove_batch_dim(name, batched_output, vmap_level, batch_size, out_dim):
if out_dim is None:
if isinstance(batched_output, torch.Tensor) and is_batchedtensor(batched_output):
raise ValueError(
f'vmap({name}, ...): `{name}` can not return a '
f'BatchedTensor when out_dim is None'
)
return batched_output
# out_dim is non None
if not isinstance(batched_output, torch.Tensor):
raise ValueError(f'vmap({name}, ...): `{name}` must only return '
f'Tensors, got type {type(batched_output)}. '
'Did you mean to set out_dim= to None for output?')
return _remove_batch_dim(batched_output, vmap_level, batch_size, out_dim)
# Undos the batching (and any batch dimensions) associated with the `vmap_level`.
def _unwrap_batched(
batched_outputs: Union[Tensor, Tuple[Tensor, ...]],
out_dims: out_dims_t,
vmap_level: int, batch_size: int, func: Callable) -> Tuple:
flat_batched_outputs, output_spec = tree_flatten(batched_outputs)
def incompatible_error():
raise ValueError(
f'vmap({_get_name(func)}, ..., out_dims={out_dims})(<inputs>): '
f'out_dims is not compatible with the structure of `outputs`. '
f'out_dims has structure {tree_flatten(out_dims)[1]} but outputs '
f'has structure {output_spec}.')
if isinstance(batched_outputs, torch.Tensor):
# Some weird edge case requires us to spell out the following
# see test_out_dims_edge_case
if isinstance(out_dims, int):
flat_out_dims = [out_dims]
elif isinstance(out_dims, tuple) and len(out_dims) == 1:
flat_out_dims = out_dims
elif out_dims is None:
flat_out_dims = [out_dims]
else:
incompatible_error()
else:
flat_out_dims = _broadcast_to_and_flatten(out_dims, output_spec)
if flat_out_dims is None:
incompatible_error()
flat_outputs = [
_maybe_remove_batch_dim(_get_name(func), batched_output, vmap_level, batch_size, out_dim)
for batched_output, out_dim in zip(flat_batched_outputs, flat_out_dims)
]
return tree_unflatten(flat_outputs, output_spec)
def _check_int_or_none(x, func, out_dims):
if isinstance(x, int):
return
if x is None:
return
raise ValueError(
f'vmap({_get_name(func)}, ..., out_dims={out_dims}): `out_dims` must be '
f'an int, None or a python collection of ints representing where in the outputs the '
f'vmapped dimension should appear.')
def _check_out_dims_is_int_or_int_pytree(out_dims: out_dims_t, func: Callable) -> None:
if isinstance(out_dims, int):
return
tree_map_(partial(_check_int_or_none, func=func, out_dims=out_dims), out_dims)
def _get_name(func: Callable):
if hasattr(func, '__name__'):
return func.__name__
# Not all callables have __name__, in fact, only static functions/methods do.
# A callable created via functools.partial or an nn.Module, to name some
# examples, don't have a __name__.
return repr(func)
DECOMPOSITIONS_LOADED = False
VMAP_DECOMPOSITIONS_LIB = None
# torch.package, Python 3.11, and torch.jit-less environments are unhappy with
# decompositions. Only load them when needed if possible.
def lazy_load_decompositions():
global DECOMPOSITIONS_LOADED
if DECOMPOSITIONS_LOADED:
return
DECOMPOSITIONS_LOADED = True
if not (os.environ.get("PYTORCH_JIT", "1") == "1" and __debug__):
return
# use an alternate way to register an operator into the decomposition table
# _register_jit_decomposition doesn't work for some operators, e.g. addr,
# because the Tensor types generated cannot be unioned by torchscript
# decomp should be type OpOverload
global VMAP_DECOMPOSITIONS_LIB
VMAP_DECOMPOSITIONS_LIB = torch.library.Library("aten", "IMPL", "FuncTorchBatched")
from torch._decomp import decomposition_table
def _register_python_decomposition_vmap(decomp):
if decomp in decomposition_table:
VMAP_DECOMPOSITIONS_LIB.impl(decomp, decomposition_table[decomp])
else:
raise RuntimeError(f"could not find decomposition for {decomp}")
_register_python_decomposition_vmap(torch.ops.aten.mse_loss_backward.default)
_register_python_decomposition_vmap(torch.ops.aten.addr.default)
# vmap(func)(inputs) wraps all Tensor inputs to be batched in BatchedTensors,
# sends those into func, and then unwraps the output BatchedTensors. Operations
# on BatchedTensors perform the batched operations that the user is asking for.
#
# vmap's randomness behavior differs from JAX's, which would require a PRNG key
# to be passed everywhere.
@exposed_in('torch.func')
def vmap(
func: Callable,
in_dims: in_dims_t = 0,
out_dims: out_dims_t = 0,
randomness: str = 'error',
*,
chunk_size=None) -> Callable:
"""
vmap is the vectorizing map; ``vmap(func)`` returns a new function that
maps ``func`` over some dimension of the inputs. Semantically, vmap
pushes the map into PyTorch operations called by ``func``, effectively
vectorizing those operations.
vmap is useful for handling batch dimensions: one can write a function
``func`` that runs on examples and then lift it to a function that can
take batches of examples with ``vmap(func)``. vmap can also be used to
compute batched gradients when composed with autograd.
.. note::
:func:`torch.vmap` is aliased to :func:`torch.func.vmap` for
convenience. Use whichever one you'd like.
Args:
func (function): A Python function that takes one or more arguments.
Must return one or more Tensors.
in_dims (int or nested structure): Specifies which dimension of the
inputs should be mapped over. ``in_dims`` should have a
structure like the inputs. If the ``in_dim`` for a particular
input is None, then that indicates there is no map dimension.
Default: 0.
out_dims (int or Tuple[int]): Specifies where the mapped dimension
should appear in the outputs. If ``out_dims`` is a Tuple, then
it should have one element per output. Default: 0.
randomness (str): Specifies whether the randomness in this
vmap should be the same or different across batches. If 'different',
the randomness for each batch will be different. If 'same', the
randomness will be the same across batches. If 'error', any calls to
random functions will error. Default: 'error'. WARNING: this flag
only applies to random PyTorch operations and does not apply to
Python's random module or numpy randomness.
chunk_size (None or int): If None (default), apply a single vmap over inputs.
If not None, then compute the vmap :attr:`chunk_size` samples at a time.
Note that :attr:`chunk_size=1` is equivalent to computing the vmap with a for-loop.
If you run into memory issues computing the vmap, please try a non-None chunk_size.
Returns:
Returns a new "batched" function. It takes the same inputs as
``func``, except each input has an extra dimension at the index
specified by ``in_dims``. It takes returns the same outputs as
``func``, except each output has an extra dimension at the index
specified by ``out_dims``.
.. warning:
:func:`vmap` works best with functional-style code. Please do not
perform any side-effects in ``func``, with the exception of
in-place PyTorch operations. Examples of side-effects include mutating
Python data structures and assigning values to variables not captured
in ``func``.
One example of using :func:`vmap` is to compute batched dot products. PyTorch
doesn't provide a batched ``torch.dot`` API; instead of unsuccessfully
rummaging through docs, use :func:`vmap` to construct a new function.
>>> torch.dot # [D], [D] -> []
>>> batched_dot = torch.func.vmap(torch.dot) # [N, D], [N, D] -> [N]
>>> x, y = torch.randn(2, 5), torch.randn(2, 5)
>>> batched_dot(x, y)
:func:`vmap` can be helpful in hiding batch dimensions, leading to a simpler
model authoring experience.
>>> batch_size, feature_size = 3, 5
>>> weights = torch.randn(feature_size, requires_grad=True)
>>>
>>> def model(feature_vec):
>>> # Very simple linear model with activation
>>> return feature_vec.dot(weights).relu()
>>>
>>> examples = torch.randn(batch_size, feature_size)
>>> result = torch.vmap(model)(examples)
:func:`vmap` can also help vectorize computations that were previously difficult
or impossible to batch. One example is higher-order gradient computation.
The PyTorch autograd engine computes vjps (vector-Jacobian products).
Computing a full Jacobian matrix for some function f: R^N -> R^N usually
requires N calls to ``autograd.grad``, one per Jacobian row. Using :func:`vmap`,
we can vectorize the whole computation, computing the Jacobian in a single
call to ``autograd.grad``.
>>> # Setup
>>> N = 5
>>> f = lambda x: x ** 2
>>> x = torch.randn(N, requires_grad=True)
>>> y = f(x)
>>> I_N = torch.eye(N)
>>>
>>> # Sequential approach
>>> jacobian_rows = [torch.autograd.grad(y, x, v, retain_graph=True)[0]
>>> for v in I_N.unbind()]
>>> jacobian = torch.stack(jacobian_rows)
>>>
>>> # vectorized gradient computation
>>> def get_vjp(v):
>>> return torch.autograd.grad(y, x, v)
>>> jacobian = torch.vmap(get_vjp)(I_N)
:func:`vmap` can also be nested, producing an output with multiple batched dimensions
>>> torch.dot # [D], [D] -> []
>>> batched_dot = torch.vmap(torch.vmap(torch.dot)) # [N1, N0, D], [N1, N0, D] -> [N1, N0]
>>> x, y = torch.randn(2, 3, 5), torch.randn(2, 3, 5)
>>> batched_dot(x, y) # tensor of size [2, 3]
If the inputs are not batched along the first dimension, ``in_dims`` specifies
the dimension that each inputs are batched along as
>>> torch.dot # [N], [N] -> []
>>> batched_dot = torch.vmap(torch.dot, in_dims=1) # [N, D], [N, D] -> [D]
>>> x, y = torch.randn(2, 5), torch.randn(2, 5)
>>> batched_dot(x, y) # output is [5] instead of [2] if batched along the 0th dimension
If there are multiple inputs each of which is batched along different dimensions,
``in_dims`` must be a tuple with the batch dimension for each input as
>>> torch.dot # [D], [D] -> []
>>> batched_dot = torch.vmap(torch.dot, in_dims=(0, None)) # [N, D], [D] -> [N]
>>> x, y = torch.randn(2, 5), torch.randn(5)
>>> batched_dot(x, y) # second arg doesn't have a batch dim because in_dim[1] was None
If the input is a Python struct, ``in_dims`` must be a tuple containing a struct
matching the shape of the input:
>>> f = lambda dict: torch.dot(dict['x'], dict['y'])
>>> x, y = torch.randn(2, 5), torch.randn(5)
>>> input = {'x': x, 'y': y}
>>> batched_dot = torch.vmap(f, in_dims=({'x': 0, 'y': None},))
>>> batched_dot(input)
By default, the output is batched along the first dimension. However, it can be batched
along any dimension by using ``out_dims``
>>> f = lambda x: x ** 2
>>> x = torch.randn(2, 5)
>>> batched_pow = torch.vmap(f, out_dims=1)
>>> batched_pow(x) # [5, 2]
For any function that uses kwargs, the returned function will not batch the kwargs but will
accept kwargs
>>> x = torch.randn([2, 5])
>>> def fn(x, scale=4.):
>>> return x * scale
>>>
>>> batched_pow = torch.vmap(fn)
>>> assert torch.allclose(batched_pow(x), x * 4)
>>> batched_pow(x, scale=x) # scale is not batched, output has shape [2, 2, 5]
.. note::
vmap does not provide general autobatching or handle variable-length
sequences out of the box.
"""
_check_randomness_arg(randomness)
if not (chunk_size is None or chunk_size > 0):
raise ValueError(f"vmap: chunk_size should be None or greater than 0. (got {chunk_size})")
@functools.wraps(func)
def wrapped(*args, **kwargs):
lazy_load_decompositions()
_check_out_dims_is_int_or_int_pytree(out_dims, func)
batch_size, flat_in_dims, flat_args, args_spec = _process_batched_inputs(in_dims, args, func)
if chunk_size is not None:
chunks_flat_args = _get_chunked_inputs(flat_args, flat_in_dims, batch_size, chunk_size)
return _chunked_vmap(func, flat_in_dims, chunks_flat_args,
args_spec, out_dims, randomness, **kwargs)
# If chunk_size is not specified.
return _flat_vmap(
func, batch_size, flat_in_dims, flat_args, args_spec, out_dims, randomness, **kwargs
)
return wrapped
def get_chunk_sizes(total_elems, chunk_size):
n_chunks = n_chunks = total_elems // chunk_size
chunk_sizes = [chunk_size] * n_chunks
# remainder chunk
remainder = total_elems % chunk_size
if remainder != 0:
chunk_sizes.append(remainder)
return chunk_sizes
def _get_chunked_inputs(flat_args, flat_in_dims, batch_size, chunk_size):
split_idxs = (batch_size,)
if chunk_size is not None:
chunk_sizes = get_chunk_sizes(batch_size, chunk_size)
split_idxs = tuple(itertools.accumulate(chunk_sizes))
flat_args_chunks = tuple(
t.tensor_split(split_idxs, dim=in_dim) if in_dim is not None else [t, ] * len(split_idxs)
for t, in_dim in zip(flat_args, flat_in_dims)
)
# transpose chunk dim and flatten structure
# chunks_flat_args is a list of flatten args
chunks_flat_args = zip(*flat_args_chunks)
return chunks_flat_args
def _flatten_chunks_output(chunks_output_):
# chunks_output is a list of chunked outputs
# flatten chunked outputs:
flat_chunks_output = []
arg_spec = None
for output in chunks_output_:
flat_output, arg_specs = tree_flatten(output)
flat_chunks_output.append(flat_output)
if arg_spec is None:
arg_spec = arg_specs
# transpose chunk dim and flatten structure
# flat_output_chunks is flat list of chunks
flat_output_chunks = list(zip(*flat_chunks_output))
return flat_output_chunks, arg_spec
def _concat_chunked_outputs(out_dims, arg_spec, flat_output_chunks):
# concat chunks on out_dim
flat_out_dims = _broadcast_to_and_flatten(out_dims, arg_spec)
assert len(flat_out_dims) == len(flat_output_chunks)
flat_output = []
for idx, out_dim in enumerate(flat_out_dims):
flat_output.append(torch.cat(flat_output_chunks[idx], dim=out_dim))
# release tensors
flat_output_chunks[idx] = None
return flat_output
# Applies vmap on chunked_input and returns concatenated output over the chunks.
def _chunked_vmap(func, flat_in_dims, chunks_flat_args, args_spec, out_dims, randomness, **kwargs):
chunks_output = []
rs = torch.get_rng_state() if randomness == "same" else None
for flat_args in chunks_flat_args:
batch_size = _validate_and_get_batch_size(flat_in_dims, flat_args)
# The way we compute split the input in `_get_chunked_inputs`,
# we may get a tensor with `0` batch-size. We skip any computation
# in that case.
# Eg.
# >>> chunk_size = 1
# >>> batch_size = 6
# >>> t = torch.zeros(batch_size, 1)
# >>> t.tensor_split([1, 2, 3, 4, 5, 6])
# (tensor([[0.]]), tensor([[0.]]), tensor([[0.]]), tensor([[0.]]),
# tensor([[0.]]), tensor([[0.]]), tensor([], size=(0, 1)))
if batch_size == 0:
continue
if rs is not None:
torch.set_rng_state(rs)
chunks_output.append(
_flat_vmap(
func, batch_size, flat_in_dims, flat_args, args_spec, out_dims, randomness, **kwargs
)
)
flat_output_chunks, arg_spec = _flatten_chunks_output(chunks_output)
# chunked output tensors are held by both `flat_output_chunks` and `chunks_output`.
# eagerly remove the reference from `chunks_output`.
del chunks_output
# concat chunks on out_dim
flat_output = _concat_chunked_outputs(out_dims, arg_spec, flat_output_chunks)
# finally unflatten the output
return tree_unflatten(flat_output, arg_spec)
def chunk_vmap(
func: Callable,
in_dims: in_dims_t = 0,
out_dims: out_dims_t = 0,
randomness: str = 'error',
chunks=2) -> Callable:
"""
chunk_vmap is the vectorizing map (vmap) using chunks of input data. It is a mix of vmap (which vectorizes
everything) and map (which executes things sequentially). ``chunk_vmap`` vectorizes the input with number of
chunks at a time. For more details about vectorizing map, see :func:`vmap`.
.. note::
Please use :func:`vmap` with ``chunk_size`` argument instead of this API.
Args:
func (function): A Python function that takes one or more arguments.
Must return one or more Tensors.
in_dims (int or nested structure): Specifies which dimension of the
inputs should be mapped over. ``in_dims`` should have a
structure like the inputs. If the ``in_dim`` for a particular
input is None, then that indicates there is no map dimension.
Default: 0.
out_dims (int or Tuple[int]): Specifies where the mapped dimension
should appear in the outputs. If ``out_dims`` is a Tuple, then
it should have one element per output. Default: 0.
randomness (str): Specifies whether the randomness in this
vmap should be the same or different across batches. If 'different',
the randomness for each batch will be different. If 'same', the
randomness will be the same across batches. If 'error', any calls to
random functions will error. Default: 'error'. WARNING: this flag
only applies to random PyTorch operations and does not apply to
Python's random module or numpy randomness.
chunks (int): Number of chunks to use to split the input data. Default is 2.
If equals to 1 then :func:`vmap` is called.
Returns:
Returns a new "batched" function. It takes the same inputs as
``func``, except each input has an extra dimension at the index
specified by ``in_dims``. It takes returns the same outputs as
``func``, except each output has an extra dimension at the index
specified by ``out_dims``.
"""
_check_randomness_arg(randomness)
if chunks == 1:
return vmap(func, in_dims=in_dims, out_dims=out_dims, randomness=randomness)
def _get_chunk_flat_args(flat_args_, flat_in_dims_, chunks_):
flat_args_chunks = tuple(
t.chunk(chunks_, dim=in_dim) if in_dim is not None else [t, ] * chunks_
for t, in_dim in zip(flat_args_, flat_in_dims_)
)
# transpose chunk dim and flatten structure
# chunks_flat_args is a list of flatten args
chunks_flat_args = zip(*flat_args_chunks)
return chunks_flat_args
@functools.wraps(func)
def wrapped_with_chunks(*args, **kwargs):
_check_out_dims_is_int_or_int_pytree(out_dims, func)
_, flat_in_dims, flat_args, args_spec = _process_batched_inputs(in_dims, args, func)
# Chunk flat arguments
chunks_flat_args = _get_chunk_flat_args(flat_args, flat_in_dims, chunks)
# Apply vmap on chunks
return _chunked_vmap(func, flat_in_dims, chunks_flat_args, args_spec, out_dims, randomness, **kwargs)
return wrapped_with_chunks
# Vmap refactored helper funcions:
def _check_randomness_arg(randomness):
if randomness not in ['error', 'different', 'same']:
raise RuntimeError(f"Only allowed values for randomness are 'error', 'different', or 'same'. Got {randomness}")
@doesnt_support_saved_tensors_hooks
def _flat_vmap(func, batch_size, flat_in_dims, flat_args, args_spec, out_dims, randomness, **kwargs):
vmap_level = _vmap_increment_nesting(batch_size, randomness)
try:
batched_inputs = _create_batched_inputs(flat_in_dims, flat_args, vmap_level, args_spec)
batched_outputs = func(*batched_inputs, **kwargs)
return _unwrap_batched(batched_outputs, out_dims, vmap_level, batch_size, func)
finally:
_vmap_decrement_nesting()
# `restore_vmap` is a private helper function. It is vmap but has the following
# differences:
# - instead of returning outputs, it returns an (outputs, out_dims) tuple.
# out_dims is a pytree of same shape as outputs and contains Optional[int]
# specifying where the vmapped dimension, if it exists, is in the corresponding output.
# - does no validation on in_dims or inputs (vmap expects at least one Tensor to be vmapped).
# restore_vmap allows for no inputs to have the vmap dimension
# - does no validation on outputs (vmap expects only Tensor outputs)
# restore_vmap allows for return of arbitrary outputs (not just Tensors)
#
# The TL;DR is that restore_vmap is more general than vmap and has a slightly
# different API. The relaxations are so that we can "pause" vmap in the middle
# of its execution and then "restore" it later (this is what we do in
# the generate_vmap_rule=True implementation of autograd.Function).
#
# restore_vmap can be technically used in the implementation of vmap, but doing
# that refactor is a bit technically challenging because:
# - vmap couples the tensor-wrapping code with error checking
# - vmap's tensor unwrapping code is in C++; we would need to rewrite part of it
# in python because it overlaps with unwrap_batched
@doesnt_support_saved_tensors_hooks
def restore_vmap(func, in_dims, batch_size, randomness):
def inner(*args, **kwargs):
vmap_level = _vmap_increment_nesting(batch_size, randomness)
try:
batched_inputs = wrap_batched(args, in_dims, vmap_level)
batched_outputs = func(*batched_inputs, **kwargs)
return unwrap_batched(batched_outputs, vmap_level)
finally:
_vmap_decrement_nesting()
return inner
def wrap_batched(args, bdims, level):
flat_args, spec = tree_flatten(args)
flat_bdims = _broadcast_to_and_flatten(bdims, spec)
assert flat_bdims is not None
result = _create_batched_inputs(flat_bdims, flat_args, level, spec)
return result
def unwrap_batched(args, level):
flat_args, spec = tree_flatten(args)
if len(flat_args) == 0:
return args, ()
result = [torch._C._functorch._unwrap_batched(arg, level) if isinstance(arg, torch.Tensor)
else (arg, None) for arg in flat_args]
output, bdims = zip(*result)
return tree_unflatten(output, spec), tree_unflatten(bdims, spec)