blob: 11cb2e223b9a35866df42cac5559ac832bae06dc [file] [log] [blame]
# Owner(s): ["module: functorch"]
import typing
import unittest
from torch.testing._internal.common_utils import (
TestCase,
run_tests,
instantiate_parametrized_tests,
parametrize,
subtest
)
from torch._C import (
_dispatch_get_registrations_for_dispatch_key as get_registrations_for_dispatch_key,
)
xfail_functorch_batched = {
"aten::flatten.using_ints",
"aten::imag",
"aten::is_nonzero",
"aten::isfinite",
"aten::isreal",
"aten::item",
"aten::linalg_pinv",
"aten::linalg_pinv.atol_rtol_float",
"aten::linalg_slogdet",
"aten::linalg_lu_factor",
"aten::linear",
"aten::log_sigmoid",
"aten::log_softmax.int",
"aten::logdet",
"aten::masked_select_backward",
"aten::movedim.intlist",
"aten::one_hot",
"aten::real",
"aten::silu_backward",
"aten::special_xlogy",
"aten::special_xlogy.other_scalar",
"aten::special_xlogy.self_scalar",
"aten::tensor_split.indices",
"aten::tensor_split.sections",
"aten::to.device",
"aten::to.dtype",
"aten::to.dtype_layout",
"aten::to.other",
"aten::upsample_bicubic2d.vec",
"aten::upsample_bilinear2d.vec",
"aten::upsample_linear1d.vec",
"aten::upsample_nearest1d.vec",
"aten::upsample_nearest2d.vec",
"aten::upsample_nearest3d.vec",
"aten::upsample_trilinear3d.vec",
"aten::where",
}
xfail_functorch_batched_decomposition = {
"aten::diagonal_copy",
"aten::is_same_size",
"aten::unfold_copy",
}
xfail_not_implemented = {
"aten::absolute_",
"aten::affine_grid_generator_backward",
"aten::align_as",
"aten::align_tensors",
"aten::align_to",
"aten::align_to.ellipsis_idx",
"aten::alpha_dropout",
"aten::alpha_dropout_",
"aten::arccos_",
"aten::arccosh_",
"aten::arcsin_",
"aten::arcsinh_",
"aten::arctan2_",
"aten::arctan_",
"aten::arctanh_",
"aten::argwhere",
"aten::bilinear",
"aten::can_cast",
"aten::cat.names",
"aten::chain_matmul",
"aten::chalf",
"aten::choose_qparams_optimized",
"aten::clip_",
"aten::clip_.Tensor",
"aten::coalesce",
"aten::column_stack",
"aten::concat.names",
"aten::concatenate.names",
"aten::conj",
"aten::conv_tbc_backward",
"aten::ctc_loss.IntList",
"aten::ctc_loss.Tensor",
"aten::cudnn_is_acceptable",
"aten::cummaxmin_backward",
"aten::data",
"aten::diagflat",
"aten::divide.out_mode",
"aten::divide_.Scalar",
"aten::dropout",
"aten::dropout_",
"aten::embedding_bag",
"aten::embedding_bag.padding_idx",
"aten::feature_alpha_dropout",
"aten::feature_alpha_dropout_",
"aten::feature_dropout",
"aten::feature_dropout_",
"aten::fft_ihfft2",
"aten::fft_ihfftn",
"aten::fill_diagonal_",
"aten::fix_",
"aten::flatten.named_out_dim",
"aten::flatten.using_ints",
"aten::flatten.using_names",
"aten::flatten_dense_tensors",
"aten::float_power_.Scalar",
"aten::float_power_.Tensor",
"aten::floor_divide_.Scalar",
"aten::frobenius_norm",
"aten::fused_moving_avg_obs_fake_quant",
"aten::get_gradients",
"aten::greater_.Scalar",
"aten::greater_.Tensor",
"aten::greater_equal_.Scalar",
"aten::greater_equal_.Tensor",
"aten::gru.data",
"aten::gru.input",
"aten::gru_cell",
"aten::histogramdd",
"aten::histogramdd.TensorList_bins",
"aten::histogramdd.int_bins",
"aten::imag",
"aten::infinitely_differentiable_gelu_backward",
"aten::isclose",
"aten::isfinite",
"aten::isreal",
"aten::istft",
"aten::item",
"aten::kl_div",
"aten::ldexp_",
"aten::less_.Scalar",
"aten::less_.Tensor",
"aten::less_equal_.Scalar",
"aten::less_equal_.Tensor",
"aten::linalg_cond.p_str",
"aten::linalg_eigh",
"aten::linalg_eigh.eigvals",
"aten::linalg_lu_factor",
"aten::linalg_matrix_rank",
"aten::linalg_matrix_rank.out_tol_tensor",
"aten::linalg_matrix_rank.tol_tensor",
"aten::linalg_pinv",
"aten::linalg_pinv.atol_rtol_float",
"aten::linalg_pinv.out_rcond_tensor",
"aten::linalg_pinv.rcond_tensor",
"aten::linalg_slogdet",
"aten::linalg_svd.U",
"aten::linalg_tensorsolve",
"aten::linear",
"aten::log_sigmoid",
"aten::log_softmax.int",
"aten::logdet",
"aten::logsumexp.names",
"aten::lstm.data",
"aten::lstm.input",
"aten::lstm_cell",
"aten::lu_solve",
"aten::margin_ranking_loss",
"aten::masked_select_backward",
"aten::matrix_exp",
"aten::matrix_exp_backward",
"aten::max.names_dim",
"aten::max.names_dim_max",
"aten::mean.names_dim",
"aten::median.names_dim",
"aten::median.names_dim_values",
"aten::min.names_dim",
"aten::min.names_dim_min",
"aten::mish_backward",
"aten::moveaxis.int",
"aten::movedim.intlist",
"aten::multilabel_margin_loss",
"aten::nanmedian.names_dim",
"aten::nanmedian.names_dim_values",
"aten::nanquantile",
"aten::nanquantile.scalar",
"aten::narrow.Tensor",
"aten::native_channel_shuffle",
"aten::negative_",
"aten::nested_to_padded_tensor",
"aten::nonzero_numpy",
"aten::norm.names_ScalarOpt_dim",
"aten::norm.names_ScalarOpt_dim_dtype",
"aten::norm_except_dim",
"aten::not_equal_.Scalar",
"aten::not_equal_.Tensor",
"aten::one_hot",
"aten::output_nr",
"aten::pad_sequence",
"aten::pdist",
"aten::pin_memory",
"aten::promote_types",
"aten::qr.Q",
"aten::quantile",
"aten::quantile.scalar",
"aten::real",
"aten::refine_names",
"aten::rename",
"aten::rename_",
"aten::requires_grad_",
"aten::retain_grad",
"aten::retains_grad",
"aten::rnn_relu.data",
"aten::rnn_relu.input",
"aten::rnn_relu_cell",
"aten::rnn_tanh.data",
"aten::rnn_tanh.input",
"aten::rnn_tanh_cell",
"aten::set_.source_Tensor_storage_offset",
"aten::set_data",
"aten::silu_backward",
"aten::slow_conv3d",
"aten::smm",
"aten::special_chebyshev_polynomial_t.n_scalar",
"aten::special_chebyshev_polynomial_t.x_scalar",
"aten::special_chebyshev_polynomial_u.n_scalar",
"aten::special_chebyshev_polynomial_u.x_scalar",
"aten::special_chebyshev_polynomial_v.n_scalar",
"aten::special_chebyshev_polynomial_v.x_scalar",
"aten::special_chebyshev_polynomial_w.n_scalar",
"aten::special_chebyshev_polynomial_w.x_scalar",
"aten::special_hermite_polynomial_h.n_scalar",
"aten::special_hermite_polynomial_h.x_scalar",
"aten::special_hermite_polynomial_he.n_scalar",
"aten::special_hermite_polynomial_he.x_scalar",
"aten::special_laguerre_polynomial_l.n_scalar",
"aten::special_laguerre_polynomial_l.x_scalar",
"aten::special_legendre_polynomial_p.n_scalar",
"aten::special_legendre_polynomial_p.x_scalar",
"aten::special_shifted_chebyshev_polynomial_t.n_scalar",
"aten::special_shifted_chebyshev_polynomial_t.x_scalar",
"aten::special_shifted_chebyshev_polynomial_u.n_scalar",
"aten::special_shifted_chebyshev_polynomial_u.x_scalar",
"aten::special_shifted_chebyshev_polynomial_v.n_scalar",
"aten::special_shifted_chebyshev_polynomial_v.x_scalar",
"aten::special_shifted_chebyshev_polynomial_w.n_scalar",
"aten::special_shifted_chebyshev_polynomial_w.x_scalar",
"aten::special_xlogy",
"aten::special_xlogy.other_scalar",
"aten::special_xlogy.self_scalar",
"aten::square_",
"aten::sspaddmm",
"aten::std.correction_names",
"aten::std.names_dim",
"aten::std_mean.correction_names",
"aten::std_mean.names_dim",
"aten::stft",
"aten::stft.center",
"aten::stride.int",
"aten::subtract.Scalar",
"aten::subtract_.Scalar",
"aten::subtract_.Tensor",
"aten::svd.U",
"aten::tensor_split.indices",
"aten::tensor_split.sections",
"aten::tensor_split.tensor_indices_or_sections",
"aten::thnn_conv2d",
"aten::to.device",
"aten::to.dtype",
"aten::to.dtype_layout",
"aten::to.other",
"aten::to_dense",
"aten::to_dense_backward",
"aten::to_mkldnn_backward",
"aten::trace_backward",
"aten::triplet_margin_loss",
"aten::unflatten_dense_tensors",
"aten::unsafe_chunk",
"aten::upsample_bicubic2d.vec",
"aten::upsample_bilinear2d.vec",
"aten::upsample_linear1d.vec",
"aten::upsample_nearest1d.vec",
"aten::upsample_nearest2d.vec",
"aten::upsample_nearest3d.vec",
"aten::upsample_trilinear3d.vec",
"aten::vander",
"aten::var.correction_names",
"aten::var.names_dim",
"aten::var_mean.correction_names",
"aten::var_mean.names_dim",
"aten::where",
}
def dispatch_registrations(
dispatch_key: str, xfails: set, filter_func: typing.Callable = lambda reg: True):
registrations = sorted(get_registrations_for_dispatch_key(dispatch_key))
subtests = [
subtest(reg, name=f"[{reg}]",
decorators=([unittest.expectedFailure] if reg in xfails else []))
for reg in registrations if filter_func(reg)
]
return parametrize("registration", subtests)
CompositeImplicitAutogradRegistrations = set(
get_registrations_for_dispatch_key("CompositeImplicitAutograd")
)
FuncTorchBatchedRegistrations = set(
get_registrations_for_dispatch_key("FuncTorchBatched")
)
FuncTorchBatchedDecompositionRegistrations = set(
get_registrations_for_dispatch_key("FuncTorchBatchedDecomposition")
)
def filter_vmap_implementable(reg):
reg = reg.lower()
if not reg.startswith("aten::"):
return False
if reg.startswith("aten::_"):
return False
if reg.endswith(".out"):
return False
if reg.endswith("_out"):
return False
if '.dimname' in reg:
return False
if "_dimname" in reg:
return False
if 'fbgemm' in reg:
return False
if 'quantize' in reg:
return False
if 'sparse' in reg:
return False
if '::is_' in reg:
return False
return True
class TestFunctorchDispatcher(TestCase):
@dispatch_registrations("CompositeImplicitAutograd", xfail_functorch_batched)
def test_register_a_batching_rule_for_composite_implicit_autograd(
self, registration
):
assert registration not in FuncTorchBatchedRegistrations, (
f"You've added a batching rule for a CompositeImplicitAutograd operator {registration}. "
"The correct way to add vmap support for it is to put it into BatchRulesDecomposition to "
"reuse the CompositeImplicitAutograd decomposition"
)
@dispatch_registrations(
"FuncTorchBatchedDecomposition", xfail_functorch_batched_decomposition
)
def test_register_functorch_batched_decomposition(self, registration):
assert registration in CompositeImplicitAutogradRegistrations, (
f"The registrations in BatchedDecompositions.cpp must be for CompositeImplicitAutograd "
f"operations. If your operation {registration} is not CompositeImplicitAutograd, "
"then please register it to the FuncTorchBatched key in another file."
)
@dispatch_registrations(
"CompositeImplicitAutograd", xfail_not_implemented, filter_vmap_implementable
)
def test_unimplemented_batched_registrations(self, registration):
assert registration in FuncTorchBatchedDecompositionRegistrations, (
f"Please check that there is an OpInfo that covers the operator {registration} "
"and add a registration in BatchedDecompositions.cpp. "
"If your operator isn't user facing, please add it to the xfail list"
)
instantiate_parametrized_tests(TestFunctorchDispatcher)
if __name__ == "__main__":
run_tests()