blob: 88b9f7c2efab3e80b91ba091cdda1e337832dfb7 [file] [log] [blame]
#ifndef CAFFE2_UTILS_MATH_UTILS_H_
#define CAFFE2_UTILS_MATH_UTILS_H_
#include <vector>
#include "caffe2/core/common.h"
#if defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__) || \
defined(__HIP__) || (defined(__clang__) && defined(__CUDA__))
#define MATH_UTILS_DECL inline __host__ __device__
#else
#define MATH_UTILS_DECL inline
#endif
namespace caffe2 {
namespace math {
namespace utils {
template <typename T>
MATH_UTILS_DECL T Not(const T x) {
return !x;
}
template <typename T>
MATH_UTILS_DECL T Sign(const T x) {
return x > 0 ? T(1) : (x < 0 ? T(-1) : T(0));
}
template <typename T>
MATH_UTILS_DECL T Negate(const T x) {
return -x;
}
template <typename T>
MATH_UTILS_DECL T Inv(const T x) {
return T(1) / x;
}
template <typename T>
MATH_UTILS_DECL T Square(const T x) {
return x * x;
}
template <typename T>
MATH_UTILS_DECL T Cube(const T x) {
return x * x * x;
}
// Function uses casting from int to unsigned to compare if value of
// parameter a is greater or equal to zero and lower than value of
// parameter b. The b parameter is of type signed and is always
// positive,
// therefore its value is always lower than 0x800... where casting
// negative value of a parameter converts it to value higher than
// 0x800...
// The casting allows to use one condition instead of two.
MATH_UTILS_DECL bool IsAGeZeroAndALtB(const int a, const int b) {
return static_cast<unsigned int>(a) < static_cast<unsigned int>(b);
}
// Increase the index digits by one based on dims.
template <typename TIndex>
TORCH_API void
IncreaseIndexInDims(int ndim, const TIndex* dims, TIndex* index);
// Get index value from dims and index digits.
template <typename TIndex>
TORCH_API TIndex
GetIndexFromDims(const int n, const TIndex* dims, const TIndex* index);
// Checks if the input permutation is an identity permutation;
TORCH_API bool IsIdentityPermutation(const int n, const int* perm);
TORCH_API bool
CheckReduceDims(const int ndim, const int* X_dims, const int* Y_dims);
TORCH_API bool IsRowwiseReduce(
const int ndim,
const int* X_dims,
const int* Y_dims,
int* rows,
int* cols);
TORCH_API bool IsColwiseReduce(
const int ndim,
const int* X_dims,
const int* Y_dims,
int* rows,
int* cols);
TORCH_API bool IsBothEndsReduce(
const int ndim,
const int* X_dims,
const int* Y_dims,
int* pre,
int* mid,
int* nxt);
// Computest the broadcast binary operation dims.
template <typename TIndex>
TORCH_API void ComputeBroadcastBinaryOpDims(
const int A_ndim,
const TIndex* A_dims,
const int B_ndim,
const TIndex* B_dims,
TIndex* A_broadcast_dims,
TIndex* B_broadcast_dims,
TIndex* C_broadcast_dims);
TORCH_API bool IsRowwiseBroadcastBinaryOp(
const int ndim,
const int* A_dims,
const int* B_dims,
int* rows,
int* cols,
bool* broadcast_1st);
TORCH_API bool IsColwiseBroadcastBinaryOp(
const int ndim,
const int* A_dims,
const int* B_dims,
int* rows,
int* cols,
bool* broadcast_1st);
TORCH_API bool IsBothEndsBroadcastBinaryOp(
const int ndim,
const int* A_dims,
const int* B_dims,
int* pre,
int* mid,
int* nxt,
bool* broadcast_1st);
TORCH_API bool IsBatchTranspose2D(const int ndim, const int* axes);
TORCH_API void ComputeTransposeAxesForReduceOp(
const int num_dims,
const int num_reduce_axes,
const int* reduce_axes,
int* transpose_axes);
TORCH_API void
ComputeTransposeAxesForReduceOp(const int ndim, const int* dims, int* axes);
template <typename TIndex>
TORCH_API void ComputeTransposedStrides(
int ndim,
const TIndex* dims,
const int* axes,
TIndex* strides);
} // namespace utils
// Calculates ceil(a / b). User must be careful to ensure that there
// is no overflow or underflow in the calculation.
template <typename T>
constexpr T DivUp(const T a, const T b) {
return (a + b - T(1)) / b;
}
// Rounds a up to the next highest multiple of b. User must be careful
// to ensure that there is no overflow or underflow in the calculation
// of divUp.
template <typename T>
constexpr T RoundUp(const T a, const T b) {
return DivUp<T>(a, b) * b;
}
// Returns log2(n) for a positive integer type
template <typename T>
constexpr int IntegerLog2(T n, int p = 0) {
return (n <= 1) ? p : IntegerLog2(n / 2, p + 1);
}
// Returns the next highest power-of-2 for an integer type
template <typename T>
constexpr T IntegerNextHighestPowerOf2(T v) {
return (IntegerIsPowerOf2(v) ? T(2) * v : (T(1) << (IntegerLog2(v) + 1)));
}
} // namespace math
} // namespace caffe2
#endif // CAFFE2_UTILS_MATH_UTILS_H_