blob: 87aa784b7d5d39c503b3d21e53ba160ec5f7464b [file] [log] [blame]
#define TORCH_ASSERT_NO_OPERATORS
#include <limits>
#include <ATen/native/UnaryOps.h>
#include <ATen/native/cuda/Loops.cuh>
#include <ATen/AccumulateType.h>
#include <ATen/Dispatch.h>
#include <ATen/native/DispatchStub.h>
#include <ATen/native/TensorIterator.h>
#include <ATen/native/cuda/Math.cuh>
namespace at { namespace native {
// We manually overload ceil because std::ceil does not work with std::complex types.
template <typename scalar_t>
__host__ __device__ static inline scalar_t ceil_wrapper(scalar_t a) {
return std::ceil(a);
}
template<typename T>
__host__ __device__ static inline std::complex<T> ceil_wrapper(std::complex<T> v) {
return std::complex<T>(std::ceil(v.real()), std::ceil(v.imag()));
}
void ceil_kernel_cuda(TensorIteratorBase& iter) {
AT_DISPATCH_FLOATING_TYPES_AND2(
ScalarType::Half, ScalarType::BFloat16,
iter.dtype(), "ceil_cuda",
[&]() {
gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
return ceil_wrapper(a);
});
});
}
void frac_kernel_cuda(TensorIteratorBase& iter) {
AT_DISPATCH_FLOATING_TYPES_AND2(
ScalarType::Half, ScalarType::BFloat16,
iter.dtype(), "frac_cuda",
[&]() {
gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
return a - ::trunc(a);
});
});
}
// We manually overload floor because std::floor does not work with std::complex types.
template <typename scalar_t>
__host__ __device__ static inline scalar_t floor_wrapper(scalar_t a) {
return std::floor(a);
}
template<typename T>
__host__ __device__ static inline std::complex<T> floor_wrapper(std::complex<T> v) {
return std::complex<T>(std::floor(v.real()), std::floor(v.imag()));
}
void floor_kernel_cuda(TensorIteratorBase& iter) {
AT_DISPATCH_FLOATING_TYPES_AND2(
ScalarType::Half, ScalarType::BFloat16,
iter.dtype(), "floor_cuda",
[&]() {
gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
return floor_wrapper(a);
});
});
}
template <typename scalar_t>
__host__ __device__ static inline scalar_t reciprocal_wrapper(scalar_t a) {
return static_cast<scalar_t>(1)/a;
}
template<typename T>
__host__ __device__ static inline c10::complex<T> reciprocal_wrapper(c10::complex<T> v) {
// Handle extreme cases for numpy compatibility
auto both_inf = [](T real, T imag) {
return (::isinf(real) && ::isinf(imag));
};
auto either_inf = [](T real, T imag) {
return ::isinf(real) || ::isinf(imag);
};
auto either_nan = [](T real, T imag) {
return ::isnan(real) || ::isnan(imag);
};
if (either_nan(v.real(), v.imag()) || both_inf(v.real(), v.imag())) {
// If either is Nan or both are infinite, return {nan, nan}
return {std::numeric_limits<T>::quiet_NaN(), std::numeric_limits<T>::quiet_NaN()};
} else if (either_inf(v.real(), v.imag())) {
// If either is Inf, return {0, 0}
return {0, 0};
}
const c10::complex<T> one = c10::complex<T>(1.0, 0);
return one/v;
}
void reciprocal_kernel_cuda(TensorIteratorBase& iter) {
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(
ScalarType::Half, ScalarType::BFloat16,
iter.common_dtype(), "reciprocal_cuda",
[&]() {
gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
return reciprocal_wrapper(a);
});
});
}
// We manually overload nearbyint because std::nearbyint does not work with std::complex types and ROCm.
template <typename scalar_t>
__host__ __device__ static inline scalar_t nearbyint_wrapper(scalar_t a) {
return static_cast<scalar_t>(::nearbyintf(static_cast<float>(a)));
}
__host__ __device__ static inline double nearbyint_wrapper(double a) {
return ::nearbyint(a);
}
__host__ __device__ static inline c10::complex<float> nearbyint_wrapper(c10::complex<float> a) {
return c10::complex<float>(::nearbyintf(static_cast<float>(a.real())), ::nearbyintf(static_cast<float>(a.imag())));
}
#pragma push
#pragma diag_suppress 177 // Function was declared but never referenced
__host__ __device__ static inline c10::complex<double> nearbyint_wrapper(c10::complex<double> a) {
return c10::complex<double>(::nearbyint(static_cast<double>(a.real())), ::nearbyint(static_cast<double>(a.imag())));
}
#pragma pop
void round_kernel_cuda(TensorIteratorBase& iter) {
AT_DISPATCH_FLOATING_TYPES_AND2(
ScalarType::Half, ScalarType::BFloat16,
iter.dtype(), "round_cuda",
[&]() {
gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
// We do not use std::round because we would like to round midway numbers to the nearest even integer.
return nearbyint_wrapper(a);
});
});
}
void round_decimals_kernel_cuda(TensorIteratorBase& iter, int64_t decimals) {
AT_DISPATCH_FLOATING_TYPES_AND2(
ScalarType::Half, ScalarType::BFloat16,
iter.dtype(), "round_cuda",
[&]() {
bool neg_flag = false;
scalar_t ten_pow_decimals;
if (decimals < 0) {
decimals = -decimals;
neg_flag = true;
}
ten_pow_decimals = static_cast<scalar_t>(std::pow(10, decimals));
gpu_kernel(iter, [ten_pow_decimals, neg_flag]GPU_LAMBDA(scalar_t a) -> scalar_t {
return neg_flag ? std::nearbyint(a / ten_pow_decimals) * ten_pow_decimals
: std::nearbyint(a * ten_pow_decimals) / ten_pow_decimals;
});
});
}
// We manually overload trunc because std::trunc does not work with std::complex types and ROCm.
template <typename scalar_t>
__host__ __device__ static inline scalar_t trunc_wrapper(scalar_t a) {
return static_cast<scalar_t>(::truncf(static_cast<float>(a)));
}
__host__ __device__ static inline double trunc_wrapper(double a) {
return ::trunc(a);
}
__host__ __device__ static inline c10::complex<float> trunc_wrapper(c10::complex<float> a) {
return c10::complex<float>(::truncf(static_cast<float>(a.real())), ::truncf(static_cast<float>(a.imag())));
}
__host__ __device__ static inline c10::complex<double> trunc_wrapper(c10::complex<double> a) {
return c10::complex<double>(::trunc(static_cast<double>(a.real())), ::trunc(static_cast<double>(a.imag())));
}
void trunc_kernel_cuda(TensorIteratorBase& iter) {
AT_DISPATCH_FLOATING_TYPES_AND2(
ScalarType::Half, ScalarType::BFloat16,
iter.dtype(), "trunc_cuda",
[&]() {
gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
return trunc_wrapper(a);
});
});
}
REGISTER_DISPATCH(ceil_stub, &ceil_kernel_cuda);
REGISTER_DISPATCH(frac_stub, &frac_kernel_cuda);
REGISTER_DISPATCH(floor_stub, &floor_kernel_cuda);
REGISTER_DISPATCH(reciprocal_stub, &reciprocal_kernel_cuda);
REGISTER_DISPATCH(round_stub, &round_kernel_cuda);
REGISTER_DISPATCH(round_decimals_stub, &round_decimals_kernel_cuda);
REGISTER_DISPATCH(trunc_stub, &trunc_kernel_cuda);
}} // namespace at::native