blob: fcfcd2e5ebbdb3439f848b3f42a4d7709d7017a3 [file] [log] [blame]
#if (defined(CUDA_VERSION) && CUDA_VERSION >= 12000) && (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 800))
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#include <cuda_runtime.h>
#include <mma.h>
#endif
#include <ATen/ATen.h>
#include <ATen/core/Tensor.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/DeviceGuard.h>
#include <c10/cuda/CUDAGuard.h>
namespace at::native {
template <typename U, typename V>
constexpr __host__ __device__ auto divDown(U a, V b) -> decltype(a + b) {
static_assert(std::is_integral<U>::value && std::is_integral<V>::value, "");
return (a / b);
}
template <typename U, typename V>
constexpr __host__ __device__ auto divUp(U a, V b) -> decltype(a + b) {
static_assert(std::is_integral<U>::value && std::is_integral<V>::value, "");
// Overflow safe variant of (a + b - 1) / b
const uint64_t blocks = a / b + (a % b != 0);
return blocks;
}
template <typename U, typename V>
constexpr __host__ __device__ auto roundDown(U a, V b) -> decltype(a + b) {
static_assert(std::is_integral<U>::value && std::is_integral<V>::value, "");
return divDown(a, b) * b;
}
template <typename U, typename V>
constexpr __host__ __device__ auto roundUp(U a, V b) -> decltype(a + b) {
static_assert(std::is_integral<U>::value && std::is_integral<V>::value, "");
return divUp(a, b) * b;
}
template <typename U, typename V>
constexpr __host__ __device__ bool isEvenDivisor(U a, V b) {
static_assert(std::is_integral<U>::value && std::is_integral<V>::value, "");
return (a % V(b) == 0) && ((a / V(b)) >= 1);
}
template <class T>
constexpr __host__ __device__ T pow(T n, int power) {
return (power > 0 ? n * pow(n, power - 1) : 1);
}
template <class T>
constexpr __host__ __device__ T pow2(int power) {
return pow(2, power);
}
static_assert(pow2<int>(8) == 256, "pow2");
template <typename T>
constexpr __host__ __device__ int log2(T n, int p = 0) {
return (n <= 1) ? p : log2(n / 2, p + 1);
}
static_assert(log2(2) == 1, "log2");
static_assert(log2(3) == 1, "log2");
static_assert(log2(4) == 2, "log2");
template <typename T>
constexpr __host__ __device__ bool isPowerOf2(T v) {
static_assert(std::is_integral<T>::value, "");
return (v && !(v & (v - 1)));
}
static_assert(isPowerOf2(2048), "isPowerOf2");
static_assert(!isPowerOf2(3333), "isPowerOf2");
template <typename T>
constexpr __host__ __device__ T nextHighestPowerOf2(T v) {
static_assert(std::is_integral<T>::value, "");
return (isPowerOf2(v) ? (T)2 * v : ((T)1 << (log2(v) + 1)));
}
static_assert(nextHighestPowerOf2(1) == 2, "nextHighestPowerOf2");
static_assert(nextHighestPowerOf2(2) == 4, "nextHighestPowerOf2");
static_assert(nextHighestPowerOf2(3) == 4, "nextHighestPowerOf2");
static_assert(nextHighestPowerOf2(4) == 8, "nextHighestPowerOf2");
static_assert(nextHighestPowerOf2(15) == 16, "nextHighestPowerOf2");
static_assert(nextHighestPowerOf2(16) == 32, "nextHighestPowerOf2");
static_assert(nextHighestPowerOf2(17) == 32, "nextHighestPowerOf2");
static_assert(
nextHighestPowerOf2(1536000000u) == 2147483648u,
"nextHighestPowerOf2");
static_assert(
nextHighestPowerOf2((size_t)2147483648ULL) == (size_t)4294967296ULL,
"nextHighestPowerOf2");
template <typename T>
constexpr __host__ __device__ T nextLowestPowerOf2(T v) {
static_assert(std::is_integral<T>::value, "");
return (isPowerOf2(v) ? v / (T)2 : ((T)1 << (log2(v))));
}
static_assert(nextLowestPowerOf2(1) == 0, "nextLowestPowerOf2");
static_assert(nextLowestPowerOf2(2) == 1, "nextLowestPowerOf2");
static_assert(nextLowestPowerOf2(3) == 2, "nextLowestPowerOf2");
static_assert(nextLowestPowerOf2(4) == 2, "nextLowestPowerOf2");
static_assert(nextLowestPowerOf2(15) == 8, "nextLowestPowerOf2");
static_assert(nextLowestPowerOf2(16) == 8, "nextLowestPowerOf2");
static_assert(nextLowestPowerOf2(17) == 16, "nextLowestPowerOf2");
inline __host__ __device__ bool isPointerAligned(const void* p, int align) {
return reinterpret_cast<uintptr_t>(p) % align == 0;
}
// Returns the increment needed to aligned the pointer to the next highest
// aligned address
template <int Align>
inline __host__ __device__ uint32_t getAlignmentRoundUp(const void* p) {
static_assert(isPowerOf2(Align), "");
const uint32_t diff = uint32_t(uintptr_t(p) & uintptr_t(Align - 1));
return diff == 0 ? 0 : uint32_t(Align) - diff;
}
constexpr int32_t kWarpSize = 32;
#if (defined(CUDA_VERSION) && CUDA_VERSION >= 12000) && (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 800))
// f16 vector types
struct __align__(2) f16x1 {
__half vals[1];
};
struct __align__(4) f16x2 {
__half vals[2];
};
struct __align__(8) f16x4 {
__half vals[4];
};
struct __align__(16) f16x8 {
__half vals[8];
};
// bf16 vector types
struct __align__(2) bf16x1 {
__nv_bfloat16 vals[1];
};
struct __align__(4) bf16x2 {
__nv_bfloat16 vals[2];
};
struct __align__(8) bf16x4 {
__nv_bfloat16 vals[4];
};
struct __align__(16) bf16x8 {
__nv_bfloat16 vals[8];
};
// bf162 vector types
struct __align__(4) bf16x2x1 {
__nv_bfloat162 vals[1];
};
struct __align__(8) bf16x2x2 {
__nv_bfloat162 vals[2];
};
struct __align__(16) bf16x2x4 {
__nv_bfloat162 vals[4];
};
struct __align__(16) bf16x2x4_u32 {
uint32_t vals[4];
};
struct __align__(8) bf16x2x2_u32 {
uint32_t vals[2];
};
struct __align__(4) bf16x2x1_u32 {
uint32_t vals[1];
};
template <typename T, int N>
struct __align__(sizeof(T) * N) VectorType {
T vals[N];
};
// from
// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
inline __device__ bf16x2x4 convert_i4x8_to_bf16x2x4(uint32_t source) {
bf16x2x4 result;
constexpr int kElements = 8;
uint32_t* h = reinterpret_cast<uint32_t*>(&result);
uint32_t const source_i4s = source;
// First, we extract the i4s and construct an intermediate fp16 number.
static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa;
static constexpr uint32_t MASK = 0x000f000f;
static constexpr uint32_t I4s_TO_BF16s_MAGIC_NUM = 0x43004300;
// We don't have enough mantissa to remove as much shift overhead as FP16, so
// we must loop. No shift needed for first item.
uint32_t i4s = source_i4s;
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(h[0])
: "r"(i4s), "n"(MASK), "n"(I4s_TO_BF16s_MAGIC_NUM), "n"(immLut));
#pragma unroll
for (int ii = 1; ii < kElements / 2; ++ii) {
i4s >>= 4; // or is it 8?
// (i4s & 0x000f000f) | 0x43004300
asm volatile(
"lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(h[ii])
: "r"(i4s), "n"(MASK), "n"(I4s_TO_BF16s_MAGIC_NUM), "n"(immLut));
}
// This is the BF16 {-136, -136} represented as an integer.
static constexpr uint32_t BF16_BIAS = 0xC308C308;
static constexpr uint32_t BF16_ONE = 0x3F803F80;
// Finally, we construct the output numbers.
#pragma unroll
for (int ii = 0; ii < kElements / 2; ++ii) {
// Since this section is for Ampere+, we use bf16 fma to do the bias
// subtraction
asm("fma.rn.bf16x2 %0, %1, %2, %3;\n"
: "=r"(h[ii])
: "r"(h[ii]), "r"(BF16_ONE), "r"(BF16_BIAS));
}
return result;
}
enum class KReductionType {
// No k-reduction is needed between blocks as the number of k-tiles processed
// per block are exact and we can directly write the output
None,
};
// Loads the A matrix in 16-bit standard m x k row major layout, and writes
// the C matrix in 16-bit standard m x n row major layout:
//
// size [m][k]
template <KReductionType ReduceType>
struct ALayout_RM {
static constexpr int32_t kMTileSize = 16;
static constexpr int32_t kNTileSize = 8;
static constexpr int32_t kKTileSize = 16;
template <int KTilesToLoad>
static __device__ void load(
const void* A,
int32_t m,
int32_t k,
int32_t mTiles,
int32_t mTile,
int32_t kTiles,
int32_t kTileStart,
int32_t laneId,
bf16x2x4_u32 out[KTilesToLoad]) {
const auto mLane = mTile * kMTileSize + (laneId / 4);
const auto kLane = kTileStart * kKTileSize + (laneId % 4) * 2;
// access
// [mTile * kMTileSize + (laneId / 4)]
// [kTileStart * kKTileSize + (laneId % 4) * 2]
auto aPtr = reinterpret_cast<const __nv_bfloat16*>(A) + mLane * k + kLane;
auto aPtrPlus8Rows = aPtr + 8 * k;
bool m0InBounds = mLane < m;
bool m1InBounds = (mLane + 8) < m;
#pragma unroll
for (int i = 0; i < KTilesToLoad; ++i) {
out[i].vals[0] = m0InBounds
? *reinterpret_cast<const uint32_t*>(aPtr + i * kKTileSize)
: uint32_t(0);
out[i].vals[1] = m1InBounds
? *reinterpret_cast<const uint32_t*>(aPtrPlus8Rows + i * kKTileSize)
: uint32_t(0);
out[i].vals[2] = m0InBounds
? *reinterpret_cast<const uint32_t*>(aPtr + i * kKTileSize + 8)
: uint32_t(0);
out[i].vals[3] = m1InBounds ? *reinterpret_cast<const uint32_t*>(
aPtrPlus8Rows + i * kKTileSize + 8)
: uint32_t(0);
}
}
static __device__ void store(
void* C,
int32_t m,
int32_t n,
int32_t mOutTiles,
int32_t mTile,
int32_t nOutTiles,
int32_t nTile,
int32_t laneId,
const float4& out) {
static_assert(ReduceType == KReductionType::None, "");
if constexpr (ReduceType == KReductionType::None) {
// sum.x / sum.y are written at
// [laneId / 4], [(laneId % 4) * 2, (laneId % 4) * 2 + 1]
// sum.z / sum.w are written at
// [8 + (laneId / 4)], [(laneId % 4) * 2, (laneId % 4) * 2 + 1]
// i.e., same columns, different row.
const int outRow = mTile * kMTileSize + (laneId / 4);
const int outCol = nTile * kNTileSize + (laneId % 4) * 2;
// Pointer where sum.x / sum.y is written
auto cPtr = reinterpret_cast<__nv_bfloat16*>(C) + outRow * n + outCol;
auto v01 = __float22bfloat162_rn(float2{out.x, out.y});
auto v23 = __float22bfloat162_rn(float2{out.z, out.w});
if (outRow < m) {
*reinterpret_cast<__nv_bfloat162*>(cPtr) = v01;
}
// sum.z, sum.w at +8 rows from cPtr
if (outRow + 8 < m) {
*reinterpret_cast<__nv_bfloat162*>(cPtr + 8 * n) = v23;
}
}
}
};
template <int InnerKTiles, int QGroupSize>
struct BLayout_TC_int4 {
static constexpr int32_t kInnerKTiles = InnerKTiles;
static constexpr int32_t kMTileSize = 16;
static constexpr int32_t kNTileSize = 8;
static constexpr int32_t kKTileSize = 16;
template <int KTilesToLoad>
static __device__ void load(
// type uint32, size [n / 8][k / (InnerKTiles * 16)][32][InnerKTiles / 2]
// n / 8: n-tiles (n8)
// k / (InnerKTiles * 16): TC size per k-tile is 16 (m16n8k16)
// 32: value per warp lane
// (InnerKTiles / 2): B layout has 4 values per lane (16 bits) per k-tile.
// 2 k-tiles packed is a uint32 (hence InnerKTiles == 2 is our smallest
// value) 4 k-tiles packed is a uint32x2 (64 bits) 8 k-tiles packed is a
// uint32x4 (128 bits)
const void* __restrict__ B,
// size [k / qGroupSize][n][2]
// Contains the scale and zero point of each of the quantized int4 values
// within B
// v_reconstructed = (bf16(B_int4_val) * scale) - zero
const void* __restrict__ quantizationInfo,
int32_t n,
int32_t k,
int32_t nTiles,
int32_t nTile,
int32_t kTiles,
int32_t kTileStart,
int32_t laneId,
bf16x2x4_u32 out[KTilesToLoad / InnerKTiles][InnerKTiles / 2]) {
// offset [nTile][kTileStart / InnerKTiles][laneId][0]
auto bPtr = reinterpret_cast<const int32_t*>(B) +
(((nTile * (kTiles / InnerKTiles) + (kTileStart / InnerKTiles)) *
kWarpSize) +
laneId) *
(InnerKTiles / 2);
int32_t b_int4[KTilesToLoad / InnerKTiles][InnerKTiles / 2];
#pragma unroll
for (int i = 0; i < KTilesToLoad / InnerKTiles; ++i) {
auto bPtrCur = bPtr + i * kWarpSize * (InnerKTiles / 2);
if constexpr (InnerKTiles == 2) {
b_int4[i][0] = bPtrCur[0];
}
if constexpr (InnerKTiles == 4) {
// asm volatile("ld.global.cs.v2.u32 {%0, %1}, [%2];\n"
// : "=r"(b_int4[i][0]), "=r"(b_int4[i][1])
// : "l"(bPtrCur));
int2 load8 = reinterpret_cast<const int2*>(bPtrCur)[0];
b_int4[i][0] = load8.x;
b_int4[i][1] = load8.y;
}
if constexpr (InnerKTiles == 8) {
// asm volatile("ld.global.cs.v4.u32 {%0, %1, %2, %3}, [%4];\n"
// : "=r"(b_int4[i][0]), "=r"(b_int4[i][1]),
// "=r"(b_int4[i][2]), "=r"(b_int4[i][3]) : "l"(bPtrCur));
int4 load16 = reinterpret_cast<const int4*>(bPtrCur)[0];
b_int4[i][0] = load16.x;
b_int4[i][1] = load16.y;
b_int4[i][2] = load16.z;
b_int4[i][3] = load16.w;
}
}
// Load needed info for dequantization
static_assert(isPowerOf2(QGroupSize), "");
static_assert(isEvenDivisor(QGroupSize, kKTileSize), "");
// smallest quantization group size is 32 (2 k-tiles are packed in an int32)
static_assert(QGroupSize >= kKTileSize * 2, "");
constexpr int kKTilesPerQGroup = (QGroupSize / kKTileSize);
// a q-group could be larger than what we are handling in a single warp
constexpr int kNumQGroups = (KTilesToLoad / kKTilesPerQGroup) < 1
? 1
: (KTilesToLoad / kKTilesPerQGroup);
__nv_bfloat162 qScaleAndZero[kNumQGroups];
{
int32_t laneN = nTile * kNTileSize + (laneId / 4);
int32_t groupStart = (kTileStart * kKTileSize) / QGroupSize;
int32_t n = nTiles * kNTileSize;
// offset [qScale_kGroup][qScale_n][0]
auto qInfoPtr = reinterpret_cast<const __nv_bfloat16*>(quantizationInfo) +
(groupStart * n + laneN) * 2;
#pragma unroll
for (int i = 0; i < kNumQGroups; ++i) {
qScaleAndZero[i] =
*reinterpret_cast<const __nv_bfloat162*>(qInfoPtr + i * n * 2);
}
}
//
// De-quantize int4 values to bf16. Values are dequantized as truly int4
// [-8, 7] range; dequant = (bf16(int4_value) * bf16_scale) + bf16_zero
//
{
// FIXME: does this negatively affect register counts, or will nvcc
// move this expansion (and data loads above) closer to the point of use?
__nv_bfloat162 qScale[kNumQGroups];
__nv_bfloat162 qZero[kNumQGroups];
#pragma unroll
for (int i = 0; i < kNumQGroups; ++i) {
qScale[i] = __bfloat162bfloat162(qScaleAndZero[i].x);
qZero[i] = __bfloat162bfloat162(qScaleAndZero[i].y);
}
#pragma unroll
for (int i = 0; i < KTilesToLoad / InnerKTiles; ++i) {
#pragma unroll
for (int j = 0; j < InnerKTiles / 2; ++j) {
bf16x2x4 v = convert_i4x8_to_bf16x2x4(b_int4[i][j]);
int curKTile = i * InnerKTiles + j * 2;
int curQGroup = (curKTile * kKTileSize) / QGroupSize;
// The dequantized values in `v` for a given lane have the same n
// dimension (the B tensor core layout has all values in the same
// thread along the same n) but different k dimension, but all are
// guaranteed to occur within the same quantization group, so we need
// only load a single scale + zero to cover what this lane has
#pragma unroll
for (int k = 0; k < 4; ++k) {
v.vals[k] = __hfma2(v.vals[k], qScale[curQGroup], qZero[curQGroup]);
}
// type pun, the __nv_bfloat162 value in bf16x2x4 is a struct and
// can't be used as a 32-bit asm register argument for `mma`
static_assert(sizeof(bf16x2x4) == sizeof(out[0][0]), "");
std::memcpy(&out[i][j], &v, sizeof(bf16x2x4_u32));
}
}
}
}
};
template <
typename ALayout,
typename BLayout,
typename CLayout,
int Warps,
int KTilesPerIteration>
__global__
__launch_bounds__(Warps* kWarpSize) void tinygemm_m16n8k16_chunk_kernel(
// Data for the A matrix, loaded as per ALayout
const void* const __restrict__ A,
// Data for the B matrix, loaded as per BLayout
const void* const __restrict__ B,
// Optional quantization data for dequantizing B, loaded as per BLayout
const void* const __restrict__ B_quantizationInfo,
// Output data for the C matrix, stored as per CLayout
void* __restrict__ C,
// The size of the matrix multiplication
int32_t m,
int32_t n,
int32_t k,
// The size of the matrix multiplication, in multiples of our TC tile size
int32_t mTiles,
int32_t nTiles,
int32_t kTiles) {
constexpr int32_t kMTileSize = 16;
constexpr int32_t kNTileSize = 8;
constexpr int32_t kKTileSize = 16;
static_assert(
ALayout::kMTileSize == kMTileSize && ALayout::kNTileSize == kNTileSize &&
ALayout::kKTileSize == kKTileSize,
"");
static_assert(
BLayout::kMTileSize == kMTileSize && BLayout::kNTileSize == kNTileSize &&
BLayout::kKTileSize == kKTileSize,
"");
static_assert(
CLayout::kMTileSize == kMTileSize && CLayout::kNTileSize == kNTileSize &&
CLayout::kKTileSize == kKTileSize,
"");
constexpr int kInnerKTiles = BLayout::kInnerKTiles;
// 2/4/8 inner k-tiles correspond to 4, 8 and 16 byte innermost loads
static_assert(
kInnerKTiles == 2 || kInnerKTiles == 4 || kInnerKTiles == 8, "");
// We always process at least kInnerKTiles k-tiles back to back in a warp
static_assert(
KTilesPerIteration >= kInnerKTiles &&
isEvenDivisor(KTilesPerIteration, kInnerKTiles),
"");
auto warpId = threadIdx.y;
auto laneId = threadIdx.x;
int32_t mTile = blockIdx.z;
int32_t nTile = blockIdx.y;
float4 c{0.0f, 0.0f, 0.0f, 0.0f};
// First, handle whole multiples of KTilesPerIteration
auto kTilesLimit = roundDown(kTiles, KTilesPerIteration);
// Each warp handles a set of KTilesPerIteration under the above limit
for (int32_t kTileBase = (blockIdx.x * Warps + warpId) * KTilesPerIteration;
kTileBase < kTilesLimit;
kTileBase += Warps * KTilesPerIteration) {
//
// Load data from A
//
bf16x2x4_u32 a[KTilesPerIteration];
ALayout::template load<KTilesPerIteration>(
A, m, k, mTiles, mTile, kTiles, kTileBase, laneId, a);
//
// Load data from B and de-quantize as needed
// Each k-tile is bf16x2x2
//
bf16x2x4_u32 b[KTilesPerIteration / kInnerKTiles][kInnerKTiles / 2];
BLayout::template load<KTilesPerIteration>(
B,
B_quantizationInfo,
n,
k,
nTiles,
nTile,
kTiles,
kTileBase,
laneId,
b);
//
// Now, perform the matrix multiplication
//
// We accumulate across k-tiles here
#pragma unroll
for (int i = 0; i < KTilesPerIteration / kInnerKTiles; ++i) {
static_assert(isEvenDivisor(kInnerKTiles, 2) && kInnerKTiles >= 2, "");
#pragma unroll
for (int j = 0; j < kInnerKTiles / 2; ++j) {
// We don't simply accumulate into `c` as this creates a too-strong
// execution dependency. Instead, we only periodically accumulate into
// `c`
float4 cTmp[2];
#pragma unroll
for (int k = 0; k < 2; ++k) {
cTmp[k] = float4{0.0f, 0.0f, 0.0f, 0.0f};
}
#pragma unroll
for (int k = 0; k < 2; ++k) {
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 "
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};"
: "=f"(cTmp[k].x),
"=f"(cTmp[k].y),
"=f"(cTmp[k].z),
"=f"(cTmp[k].w)
: "r"(a[i * kInnerKTiles + j * 2 + k].vals[0]),
"r"(a[i * kInnerKTiles + j * 2 + k].vals[1]),
"r"(a[i * kInnerKTiles + j * 2 + k].vals[2]),
"r"(a[i * kInnerKTiles + j * 2 + k].vals[3]),
"r"(b[i][(j * 2 + k) / 2].vals[((j * 2 + k) % 2) * 2 + 0]),
"r"(b[i][(j * 2 + k) / 2].vals[((j * 2 + k) % 2) * 2 + 1]),
"f"(cTmp[k].x),
"f"(cTmp[k].y),
"f"(cTmp[k].z),
"f"(cTmp[k].w));
}
#pragma unroll
for (int k = 0; k < 2; ++k) {
c.x += cTmp[k].x;
c.y += cTmp[k].y;
c.z += cTmp[k].z;
c.w += cTmp[k].w;
}
}
}
} // for all tiles under kTilesLimit
// Now, there could be a remainder of 1 to KTilesPerIteration - 1 k-tiles
// remaining. We guarantee that the number of warps is >= KTilesPerIteration /
// kInnerKTiles, so that each warp can simply load kInnerKTiles and do its
// thing without needing more warps
static_assert(Warps >= KTilesPerIteration / kInnerKTiles, "");
auto kTileBaseRemaining = kTilesLimit + warpId * kInnerKTiles;
// If we have any remainder k-tiles, some warps will handle them, processing
// kInnerKTiles k-tiles at a time
if (kTileBaseRemaining < kTiles) {
bf16x2x4_u32 a[kInnerKTiles];
ALayout::template load<kInnerKTiles>(
A, m, k, mTiles, mTile, kTiles, kTileBaseRemaining, laneId, a);
bf16x2x4_u32 b[1][kInnerKTiles / 2];
BLayout::template load<kInnerKTiles>(
B,
B_quantizationInfo,
n,
k,
nTiles,
nTile,
kTiles,
kTileBaseRemaining,
laneId,
b);
#pragma unroll
for (int j = 0; j < kInnerKTiles / 2; ++j) {
// We don't simply accumulate into `c` as this creates a too-strong
// execution dependency. Instead, we only periodically accumulate into
// `c`
float4 cTmp[2];
#pragma unroll
for (int k = 0; k < 2; ++k) {
cTmp[k] = float4{0.0f, 0.0f, 0.0f, 0.0f};
}
#pragma unroll
for (int k = 0; k < 2; ++k) {
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 "
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};"
: "=f"(cTmp[k].x), "=f"(cTmp[k].y), "=f"(cTmp[k].z), "=f"(cTmp[k].w)
: "r"(a[j * 2 + k].vals[0]),
"r"(a[j * 2 + k].vals[1]),
"r"(a[j * 2 + k].vals[2]),
"r"(a[j * 2 + k].vals[3]),
"r"(b[0][(j * 2 + k) / 2].vals[((j * 2 + k) % 2) * 2 + 0]),
"r"(b[0][(j * 2 + k) / 2].vals[((j * 2 + k) % 2) * 2 + 1]),
"f"(cTmp[k].x),
"f"(cTmp[k].y),
"f"(cTmp[k].z),
"f"(cTmp[k].w));
}
#pragma unroll
for (int k = 0; k < 2; ++k) {
c.x += cTmp[k].x;
c.y += cTmp[k].y;
c.z += cTmp[k].z;
c.w += cTmp[k].w;
}
}
}
//
// Reduce independent k-tiles (same m/n) across warps
//
__shared__ float4 smem_sum[Warps][kWarpSize];
// FIXME: this likely doesn't need to be a true reduction tree, can just be a
// serial sum, maybe (unless nvcc/ptxas goes back to its old ways)
// smem_sum[warpId][laneId] = TreeReduce4<KTilesPerIteration>::reduce(c);
smem_sum[warpId][laneId] = c;
__syncthreads();
if (warpId == 0) {
float4 sum_f32{0.0f, 0.0f, 0.0f, 0.0f};
// Reduce across the block in the first warp
for (int i = 0; i < Warps; ++i) {
float4 v = smem_sum[i][laneId];
sum_f32.x += v.x;
sum_f32.y += v.y;
sum_f32.z += v.z;
sum_f32.w += v.w;
}
// Write the reduced result (in the first warp) into the output
CLayout::store(
C,
m,
n,
mTiles,
mTile,
// n for C output becomes k for A input, so for m16n8k16,
// we need to halve the tiles
nTiles / 2,
nTile,
laneId,
sum_f32);
}
}
template <
typename ALayout,
typename BLayout,
typename CLayout,
int Warps,
int KTilesPerWarp>
void launch_tinygemm_kernel(
const at::Tensor& A,
const at::Tensor& B,
const at::Tensor* qScaleAndZeros, /* optional */
at::Tensor& C_final,
int32_t mTiles,
int32_t nTiles,
int32_t kTiles,
int32_t m,
int32_t n,
int32_t k,
cudaStream_t stream) {
// The chunking kernel requires that kTiles is a multiple of kInnerKTiles
TORCH_CHECK(
kTiles >= BLayout::kInnerKTiles &&
isEvenDivisor(kTiles, BLayout::kInnerKTiles));
TORCH_CHECK(
KTilesPerWarp >= BLayout::kInnerKTiles &&
isEvenDivisor(KTilesPerWarp, BLayout::kInnerKTiles));
// After intra-block reduction across the k dimension, we are left with this
// many tiles
// int32_t postKernelKTiles = kTiles / (Warps * KTilesPerWarp);
int32_t postKernelKTiles = 1; // we loop
auto grid = dim3(postKernelKTiles, nTiles, mTiles);
auto block = dim3(kWarpSize, Warps);
auto func =
tinygemm_m16n8k16_chunk_kernel<ALayout, BLayout, CLayout, Warps, KTilesPerWarp>;
func<<<grid, block, 0, stream>>>(
A.data_ptr(),
B.data_ptr(),
qScaleAndZeros ? qScaleAndZeros->data_ptr() : nullptr,
C_final.data_ptr(),
m,
n,
k,
mTiles,
nTiles,
kTiles);
C10_CUDA_KERNEL_LAUNCH_CHECK();
cudaFuncAttributes funcAttr;
C10_CUDA_CHECK(cudaFuncGetAttributes(
&funcAttr,
func));
}
// FIXME: parallelize better, smem staging etc?
template <int InnerKTiles>
__global__ void matrix_to_m16n8k16_Bint4_layout(
// size [n][k]
const at::PackedTensorAccessor32<int32_t, 2, at::RestrictPtrTraits> in,
// size [ceil(n / 8)][ceil(k / (InnerKTiles * 16))][32][InnerKTiles / 2]
at::PackedTensorAccessor32<int32_t, 4, at::RestrictPtrTraits> out) {
// int4 values are packed into int32 values, which require at least 8. Given
// m16n8k16 B layout requires 4 scalar values/lane, the minimum number of
// innermost k-tiles that we can use is 2.
static_assert(InnerKTiles >= 2 && isPowerOf2(InnerKTiles), "");
constexpr int32_t kNTileSize = 8;
constexpr int32_t kKTileSize = 16;
// gridDim.x corresponds to the number of k-tiles divided by InnerKTiles
auto kOuterTile = blockIdx.x;
auto nTile = blockIdx.y;
auto t = threadIdx.x;
// Two k-tiles are packed into an int32 at a time
#pragma unroll
for (int innerKTile = 0; innerKTile < InnerKTiles; innerKTile += 2) {
// n dimension that this lane loads from
auto n0 = nTile * kNTileSize + (t / 4);
bool n0Valid = n0 < in.size(0);
int32_t ks[8];
auto kBase0 = (kOuterTile * InnerKTiles + innerKTile) * kKTileSize;
ks[0] = kBase0 + (t % 4) * 2;
ks[1] = ks[0] + 1;
ks[2] = ks[0] + 8;
ks[3] = ks[0] + 8 + 1;
auto kBase1 = kBase0 + kKTileSize;
ks[4] = kBase1 + (t % 4) * 2;
ks[5] = ks[4] + 1;
ks[6] = ks[4] + 8;
ks[7] = ks[4] + 8 + 1;
auto pIn = &in[n0][0];
uint32_t v[8];
#pragma unroll
for (int i = 0; i < 8; ++i) {
v[i] = (n0Valid && ks[i] < in.size(1)) ? pIn[ks[i]] : uint32_t(0);
}
int32_t pack = (v[7] << 28) | (v[5] << 24) | (v[3] << 20) | (v[1] << 16) |
(v[6] << 12) | (v[4] << 8) | (v[2] << 4) | v[0];
// inner k-tiles pack two at a time
out[nTile][kOuterTile][t][innerKTile / 2] = pack;
}
}
#endif
at::Tensor _weight_int4pack_mm_cuda(
const at::Tensor& A,
const at::Tensor& B,
int64_t qGroupSize,
const at::Tensor& qScaleAndZeros) {
c10::cuda::CUDAGuard g(A.device());
TORCH_CHECK(
A.device() == B.device() && A.device() == qScaleAndZeros.device());
constexpr int32_t kMTileSize = 16;
constexpr int32_t kNTileSize = 8;
constexpr int32_t kKTileSize = 16;
// row major layout
auto m = A.size(0);
auto mTiles = divUp(m, kMTileSize);
// tensor core layout
auto nTiles = B.size(0);
auto n = nTiles * kNTileSize;
// row major layout
auto k = A.size(1);
auto kTiles = divUp(k, kKTileSize);
// The number of inner k tiles is the innermost dimension of times 2
// 2 k-tiles (4 values per lane per tile, 8 values total) quantized to int4
// packed into 1 int32 for int4 B
auto B_innerKTiles = B.size(3) * 2;
TORCH_CHECK(B_innerKTiles == 2 || B_innerKTiles == 4 || B_innerKTiles == 8);
// A is standard row major
TORCH_CHECK(A.dtype() == at::kBFloat16);
TORCH_CHECK(A.is_contiguous());
TORCH_CHECK(A.dim() == 2);
// B has B_innerKTiles k-tiles in the innermost dimension
TORCH_CHECK(B.dtype() == at::kInt);
TORCH_CHECK(B.is_contiguous());
TORCH_CHECK(B.dim() == 4);
TORCH_CHECK(B.size(1) == k / (B_innerKTiles * kKTileSize));
TORCH_CHECK(B.size(2) == kWarpSize);
// Validate the scale and zero point tensor for dequantization
// These are the only versions handled at the moment
TORCH_CHECK(
qGroupSize == 32 || qGroupSize == 64 || qGroupSize == 128 ||
qGroupSize == 256);
TORCH_CHECK(qScaleAndZeros.dim() == 3);
auto numQGroups = qScaleAndZeros.size(0);
TORCH_CHECK(
kTiles * kKTileSize >= qGroupSize &&
isEvenDivisor(kTiles * kKTileSize, qGroupSize));
TORCH_CHECK(qScaleAndZeros.size(1) == n);
TORCH_CHECK(qScaleAndZeros.size(2) == 2);
// Output is a standard row-major matrix
auto C_final = at::empty(
{m, n}, at::TensorOptions().dtype(at::kBFloat16).device(A.device()));
#if (defined(CUDA_VERSION) && CUDA_VERSION >= 12000) && (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 800))
auto stream = at::cuda::getCurrentCUDAStream();
#define RUN_GEMM(WARPS, K_TILES_PER_WARP, Q_GROUP_SIZE, REDUCE_TYPE) \
do { \
using ACLayout = ALayout_RM<REDUCE_TYPE>; \
\
TORCH_CHECK( \
K_TILES_PER_WARP >= B_innerKTiles && \
isEvenDivisor(K_TILES_PER_WARP, B_innerKTiles)); \
\
switch (B_innerKTiles) { \
case 2: \
if constexpr (K_TILES_PER_WARP >= 2) { \
using BLayout = BLayout_TC_int4<2, Q_GROUP_SIZE>; \
launch_tinygemm_kernel< \
ACLayout, \
BLayout, \
ACLayout, \
WARPS, \
K_TILES_PER_WARP>( \
A, \
B, \
&qScaleAndZeros, \
C_final, \
mTiles, \
nTiles, \
kTiles, \
m, \
n, \
k, \
stream); \
} \
break; \
case 4: \
if constexpr (K_TILES_PER_WARP >= 4) { \
using BLayout = BLayout_TC_int4<4, Q_GROUP_SIZE>; \
launch_tinygemm_kernel< \
ACLayout, \
BLayout, \
ACLayout, \
WARPS, \
K_TILES_PER_WARP>( \
A, \
B, \
&qScaleAndZeros, \
C_final, \
mTiles, \
nTiles, \
kTiles, \
m, \
n, \
k, \
stream); \
} \
break; \
case 8: \
if constexpr (K_TILES_PER_WARP >= 8) { \
using BLayout = BLayout_TC_int4<8, Q_GROUP_SIZE>; \
launch_tinygemm_kernel< \
ACLayout, \
BLayout, \
ACLayout, \
WARPS, \
K_TILES_PER_WARP>( \
A, \
B, \
&qScaleAndZeros, \
C_final, \
mTiles, \
nTiles, \
kTiles, \
m, \
n, \
k, \
stream); \
} \
break; \
default: \
break; \
} \
} while (false)
#define HANDLE_Q_GROUP(WARPS, K_TILES_PER_WARP, REDUCE_TYPE) \
do { \
switch (qGroupSize) { \
case 32: \
RUN_GEMM(WARPS, K_TILES_PER_WARP, 32, REDUCE_TYPE); \
break; \
case 64: \
RUN_GEMM(WARPS, K_TILES_PER_WARP, 64, REDUCE_TYPE); \
break; \
case 128: \
RUN_GEMM(WARPS, K_TILES_PER_WARP, 128, REDUCE_TYPE); \
break; \
case 256: \
RUN_GEMM(WARPS, K_TILES_PER_WARP, 256, REDUCE_TYPE); \
break; \
} \
} while (false)
HANDLE_Q_GROUP(8, 8, KReductionType::None);
#undef HANDLE_Q_GROUP
#undef RUN_GEMM
return C_final;
#endif
TORCH_CHECK(false, "_weight_int4pack_mm_cuda is not available for build.")
return C_final;
}
// input is [n][k] (int32 dtype)
// output is [n / 8][k / (InnerKTiles * 16)][32][innerKTiles / 2]
at::Tensor _convert_weight_to_int4pack_cuda(
const at::Tensor& in,
int64_t innerKTiles) {
c10::cuda::CUDAGuard g(in.device());
TORCH_CHECK(in.dim() == 2);
TORCH_CHECK(in.dtype() == at::kInt);
TORCH_CHECK(in.is_contiguous());
// At least 2 k-tiles need to be packed back to back in the innermost
// dimension, as the m16n8k16 tensor core tile presents 4 scalar values for
// the B matrix, but the minimum word size for the packed format is 4 bytes
// (int32). 4 inner K-tiles = 8 byte load, 8 inner k-tiles = 16 byte load
// which is the maximum vectorized load/store size
TORCH_CHECK(innerKTiles == 2 || innerKTiles == 4 || innerKTiles == 8);
constexpr int32_t kNTileSize = 8;
constexpr int32_t kKTileSize = 16;
auto nTiles = divUp(in.size(0), kNTileSize);
// k-tiles are packed back to back in the innermost dimension in order to
// allow for 4/8/16 byte loads
TORCH_CHECK(isEvenDivisor(in.size(1), innerKTiles * kKTileSize));
// kSuperTiles is the number of k-tiles assuming k is innerKTiles * kKTileSize
auto kSuperTiles = divUp(in.size(1), innerKTiles * kKTileSize);
// each block handles `innerKTiles` k-tiles.
// 2 k-tiles are a single int32
auto out = at::empty(
{nTiles, kSuperTiles, 32, innerKTiles / 2},
at::TensorOptions().dtype(at::kInt).device(in.device()));
#if (defined(CUDA_VERSION) && CUDA_VERSION >= 12000) && (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 800))
auto stream = at::cuda::getCurrentCUDAStream();
dim3 grid(kSuperTiles, nTiles);
if (innerKTiles == 2) {
matrix_to_m16n8k16_Bint4_layout<2><<<grid, kWarpSize, 0, stream>>>(
in.packed_accessor32<int32_t, 2, at::RestrictPtrTraits>(),
out.packed_accessor32<int32_t, 4, at::RestrictPtrTraits>());
} else if (innerKTiles == 4) {
matrix_to_m16n8k16_Bint4_layout<4><<<grid, kWarpSize, 0, stream>>>(
in.packed_accessor32<int32_t, 2, at::RestrictPtrTraits>(),
out.packed_accessor32<int32_t, 4, at::RestrictPtrTraits>());
} else if (innerKTiles == 8) {
matrix_to_m16n8k16_Bint4_layout<8><<<grid, kWarpSize, 0, stream>>>(
in.packed_accessor32<int32_t, 2, at::RestrictPtrTraits>(),
out.packed_accessor32<int32_t, 4, at::RestrictPtrTraits>());
}
return out;
#endif
TORCH_CHECK(false, "_convert_weight_to_int4pack_cuda is not available for build.")
return out;
}
} // namespace at::native