blob: 20eb736e494573ac7305c8ce2519867902be7ae4 [file] [log] [blame]
#define TORCH_ASSERT_NO_OPERATORS
#include <ATen/Dispatch.h>
#include <ATen/NumericUtils.h>
#include <ATen/native/DispatchStub.h>
#include <ATen/native/ReduceAllOps.h>
#include <ATen/native/ReduceOps.h>
#include <ATen/native/SharedReduceOps.h>
#include <ATen/native/TensorCompare.h>
#include <ATen/native/TensorIterator.h>
#include <ATen/native/cuda/ReduceOps.h>
#include <ATen/cuda/NumericLimits.cuh>
#include <ATen/native/cuda/Reduce.cuh>
#include <ATen/Dispatch.h>
#include <ATen/NumericUtils.h>
#include <ATen/cuda/NumericLimits.cuh>
namespace at {
namespace native {
template <typename scalar_t, typename acc_t = scalar_t>
void argmin_kernel_cuda_impl(TensorIterator& iter) {
gpu_reduce_kernel<scalar_t, int64_t>(
iter,
ArgMinOps<acc_t>{},
thrust::pair<acc_t, int64_t>(
at::numeric_limits<acc_t>::upper_bound(), 0));
};
void argmin_kernel_cuda(TensorIterator& iter) {
// For float16 & bfloat16, instead of implementing is_nan and warp_shfl_down,
// we can convert float16 & bfloat16 to float and do all the operations in
// float.
if (iter.dtype(1) == kHalf) {
argmin_kernel_cuda_impl<at::Half, float>(iter);
} else if (iter.dtype(1) == kBFloat16) {
argmin_kernel_cuda_impl<at::BFloat16, float>(iter);
} else {
AT_DISPATCH_ALL_TYPES(iter.dtype(1), "argmin_cuda", [&]() {
argmin_kernel_cuda_impl<scalar_t>(iter);
});
}
}
REGISTER_DISPATCH(argmin_stub, &argmin_kernel_cuda);
} // namespace native
} // namespace at