blob: 850dbc666cc4b98dc206c00dd9c50c55609a5a8a [file] [log] [blame]
// Copyright (c) Facebook, Inc. and its affiliates.
// All rights reserved.
//
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree.
#include <ATen/Operators.h>
#include <ATen/FunctionalTensorWrapper.h>
#include <ATen/core/dispatch/Dispatcher.h>
#include <ATen/functorch/BatchRulesHelper.h>
#include <ATen/functorch/BatchedFallback.h>
#include <ATen/functorch/DynamicLayer.h>
#include <ATen/functorch/PlumbingHelper.h>
namespace at { namespace functorch {
#define OP_DECOMPOSE(op) m.impl(#op, static_cast<decltype(&ATEN_FN(op))>(native::op));
#define OP_DECOMPOSE2(op, overload) m.impl(#op"."#overload, static_cast<decltype(&ATEN_FN2(op, overload))>(native::op));
TORCH_LIBRARY_IMPL(aten, FuncTorchVmapMode, m) {
OP_DECOMPOSE(alpha_dropout_);
OP_DECOMPOSE(dropout_);
OP_DECOMPOSE(feature_alpha_dropout_);
OP_DECOMPOSE(feature_dropout_);
}
TORCH_LIBRARY_IMPL(aten, FuncTorchBatched, m) {
OP_DECOMPOSE2(__and__, Scalar);
OP_DECOMPOSE2(__and__, Tensor);
OP_DECOMPOSE2(__iand__, Tensor);
OP_DECOMPOSE2(__iand__, Scalar);
OP_DECOMPOSE2(__ior__, Tensor);
OP_DECOMPOSE2(__ior__, Scalar);
OP_DECOMPOSE2(__ixor__, Tensor);
OP_DECOMPOSE2(__ixor__, Scalar);
OP_DECOMPOSE2(__or__, Tensor);
OP_DECOMPOSE2(__or__, Scalar);
OP_DECOMPOSE2(__xor__, Tensor);
OP_DECOMPOSE2(__xor__, Scalar);
OP_DECOMPOSE(_batch_norm_impl_index);
OP_DECOMPOSE(absolute);
OP_DECOMPOSE(arctan2);
OP_DECOMPOSE(avg_pool1d);
OP_DECOMPOSE(adaptive_max_pool1d);
OP_DECOMPOSE(adaptive_avg_pool1d);
m.impl("adaptive_avg_pool2d", native::adaptive_avg_pool2d_symint);
OP_DECOMPOSE(adaptive_avg_pool3d);
OP_DECOMPOSE(adjoint);
OP_DECOMPOSE(arccos);
OP_DECOMPOSE(arccosh);
OP_DECOMPOSE(arcsin);
OP_DECOMPOSE(arcsinh);
OP_DECOMPOSE(arctan);
OP_DECOMPOSE(arctanh);
OP_DECOMPOSE(atleast_1d);
OP_DECOMPOSE2(atleast_1d, Sequence);
OP_DECOMPOSE(atleast_2d);
OP_DECOMPOSE2(atleast_2d, Sequence);
OP_DECOMPOSE(atleast_3d);
OP_DECOMPOSE2(atleast_3d, Sequence);
OP_DECOMPOSE(batch_norm);
OP_DECOMPOSE2(bitwise_or, Scalar);
OP_DECOMPOSE2(bitwise_xor, Scalar);
OP_DECOMPOSE(broadcast_tensors);
OP_DECOMPOSE(broadcast_to);
OP_DECOMPOSE(cartesian_prod);
OP_DECOMPOSE(cdist);
OP_DECOMPOSE(clip);
OP_DECOMPOSE2(clip, Tensor );
OP_DECOMPOSE(concat);
OP_DECOMPOSE(conj_physical);
OP_DECOMPOSE(combinations);
OP_DECOMPOSE(corrcoef);
OP_DECOMPOSE(cosine_embedding_loss);
OP_DECOMPOSE(cosine_similarity);
OP_DECOMPOSE(cov);
m.impl("cross_entropy_loss", native::cross_entropy_loss_symint);
OP_DECOMPOSE2(cumulative_trapezoid, x);
OP_DECOMPOSE2(cumulative_trapezoid, dx);
OP_DECOMPOSE2(dsplit, int);
OP_DECOMPOSE2(dsplit, array);
OP_DECOMPOSE(det);
m.impl("diag_backward", native::diag_backward_symint);
OP_DECOMPOSE(diff);
OP_DECOMPOSE(dstack);
OP_DECOMPOSE(einsum);
m.impl("embedding_backward", native::embedding_backward_symint);
OP_DECOMPOSE(expand_as);
OP_DECOMPOSE(fft_fft);
OP_DECOMPOSE(fft_fftshift);
OP_DECOMPOSE(fft_fft2);
OP_DECOMPOSE(fft_fftn);
OP_DECOMPOSE(fft_hfft);
OP_DECOMPOSE(fft_hfft2);
OP_DECOMPOSE(fft_hfftn);
OP_DECOMPOSE(fft_ifft);
OP_DECOMPOSE(fft_ifftshift);
OP_DECOMPOSE(fft_ifft2);
OP_DECOMPOSE(fft_ifftn);
OP_DECOMPOSE(fft_ihfft);
OP_DECOMPOSE(fft_irfft);
OP_DECOMPOSE(fft_irfft2);
OP_DECOMPOSE(fft_irfftn);
OP_DECOMPOSE(fft_rfft);
OP_DECOMPOSE(fft_rfft2);
OP_DECOMPOSE(fft_rfftn);
OP_DECOMPOSE(fix);
OP_DECOMPOSE(fliplr);
OP_DECOMPOSE(flipud);
OP_DECOMPOSE2(float_power, Tensor_Tensor);
OP_DECOMPOSE2(float_power, Tensor_Scalar);
OP_DECOMPOSE(ger);
OP_DECOMPOSE2(gradient, scalarint);
OP_DECOMPOSE2(gradient, scalararray);
OP_DECOMPOSE2(gradient, array);
OP_DECOMPOSE2(gradient, scalarrayint);
OP_DECOMPOSE2(gradient, scalarrayarray);
OP_DECOMPOSE2(gradient, tensorarrayint);
OP_DECOMPOSE2(gradient, tensorarray);
OP_DECOMPOSE2(greater_equal, Tensor );
OP_DECOMPOSE2(greater, Tensor );
OP_DECOMPOSE(grid_sampler);
OP_DECOMPOSE(group_norm);
OP_DECOMPOSE(hinge_embedding_loss);
OP_DECOMPOSE2(hsplit, int);
OP_DECOMPOSE2(hsplit, array);
OP_DECOMPOSE(hstack);
m.impl("index_select_backward", native::index_select_backward_symint);
OP_DECOMPOSE(inner);
OP_DECOMPOSE(inverse);
OP_DECOMPOSE(instance_norm);
OP_DECOMPOSE(kron);
OP_DECOMPOSE(l1_loss);
OP_DECOMPOSE(layer_norm);
OP_DECOMPOSE2(ldexp, Tensor);
OP_DECOMPOSE2(less_equal, Tensor );
OP_DECOMPOSE2(less, Tensor );
OP_DECOMPOSE(linalg_cond);
OP_DECOMPOSE(linalg_cholesky);
OP_DECOMPOSE(linalg_det);
OP_DECOMPOSE(linalg_eigvalsh);
OP_DECOMPOSE(linalg_eigvals);
OP_DECOMPOSE(linalg_inv);
OP_DECOMPOSE(linalg_matmul);
OP_DECOMPOSE(linalg_matrix_norm);
OP_DECOMPOSE2(linalg_matrix_norm, str_ord);
OP_DECOMPOSE(linalg_multi_dot);
OP_DECOMPOSE(linalg_norm);
OP_DECOMPOSE2(linalg_norm, ord_str);
OP_DECOMPOSE(linalg_solve);
OP_DECOMPOSE(linalg_solve_ex);
OP_DECOMPOSE(linalg_svd);
OP_DECOMPOSE(linalg_svdvals);
OP_DECOMPOSE(linalg_tensorinv);
OP_DECOMPOSE(_lu_with_info);
OP_DECOMPOSE(matmul);
OP_DECOMPOSE(matrix_H);
OP_DECOMPOSE(matrix_power);
OP_DECOMPOSE2(max, other );
OP_DECOMPOSE(max_pool1d_with_indices);
OP_DECOMPOSE(max_pool2d);
OP_DECOMPOSE(meshgrid);
OP_DECOMPOSE2(meshgrid, indexing);
OP_DECOMPOSE(mH);
OP_DECOMPOSE2(min, other );
OP_DECOMPOSE2(moveaxis, intlist);
OP_DECOMPOSE2(movedim, int);
OP_DECOMPOSE(msort);
OP_DECOMPOSE(mT);
m.impl("narrow", native::narrow_symint);
OP_DECOMPOSE(negative);
OP_DECOMPOSE2(frobenius_norm, dim);
OP_DECOMPOSE2(nuclear_norm, dim);
OP_DECOMPOSE(nuclear_norm);
m.impl("nll_loss_nd", native::nll_loss_nd_symint);
m.impl("nll_loss", native::nll_loss_symint);
m.impl("nll_loss2d", native::nll_loss2d_symint);
OP_DECOMPOSE2(not_equal, Tensor );
OP_DECOMPOSE(outer);
OP_DECOMPOSE(pairwise_distance);
OP_DECOMPOSE(pinverse);
OP_DECOMPOSE(poisson_nll_loss);
OP_DECOMPOSE(positive);
OP_DECOMPOSE(qr);
OP_DECOMPOSE(ravel);
OP_DECOMPOSE2(repeat_interleave, self_int);
OP_DECOMPOSE2(repeat_interleave, self_Tensor);
m.impl("reshape", native::reshape_symint);
OP_DECOMPOSE(resolve_conj);
OP_DECOMPOSE(resolve_neg);
OP_DECOMPOSE(row_stack);
OP_DECOMPOSE(rrelu);
OP_DECOMPOSE2(softmax, int);
OP_DECOMPOSE(special_gammainc);
OP_DECOMPOSE(special_gammaincc);
OP_DECOMPOSE(special_logit);
OP_DECOMPOSE(special_log_softmax);
OP_DECOMPOSE(special_logsumexp);
OP_DECOMPOSE(special_multigammaln);
OP_DECOMPOSE(special_polygamma);
OP_DECOMPOSE(special_softmax);
m.impl("split.sizes", native::split_symint);
OP_DECOMPOSE(square);
OP_DECOMPOSE(numpy_T);
OP_DECOMPOSE(reshape_as);
OP_DECOMPOSE(slogdet);
OP_DECOMPOSE(t);
OP_DECOMPOSE2(result_type, Tensor);
OP_DECOMPOSE2(result_type, Scalar);
OP_DECOMPOSE2(result_type, Scalar_Tensor);
OP_DECOMPOSE2(result_type, Scalar_Scalar);
OP_DECOMPOSE(is_same_size);
OP_DECOMPOSE(view_as);
OP_DECOMPOSE2(size, int);
OP_DECOMPOSE(is_complex);
OP_DECOMPOSE(std);
OP_DECOMPOSE2(std, dim);
OP_DECOMPOSE(std_mean);
OP_DECOMPOSE2(std_mean, dim);
OP_DECOMPOSE(swapaxes);
OP_DECOMPOSE2(subtract, Tensor);
OP_DECOMPOSE(sum_to_size);
OP_DECOMPOSE(svd);
OP_DECOMPOSE(swapdims);
OP_DECOMPOSE(take_along_dim);
OP_DECOMPOSE(tensordot);
OP_DECOMPOSE(tile);
OP_DECOMPOSE2(trapezoid, x);
OP_DECOMPOSE2(trapezoid, dx);
OP_DECOMPOSE2(trapz, x);
OP_DECOMPOSE2(trapz, dx);
m.impl("value_selecting_reduction_backward", native::value_selecting_reduction_backward_symint);
OP_DECOMPOSE(var);
OP_DECOMPOSE2(var, dim);
OP_DECOMPOSE(var_mean);
OP_DECOMPOSE2(var_mean, dim);
OP_DECOMPOSE2(vsplit, int);
OP_DECOMPOSE2(vsplit, array);
OP_DECOMPOSE(vstack);
OP_DECOMPOSE2(where, ScalarOther);
OP_DECOMPOSE2(where, ScalarSelf);
OP_DECOMPOSE(orgqr);
OP_DECOMPOSE2(unflatten, int);
OP_DECOMPOSE(_convolution_double_backward);
OP_DECOMPOSE(conv_transpose1d);
OP_DECOMPOSE2(conv_transpose2d, input);
OP_DECOMPOSE2(conv_transpose3d, input);
OP_DECOMPOSE(conv1d);
OP_DECOMPOSE(conv2d);
OP_DECOMPOSE(conv3d);
OP_DECOMPOSE2(conv1d, padding);
OP_DECOMPOSE2(conv2d, padding);
OP_DECOMPOSE2(conv3d, padding);
OP_DECOMPOSE(_convolution_mode);
OP_DECOMPOSE(frobenius_norm);
OP_DECOMPOSE(type_as);
OP_DECOMPOSE(linalg_diagonal);
m.impl("pad", native::pad_symint);
m.impl("_pad_circular", native::_pad_circular_symint);
OP_DECOMPOSE(t_);
OP_DECOMPOSE(swapdims_);
OP_DECOMPOSE(swapaxes_);
OP_DECOMPOSE(unfold_copy);
// divide, alias for div
OP_DECOMPOSE2(divide, Tensor);
OP_DECOMPOSE2(divide_, Tensor);
OP_DECOMPOSE2(divide, Scalar);
OP_DECOMPOSE2(divide, Tensor_mode);
OP_DECOMPOSE2(divide_, Tensor_mode);
OP_DECOMPOSE2(divide, Scalar_mode);
OP_DECOMPOSE2(divide_, Scalar_mode);
// divide, alias for div
OP_DECOMPOSE2(true_divide, Tensor);
OP_DECOMPOSE2(true_divide_, Tensor);
OP_DECOMPOSE2(true_divide, Scalar);
OP_DECOMPOSE2(true_divide_, Scalar);
// multiply, alias for mul
OP_DECOMPOSE2(multiply, Tensor)
OP_DECOMPOSE2(multiply_, Tensor)
OP_DECOMPOSE2(multiply, Scalar)
OP_DECOMPOSE2(multiply_, Scalar)
}
}}