blob: 86d8bbd528c8fcf095a0d71ce3c2c62099989c17 [file] [log] [blame]
#define TORCH_ASSERT_NO_OPERATORS
#define _USE_MATH_DEFINES
#include <ATen/native/Activation.h>
#include <cmath>
#include <thrust/tuple.h>
#include <ATen/AccumulateType.h>
#include <ATen/Dispatch.h>
#include <ATen/core/TensorBase.h>
#include <c10/core/Scalar.h>
#include <c10/cuda/CUDAMathCompat.h>
#include <ATen/cuda/ApplyGridUtils.cuh>
#include <ATen/cuda/detail/OffsetCalculator.cuh>
#include <ATen/native/cuda/Loops.cuh>
namespace at {
namespace native {
namespace {
template <typename scalar_t>
void threshold_kernel_impl(
TensorIteratorBase& iter,
scalar_t threshold,
scalar_t value) {
gpu_kernel_with_scalars(
iter, [=] GPU_LAMBDA(scalar_t x, scalar_t other) -> scalar_t {
return x <= threshold ? value : other;
});
}
static void threshold_kernel_cuda(
TensorIteratorBase& iter,
const Scalar& threshold,
const Scalar& value) {
AT_DISPATCH_ALL_TYPES_AND2(
at::ScalarType::Half,
at::ScalarType::BFloat16,
iter.dtype(),
"threshold_cuda",
[&] {
threshold_kernel_impl<scalar_t>(
iter, threshold.to<scalar_t>(), value.to<scalar_t>());
});
}
} // namespace
REGISTER_DISPATCH(threshold_stub, &threshold_kernel_cuda);
} // namespace native
} // namespace at