blob: 89f0992e30f14b9cdd7a82fa37e18d3130ea1432 [file] [log] [blame] [edit]
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/
#pragma once
#include <cstdint>
#include <type_traits>
#include <executorch/kernels/optimized/blas/BlasKernel.h>
#include <executorch/runtime/core/exec_aten/exec_aten.h>
namespace executorch {
namespace cpublas {
enum class TransposeType {
NoTranspose,
Transpose,
ConjTranspose,
};
// clang-format off
void normalize_last_dims(
TransposeType transa, TransposeType transb,
int64_t m, int64_t n, int64_t k,
int64_t *lda, int64_t *ldb, int64_t *ldc);
// clang-format on
inline char to_blas(TransposeType trans) {
switch (trans) {
case TransposeType::Transpose:
return 'T';
case TransposeType::NoTranspose:
return 'N';
case TransposeType::ConjTranspose:
return 'C';
}
// Assume no transpose by default
return 'N';
}
// clang-format off
template <typename scalar_t, typename opmath_t>
void gemm_impl(
TransposeType transa, TransposeType transb,
int64_t m, int64_t n, int64_t k,
opmath_t alpha,
const scalar_t *a, int64_t lda,
const scalar_t *b, int64_t ldb,
opmath_t beta,
scalar_t *c, int64_t ldc) {
if (transa == TransposeType::NoTranspose &&
transb == TransposeType::NoTranspose) {
return gemm_notrans_(m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
} else if (
transa == TransposeType::Transpose &&
transb != TransposeType::Transpose) {
gemm_transa_(m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
} else if (
transa == TransposeType::NoTranspose &&
transb == TransposeType::Transpose) {
gemm_transb_(m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
} else { // transa == TransposeType::Transpose && transb ==
// TransposeType::Transpose
gemm_transab_(m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
}
}
// clang-format on
// clang-format off
void gemm(
TransposeType transa, TransposeType transb,
int64_t m, int64_t n, int64_t k,
double alpha,
const double *a, int64_t lda,
const double *b, int64_t ldb,
double beta,
double *c, int64_t ldc);
// clang-format on
// clang-format off
void gemm(
TransposeType transa, TransposeType transb,
int64_t m, int64_t n, int64_t k,
const float alpha,
const float *a, int64_t lda,
const float *b, int64_t ldb,
const float beta,
float *c, int64_t ldc);
// clang-format on
// clang-format off
void gemm(
TransposeType transa, TransposeType transb,
int64_t m, int64_t n, int64_t k,
const exec_aten::Half alpha,
const exec_aten::Half *a, int64_t lda,
const exec_aten::Half *b, int64_t ldb,
const exec_aten::Half beta,
exec_aten::Half *c, int64_t ldc);
void gemm(
TransposeType transa, TransposeType transb,
int64_t m, int64_t n, int64_t k,
const exec_aten::BFloat16 alpha,
const exec_aten::BFloat16 *a, int64_t lda,
const exec_aten::BFloat16 *b, int64_t ldb,
const exec_aten::BFloat16 beta,
exec_aten::BFloat16 *c, int64_t ldc);
// clang-format on
// clang-format off
template <typename T,
typename std::enable_if<std::is_integral<T>::value, int>::type = 0>
void gemm(
TransposeType transa, TransposeType transb,
int64_t m, int64_t n, int64_t k,
const T alpha,
const T *a, int64_t lda,
const T *b, int64_t ldb,
const T beta,
T *c, int64_t ldc) {
normalize_last_dims(transa, transb, m, n, k, &lda, &ldb, &ldc);
using acc_type = utils::compute_dtype<T>;
gemm_impl(
transa, transb,
m, n, k,
static_cast<const acc_type>(alpha),
a, lda,
b, ldb,
static_cast<const acc_type>(beta),
c, ldc);
}
// clang-format on
} // namespace cpublas
} // namespace executorch