| #pragma once |
| |
| #include <c10/util/BFloat16.h> |
| #include <c10/util/Half.h> |
| #include <c10/util/math_compat.h> |
| |
| C10_CLANG_DIAGNOSTIC_PUSH() |
| #if C10_CLANG_HAS_WARNING("-Wimplicit-float-conversion") |
| C10_CLANG_DIAGNOSTIC_IGNORE("-Wimplicit-float-conversion") |
| #endif |
| |
| namespace std { |
| |
| template <typename T> |
| struct is_reduced_floating_point |
| : std::integral_constant< |
| bool, |
| std::is_same<T, c10::Half>::value || |
| std::is_same<T, c10::BFloat16>::value> {}; |
| |
| template <typename T> |
| constexpr bool is_reduced_floating_point_v = |
| is_reduced_floating_point<T>::value; |
| |
| template < |
| typename T, |
| typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0> |
| inline T acos(T a) { |
| return std::acos(float(a)); |
| } |
| template < |
| typename T, |
| typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0> |
| inline T asin(T a) { |
| return std::asin(float(a)); |
| } |
| template < |
| typename T, |
| typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0> |
| inline T atan(T a) { |
| return std::atan(float(a)); |
| } |
| template < |
| typename T, |
| typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0> |
| inline T atanh(T a) { |
| return std::atanh(float(a)); |
| } |
| template < |
| typename T, |
| typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0> |
| inline T erf(T a) { |
| return std::erf(float(a)); |
| } |
| template < |
| typename T, |
| typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0> |
| inline T erfc(T a) { |
| return std::erfc(float(a)); |
| } |
| template < |
| typename T, |
| typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0> |
| inline T exp(T a) { |
| return std::exp(float(a)); |
| } |
| template < |
| typename T, |
| typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0> |
| inline T expm1(T a) { |
| return std::expm1(float(a)); |
| } |
| template < |
| typename T, |
| typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0> |
| inline T log(T a) { |
| return std::log(float(a)); |
| } |
| template < |
| typename T, |
| typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0> |
| inline T log10(T a) { |
| return std::log10(float(a)); |
| } |
| template < |
| typename T, |
| typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0> |
| inline T log1p(T a) { |
| return std::log1p(float(a)); |
| } |
| template < |
| typename T, |
| typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0> |
| inline T log2(T a) { |
| return std::log2(float(a)); |
| } |
| template < |
| typename T, |
| typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0> |
| inline T ceil(T a) { |
| return std::ceil(float(a)); |
| } |
| template < |
| typename T, |
| typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0> |
| inline T cos(T a) { |
| return std::cos(float(a)); |
| } |
| template < |
| typename T, |
| typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0> |
| inline T floor(T a) { |
| return std::floor(float(a)); |
| } |
| template < |
| typename T, |
| typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0> |
| inline T nearbyint(T a) { |
| return std::nearbyint(float(a)); |
| } |
| template < |
| typename T, |
| typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0> |
| inline T sin(T a) { |
| return std::sin(float(a)); |
| } |
| template < |
| typename T, |
| typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0> |
| inline T tan(T a) { |
| return std::tan(float(a)); |
| } |
| template < |
| typename T, |
| typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0> |
| inline T sinh(T a) { |
| return std::sinh(float(a)); |
| } |
| template < |
| typename T, |
| typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0> |
| inline T cosh(T a) { |
| return std::cosh(float(a)); |
| } |
| template < |
| typename T, |
| typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0> |
| inline T tanh(T a) { |
| return std::tanh(float(a)); |
| } |
| template < |
| typename T, |
| typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0> |
| inline T trunc(T a) { |
| return std::trunc(float(a)); |
| } |
| template < |
| typename T, |
| typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0> |
| inline T lgamma(T a) { |
| return std::lgamma(float(a)); |
| } |
| template < |
| typename T, |
| typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0> |
| inline T sqrt(T a) { |
| return std::sqrt(float(a)); |
| } |
| template < |
| typename T, |
| typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0> |
| inline T rsqrt(T a) { |
| return 1.0 / std::sqrt(float(a)); |
| } |
| template < |
| typename T, |
| typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0> |
| inline T abs(T a) { |
| return std::abs(float(a)); |
| } |
| #if defined(_MSC_VER) && defined(__CUDACC__) |
| template < |
| typename T, |
| typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0> |
| inline T pow(T a, double b) { |
| return std::pow(float(a), float(b)); |
| } |
| #else |
| template < |
| typename T, |
| typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0> |
| inline T pow(T a, double b) { |
| return std::pow(float(a), b); |
| } |
| #endif |
| template < |
| typename T, |
| typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0> |
| inline T pow(T a, T b) { |
| return std::pow(float(a), float(b)); |
| } |
| template < |
| typename T, |
| typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0> |
| inline T fmod(T a, T b) { |
| return std::fmod(float(a), float(b)); |
| } |
| |
| /* |
| The following function is inspired from the implementation in `musl` |
| Link to License: https://git.musl-libc.org/cgit/musl/tree/COPYRIGHT |
| ---------------------------------------------------------------------- |
| Copyright © 2005-2020 Rich Felker, et al. |
| |
| Permission is hereby granted, free of charge, to any person obtaining |
| a copy of this software and associated documentation files (the |
| "Software"), to deal in the Software without restriction, including |
| without limitation the rights to use, copy, modify, merge, publish, |
| distribute, sublicense, and/or sell copies of the Software, and to |
| permit persons to whom the Software is furnished to do so, subject to |
| the following conditions: |
| |
| The above copyright notice and this permission notice shall be |
| included in all copies or substantial portions of the Software. |
| |
| THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, |
| EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF |
| MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. |
| IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY |
| CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, |
| TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE |
| SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. |
| ---------------------------------------------------------------------- |
| */ |
| template < |
| typename T, |
| typename std::enable_if_t<is_reduced_floating_point_v<T>, int> = 0> |
| C10_HOST_DEVICE inline T nextafter(T from, T to) { |
| // Reference: |
| // https://git.musl-libc.org/cgit/musl/tree/src/math/nextafter.c |
| using int_repr_t = uint16_t; |
| using float_t = T; |
| constexpr uint8_t bits = 16; |
| union { |
| float_t f; |
| int_repr_t i; |
| } ufrom = {from}, uto = {to}; |
| |
| // get a mask to get the sign bit i.e. MSB |
| int_repr_t sign_mask = int_repr_t{1} << (bits - 1); |
| |
| // short-circuit: if either is NaN, return NaN |
| if (from != from || to != to) { |
| return from + to; |
| } |
| |
| // short-circuit: if they are exactly the same. |
| if (ufrom.i == uto.i) { |
| return from; |
| } |
| |
| // mask the sign-bit to zero i.e. positive |
| // equivalent to abs(x) |
| int_repr_t abs_from = ufrom.i & ~sign_mask; |
| int_repr_t abs_to = uto.i & ~sign_mask; |
| if (abs_from == 0) { |
| // if both are zero but with different sign, |
| // preserve the sign of `to`. |
| if (abs_to == 0) { |
| return to; |
| } |
| // smallest subnormal with sign of `to`. |
| ufrom.i = (uto.i & sign_mask) | int_repr_t{1}; |
| return ufrom.f; |
| } |
| |
| // if abs(from) > abs(to) or sign(from) != sign(to) |
| if (abs_from > abs_to || ((ufrom.i ^ uto.i) & sign_mask)) { |
| ufrom.i--; |
| } else { |
| ufrom.i++; |
| } |
| |
| return ufrom.f; |
| } |
| |
| } // namespace std |
| |
| C10_CLANG_DIAGNOSTIC_POP() |