| #define TORCH_ASSERT_NO_OPERATORS |
| #define _USE_MATH_DEFINES |
| |
| #include <ATen/native/Activation.h> |
| |
| #include <cmath> |
| |
| #include <thrust/tuple.h> |
| |
| #include <ATen/AccumulateType.h> |
| #include <ATen/Dispatch.h> |
| #include <ATen/core/TensorBase.h> |
| #include <c10/core/Scalar.h> |
| #include <c10/cuda/CUDAMathCompat.h> |
| #include <ATen/cuda/ApplyGridUtils.cuh> |
| #include <ATen/cuda/detail/OffsetCalculator.cuh> |
| #include <ATen/native/cuda/Loops.cuh> |
| |
| namespace at::native { |
| |
| void GeluCUDAKernelImpl(TensorIteratorBase& it, GeluType approximate) { |
| if (approximate == GeluType::Tanh) { |
| AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, it.dtype(), "GeluCUDAKernelImpl", [&]() { |
| gpu_kernel(it, [] GPU_LAMBDA(scalar_t x) -> scalar_t { |
| using opmath_t = at::opmath_type<scalar_t>; |
| constexpr opmath_t kBeta = M_SQRT2 * M_2_SQRTPI * opmath_t(0.5); |
| constexpr opmath_t kKappa = 0.044715; |
| auto x_cube = static_cast<opmath_t>(x) * static_cast<opmath_t>(x) * static_cast<opmath_t>(x); |
| auto inner = kBeta * (static_cast<opmath_t>(x) + kKappa * x_cube); |
| return opmath_t(0.5) * static_cast<opmath_t>(x) * (opmath_t(1) + c10::cuda::compat::tanh(inner)); |
| }); |
| }); |
| } else { |
| AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, it.dtype(), "GeluCUDAKernelImpl", [&]() { |
| gpu_kernel(it, [] GPU_LAMBDA(scalar_t x) -> scalar_t { |
| using opmath_t = at::opmath_type<scalar_t>; |
| constexpr opmath_t kAlpha = M_SQRT1_2; |
| return static_cast<opmath_t>(x) * opmath_t(0.5) * (opmath_t(1) + ::erf(static_cast<opmath_t>(x) * kAlpha)); |
| }); |
| }); |
| } |
| } |
| |
| void GeluBackwardCUDAKernelImpl(TensorIteratorBase& it, GeluType approximate) { |
| if (approximate == GeluType::Tanh) { |
| AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, |
| it.dtype(), "GeluBackwardCUDAKernelImpl", [&]() { |
| gpu_kernel(it, [] GPU_LAMBDA(scalar_t dy, scalar_t x) -> scalar_t { |
| using opmath_t = at::opmath_type<scalar_t>; |
| constexpr opmath_t kBeta = M_SQRT2 * M_2_SQRTPI * opmath_t(0.5); |
| constexpr opmath_t kKappa = 0.044715; |
| auto x_sq = static_cast<opmath_t>(x) * static_cast<opmath_t>(x); |
| auto x_cube = x_sq * static_cast<opmath_t>(x); |
| auto inner = kBeta * (static_cast<opmath_t>(x) + kKappa * x_cube); |
| auto tanh_inner = c10::cuda::compat::tanh(inner); |
| |
| auto left = opmath_t(0.5) * static_cast<opmath_t>(x); |
| auto right = opmath_t(1) + tanh_inner; |
| |
| auto left_derivative = opmath_t(0.5) * right; |
| |
| auto tanh_derivative = opmath_t(1) - tanh_inner * tanh_inner; |
| auto inner_derivative = kBeta * (opmath_t(1) + opmath_t(3) * kKappa * x_sq); |
| auto right_derivative = left * tanh_derivative * inner_derivative; |
| |
| return static_cast<opmath_t>(dy) * (left_derivative + right_derivative); |
| }); |
| }); |
| } else { |
| AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, |
| it.dtype(), "GeluBackwardCUDAKernelImpl", [&]() { |
| gpu_kernel(it, [] GPU_LAMBDA(scalar_t dy, scalar_t x) -> scalar_t { |
| using opmath_t = at::opmath_type<scalar_t>; |
| constexpr opmath_t kBeta = M_2_SQRTPI * M_SQRT1_2 * opmath_t(0.5); |
| constexpr opmath_t kAlpha = M_SQRT1_2; |
| const opmath_t cdf = |
| opmath_t(0.5) * (opmath_t(1) + ::erf(static_cast<opmath_t>(x) * kAlpha)); |
| const opmath_t pdf = |
| c10::cuda::compat::exp( |
| opmath_t(-0.5) * static_cast<opmath_t>(x) * static_cast<opmath_t>(x)) * |
| kBeta; |
| return static_cast<opmath_t>(dy) * (cdf + static_cast<opmath_t>(x) * pdf); |
| }); |
| }); |
| } |
| } |
| |
| } // namespace at::native |