blob: 1d7356b6ca7e53f557ece9ca1fd172a1cbe8c300 [file] [log] [blame]
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import itertools
import torch
from functorch import vmap
import torch.utils._pytree as pytree
from functorch_additional_op_db import additional_op_db
from torch.testing._internal.common_methods_invocations import DecorateInfo
from torch.testing._internal.common_methods_invocations import op_db
import os
import unittest
from torch.testing._internal.common_device_type import toleranceOverride
from collections import namedtuple
IS_FBCODE = os.getenv('FUNCTORCH_TEST_FBCODE') == '1'
def loop(op, in_dims, out_dim, batch_size, *batched_args, **kwarg_values):
outs = []
for idx in range(batch_size):
flat_args, args_spec = pytree.tree_flatten(batched_args)
flat_dims, dims_spec = pytree.tree_flatten(in_dims)
assert(args_spec == dims_spec)
new_args = [a.select(in_dim, idx) if in_dim is not None else a for a, in_dim in zip(flat_args, flat_dims)]
out = op(*pytree.tree_unflatten(new_args, args_spec), **kwarg_values)
outs.append(out)
loop_out = []
if isinstance(outs[0], torch.Tensor):
loop_out = torch.stack(outs)
else:
for idx in range(len(outs[0])):
loop_out.append(torch.stack([i[idx] for i in outs], out_dim))
return loop_out
# Like loop helper function but for 2 levels of vmap. If we need more levels than this, probably possible
# to generalize the loops function but it seemed too complicated for this
def loop2(op, in_dims1, in_dims2, out_dim1, out_dim2, batch_size1, batch_size2, *batched_args, **kwarg_values):
outs = []
flat_args, args_spec = pytree.tree_flatten(batched_args)
flat_dims1, dims_spec1 = pytree.tree_flatten(in_dims1)
flat_dims2, dims_spec2 = pytree.tree_flatten(in_dims2)
assert(args_spec == dims_spec1)
assert(args_spec == dims_spec2)
assert(len(flat_dims1) == len(flat_dims2))
for idx1 in range(batch_size1):
out_split = []
arg_split = [a.select(in_dim1, idx1) if in_dim1 is not None else a for a, in_dim1 in zip(flat_args, flat_dims1)]
for idx2 in range(batch_size2):
new_args = [a.select(in_dim, idx2) if in_dim is not None else a for a, in_dim in zip(arg_split, flat_dims2)]
out = op(*pytree.tree_unflatten(new_args, args_spec), **kwarg_values)
out_split.append(out)
outs.append(out_split)
loop_out = []
for out_split in outs:
if isinstance(out_split[0], torch.Tensor):
loop_out.append(torch.stack(out_split, out_dim1))
else:
new_out = []
for idx in range(len(out_split[0])):
new_out.append(torch.stack([i[idx] for i in out_split], out_dim1))
loop_out.append(new_out)
new_out = []
if isinstance(loop_out, torch.Tensor):
new_out = torch.stack(loop_out, out_dim2)
else:
for idx in range(len(loop_out[0])):
new_out.append(torch.stack([i[idx] for i in loop_out], out_dim2))
return new_out
def is_valid_inplace_sample_input(sample_input, op, inplace_variant):
if inplace_variant is None:
return False
if sample_input.broadcasts_input:
return False
# Check if input's dtype matches the output's dtype
args = (sample_input.input,) + sample_input.args
kwargs = sample_input.kwargs
output_dtype = op(*args, **kwargs).dtype
return sample_input.input.dtype == output_dtype
# This is kind of dangerous, please think carefully before using it.
# Known risks:
# - the return better not be mutated so it's best to return immutable types
# (e.g. prefer tuples to list)
# - Don't hash tensors in a global context, that'll keep them around forever
def memoize(fn):
memo = {}
def wrapped(*args):
if args not in memo:
memo[args] = fn(*args)
return memo[args]
return wrapped
# NB: This is O(2 ** num_tensors).
# num_tensors ranges from 1 to 10, with 2-4 being most common.
# Try not to extravagate it if you're modifying it.
@memoize
def get_bdim_choices(num_tensors):
choices = []
# full of zeros
choices.append((0,) * num_tensors)
# All permutations of (-1, None)
options = (-1, None)
for choice in itertools.product(options, repeat=num_tensors):
choices.append(choice)
assert choices[-1] == (None,) * num_tensors
return tuple(choices[:-1])
# NB: This is O(2 ** num_tensors).
# num_tensors ranges from 1 to 10, with 2-4 being most common.
# Try not to extravagate it if you're modifying it.
def get_bdim_choices_batch_norm(num_tensors, _, running_mean=None, running_var=None, *args):
choices = []
options = (-1, None)
# instance norm turns these into unbatched 0 tensors, so we cannot batch the input if either is not specified
if running_mean is None or running_var is None:
choices.append((None,) + (0,) * (num_tensors - 1))
for choice in itertools.product(options, repeat=num_tensors - 1):
choices.append((None,) + choice)
else:
# running_mean and running_var are specified as tensors. Batch norm doesn't work if the input is batched but
# running_mean/var are unbatched, so this tests all other cases
choices.append((0,) * num_tensors)
for choice in itertools.product(options, repeat=num_tensors):
input_bdim = choice[0]
running_mean_bdim = choice[1]
running_var_bdim = choice[2]
if input_bdim and (not running_mean_bdim or not running_var_bdim):
continue
choices.append(choice)
assert choices[-1] == (None,) * num_tensors
return tuple(choices[:-1])
def add_batch_dim(arg, bdim, batch_size=3):
assert bdim == 0 or bdim == -1
assert isinstance(arg, torch.Tensor)
if bdim == 0:
shape = [1] * len(arg.shape)
shape.insert(bdim, batch_size)
return (arg.repeat(shape), bdim)
if bdim == -1:
arg = arg.unsqueeze(-1).expand(*arg.shape, batch_size).contiguous()
return (arg, bdim)
def construct_in_dims(bdim_choice_for_tensors, is_tensors):
result = []
bdim = iter(bdim_choice_for_tensors)
for is_tensor in is_tensors:
if not is_tensor:
result.append(None)
continue
result.append(next(bdim))
return tuple(result)
def is_batch_norm_training(op_name, kwarg_values):
batch_norm_fns = ("nn.functional.batch_norm", "nn.functional.instance_norm") # instance norm calls batch norm
if op_name not in batch_norm_fns:
return False
# batch norm and instance norm require the value to be a plain bool
default_training = op_name == "nn.functional.instance_norm" # instance norm defaults to training, batch norm doesn't
is_training = tuple(arg for arg in tuple(kwarg_values.values()) if isinstance(arg, bool))
if len(is_training) == 0:
return default_training
else:
assert len(is_training) == 1
return is_training[0]
def generate_vmap_inputs(arg_values, kwarg_values, is_batch_norm_and_training=False, batch_size=2):
flat_args, arg_spec = pytree.tree_flatten(tuple(arg_values))
is_tensors = [isinstance(a, torch.Tensor) for a in flat_args]
num_tensors = sum(is_tensors)
# For Batch Norm, if there's only an input, we can't
# batch it since running_mean/var will be seen as unbatched tensors
if num_tensors == 1 and is_batch_norm_and_training:
return
bdim_choices = get_bdim_choices_batch_norm(
num_tensors, *arg_values) if is_batch_norm_and_training else get_bdim_choices(num_tensors)
@memoize
def get_batched_arg(arg, bdim):
assert isinstance(arg, torch.Tensor)
assert bdim is not None
result, _ = add_batch_dim(arg, bdim, batch_size)
return result
for bdim_choice in bdim_choices:
flat_in_dims = construct_in_dims(bdim_choice, is_tensors)
flat_batched_args = tuple(arg if in_dim is None else get_batched_arg(arg, in_dim)
for arg, in_dim in zip(flat_args, flat_in_dims))
batched_args = pytree.tree_unflatten(flat_batched_args, arg_spec)
in_dims = pytree.tree_unflatten(flat_in_dims, arg_spec)
yield batched_args, in_dims, kwarg_values
def clone_if_tensor(x):
if isinstance(x, torch.Tensor):
return x.clone()
return x
def compute_quantities_for_vmap_test(
op, orig_batched_args, orig_kwarg_values, in_dims,
out_dim=0, batch_size=2, compute_loop_out=True,
clone_inputs=False):
def maybe_clone_inputs():
if clone_inputs:
batched_args = pytree.tree_map(clone_if_tensor, orig_batched_args)
kwarg_values = pytree.tree_map(clone_if_tensor, orig_kwarg_values)
return batched_args, kwarg_values
return orig_batched_args, orig_kwarg_values
batched_args, kwarg_values = maybe_clone_inputs()
if compute_loop_out:
loop_out = loop(op, in_dims, out_dim, batch_size, *batched_args, **kwarg_values)
else:
loop_out = None
# Used for debugging the resulting operations
# from functorch import make_fx
# def f(a):
# return op(a)
# t = make_fx(vmap(f, in_dims=in_dims, out_dims=out_dim))(*batched_args, **kwarg_values)
# print(in_dims, [arg.shape for arg in batched_args], kwarg_values)
batched_args, kwarg_values = maybe_clone_inputs()
batched_out = vmap(op, in_dims=in_dims, out_dims=out_dim)(*batched_args, **kwarg_values)
yield (loop_out, batched_out)
# Tests case where we dispatch to a batching rule with no bdims
# This should be handled by autogenerated plumbing. For vmap support
# added via a manual plumbing you may need to handle this specially.
def add_bdim_if_tensor(x):
if isinstance(x, torch.Tensor):
return x.unsqueeze(1)
return x
def f(dummy, *args, **kwargs):
return op(*args, **kwargs)
dummy = torch.ones(batch_size, 1)
expected = pytree.tree_map(add_bdim_if_tensor, batched_out)
inner_in_dims = (0,) + pytree.tree_map(lambda x: None, in_dims)
outer_in_dims = (0,) + in_dims
batched_args, kwarg_values = maybe_clone_inputs()
output = vmap(vmap(f, inner_in_dims), outer_in_dims)(dummy, *batched_args, **kwarg_values)
yield (expected, output)
def get_fallback_and_vmap_exhaustive(op, arg_values, kwarg_values, is_batch_norm_and_training=False, compute_loop_out=True):
out_dim = 0
batch_size = 2
generator = generate_vmap_inputs(arg_values, kwarg_values, is_batch_norm_and_training)
for batched_args, in_dims, kwarg_values in generator:
for quantities in compute_quantities_for_vmap_test(
op, batched_args, kwarg_values, in_dims, out_dim, batch_size, compute_loop_out):
yield quantities
def opinfo_in_dict(opinfo, d):
return (opinfo.name in d) or (f'{opinfo.name}.{opinfo.variant_test_name}' in d)
DecorateMeta = namedtuple("DecorateMeta", [
"op_name",
"variant_name",
"decorator",
"device_type",
"dtypes",
])
def decorate(op_name, variant_name='', *, decorator=None, device_type=None, dtypes=None):
assert decorator is not None
return DecorateMeta(op_name=op_name,
variant_name=variant_name,
decorator=decorator,
device_type=device_type,
dtypes=dtypes)
def xfail(op_name, variant_name='', *, device_type=None, dtypes=None):
return decorate(op_name=op_name,
variant_name=variant_name,
decorator=unittest.expectedFailure,
device_type=device_type,
dtypes=dtypes)
def skip(op_name, variant_name='', *, device_type=None, dtypes=None):
return decorate(op_name=op_name,
variant_name=variant_name,
decorator=unittest.skip("Skipped!"),
device_type=device_type,
dtypes=dtypes)
def skipOps(test_case_name, base_test_name, to_skip):
all_opinfos = op_db + additional_op_db
for decorate_meta in to_skip:
matching_opinfos = [o for o in all_opinfos
if o.name == decorate_meta.op_name and
o.variant_test_name == decorate_meta.variant_name]
assert len(matching_opinfos) > 0, f"Couldn't find OpInfo for {decorate_meta}"
assert len(matching_opinfos) == 1, (
"OpInfos should be uniquely determined by their (name, variant_name). "
f"Got more than one result for ({decorate_meta.op_name}, {decorate_meta.variant_name})"
)
opinfo = matching_opinfos[0]
decorators = list(opinfo.decorators)
new_decorator = DecorateInfo(decorate_meta.decorator,
test_case_name, base_test_name,
device_type=decorate_meta.device_type,
dtypes=decorate_meta.dtypes)
decorators.append(new_decorator)
opinfo.decorators = tuple(decorators)
# This decorator doesn't modify fn in any way
def wrapped(fn):
return fn
return wrapped
def expectedFailureIf(condition):
def decorator(fn):
if condition:
return unittest.expectedFailure(fn)
return fn
return decorator
def tol2(op_name, variant_name, override_dct, *, device_type=None):
return (op_name, variant_name, override_dct, device_type)
def tol1(op_name, override_dct, *, device_type=None):
return tol2(op_name, '', override_dct, device_type=device_type)
def opsToleranceOverride(test_case_name, base_test_name, overrides):
all_opinfos = op_db + additional_op_db
for override in overrides:
op_name, variant_name, override, device_type = override
matching_opinfos = [o for o in all_opinfos
if o.name == op_name and o.variant_test_name == variant_name]
assert len(matching_opinfos) == 1, f"Couldn't find OpInfo for {override}"
opinfo = matching_opinfos[0]
decorators = list(opinfo.decorators)
decorators.append(DecorateInfo(
toleranceOverride(override),
test_case_name, base_test_name, device_type=device_type))
opinfo.decorators = tuple(decorators)
# This decorator doesn't modify fn in any way
def wrapped(fn):
return fn
return wrapped
class DisableVmapFallback:
def __enter__(self):
self.prev_state = torch._C._functorch._is_vmap_fallback_enabled()
torch._C._functorch._set_vmap_fallback_enabled(False)
def __exit__(self, *ignored):
torch._C._functorch._set_vmap_fallback_enabled(self.prev_state)
def check_vmap_fallback(test_case, thunk, opinfo, dry_run=False):
try:
with DisableVmapFallback():
thunk()
except Exception:
if not dry_run:
raise
if opinfo.variant_test_name:
print(f"xfail('{opinfo.name}', '{opinfo.variant_test_name}'),")
else:
print(f"xfail('{opinfo.name}'),")