Compare commits

...

3 Commits

Author SHA1 Message Date
Johannes Gäßler 148995e5e5 llama-bench: more compact markdown tables (#7879) 2024-06-11 14:45:40 +02:00
Georgi Gerganov 4bfe50f741 tests : check the Python version (#7872)
ggml-ci
2024-06-11 10:10:20 +03:00
Johannes Gäßler bdcb8f4222 CUDA: int8 tensor cores for MMQ (q4_K, q5_K, q6_K) (#7860) 2024-06-11 08:26:07 +02:00
4 changed files with 383 additions and 8 deletions
+21
View File
@@ -1033,6 +1033,27 @@ struct markdown_printer : public printer {
if (field == "n_gpu_layers") {
return 3;
}
if (field == "n_threads") {
return 7;
}
if (field == "n_batch") {
return 7;
}
if (field == "n_ubatch") {
return 8;
}
if (field == "type_k" || field == "type_v") {
return 6;
}
if (field == "split_mode") {
return 5;
}
if (field == "flash_attn") {
return 2;
}
if (field == "use_mmap") {
return 4;
}
if (field == "test") {
return 13;
}
+66
View File
@@ -1,5 +1,27 @@
#include "common.cuh"
struct mma_int_A_I16K4 {
static constexpr int I = 16;
static constexpr int K = 4;
static constexpr int ne = 2;
int x[ne] = {0};
static __device__ __forceinline__ int get_i(const int l) {
const int ret = (l%2) * (I/2) + threadIdx.x / K;
GGML_CUDA_ASSUME(ret >= 0);
GGML_CUDA_ASSUME(ret < I);
return ret;
}
static __device__ __forceinline__ int get_k(const int /* l */) {
const int ret = threadIdx.x % K;
GGML_CUDA_ASSUME(ret >= 0);
GGML_CUDA_ASSUME(ret < K);
return ret;
}
};
struct mma_int_A_I16K8 {
static constexpr int I = 16;
static constexpr int K = 8;
@@ -22,6 +44,28 @@ struct mma_int_A_I16K8 {
}
};
struct mma_int_B_J8K4 {
static constexpr int J = 8;
static constexpr int K = 4;
static constexpr int ne = 1;
int x[ne] = {0};
static __device__ __forceinline__ int get_j(const int /* l */) {
const int ret = threadIdx.x / K;
GGML_CUDA_ASSUME(ret >= 0);
GGML_CUDA_ASSUME(ret < J);
return ret;
}
static __device__ __forceinline__ int get_k(const int /* l */) {
const int ret = threadIdx.x % K;
GGML_CUDA_ASSUME(ret >= 0);
GGML_CUDA_ASSUME(ret < K);
return ret;
}
};
struct mma_int_B_J8K8 {
static constexpr int J = 8;
static constexpr int K = 8;
@@ -65,6 +109,28 @@ struct mma_int_C_I16J8 {
return ret;
}
__device__ __forceinline__ void mma_K4(const mma_int_A_I16K4 & mma_A, const mma_int_B_J8K4 & mma_B) {
#ifdef INT8_MMA_AVAILABLE
#if __CUDA_ARCH__ >= CC_AMPERE
asm("mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};"
: "+r"(x[0]), "+r"(x[1]), "+r"(x[2]), "+r"(x[3])
: "r"(mma_A.x[0]), "r"(mma_A.x[1]), "r"(mma_B.x[0]));
#else
// On Turing m16n8k16 mma is not available, use 2x m8n8k16 mma instead:
asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};"
: "+r"(x[0]), "+r"(x[1])
: "r"(mma_A.x[0]), "r"(mma_B.x[0]));
asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};"
: "+r"(x[2]), "+r"(x[3])
: "r"(mma_A.x[1]), "r"(mma_B.x[0]));
#endif // __CUDA_ARCH__ >= CC_AMPERE
#else
GGML_UNUSED(mma_A);
GGML_UNUSED(mma_B);
NO_DEVICE_CODE;
#endif // INT8_MMA_AVAILABLE
}
__device__ __forceinline__ void mma_K8(const mma_int_A_I16K8 & mma_A, const mma_int_B_J8K8 & mma_B) {
#ifdef INT8_MMA_AVAILABLE
#if __CUDA_ARCH__ >= CC_AMPERE
+294 -6
View File
@@ -1089,7 +1089,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
}
template <int mmq_x, int mmq_y, int nwarps>
static __device__ __forceinline__ void vec_dot_q4_K_q8_1_mul_mat(
static __device__ __forceinline__ void vec_dot_q4_K_q8_1_dp4a(
const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
@@ -1115,6 +1115,97 @@ static __device__ __forceinline__ void vec_dot_q4_K_q8_1_mul_mat(
}
}
template <int mmq_x, int mmq_y, int nwarps>
static __device__ __forceinline__ void vec_dot_q4_K_q8_1_mma(
const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
typedef mma_int_A_I16K8 mma_A;
typedef mma_int_B_J8K8 mma_B;
typedef mma_int_C_I16J8 mma_C;
const int * y_qs = (const int *) y + 4;
const half2 * y_ds = (const half2 *) y;
const int i0 = threadIdx.y*mma_A::I;
static_assert(nwarps*mma_A::I == mmq_y, "nwarps*mma_A::I != mmq_y");
mma_A A[2];
int scA[mma_C::ne/2][2];
int mA[mma_C::ne/2][2];
half2 dmA[mma_C::ne/2];
#pragma unroll
for (int kvdr = 0; kvdr < VDR_Q4_K_Q8_1_MMQ; kvdr += 4) {
#pragma unroll
for (int l = 0; l < mma_A::ne; ++l) {
const int i = i0 + mma_A::get_i(l);
const int k = k0 + mma_A::get_k(l);
A[kvdr/4].x[l] = (x_ql[i*(WARP_SIZE + 1) + k] >> kvdr) & 0x0F0F0F0F;
}
#pragma unroll
for (int l = 0; l < mma_C::ne/2; ++l) {
const int i = i0 + mma_C::get_i(2*l);
const uint8_t * sc = ((const uint8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k0/16]) + 2 * ((k0 % 16) / 8);
const uint8_t * m = sc + 8;
scA[l][kvdr/4] = sc[kvdr/4];
mA[l][kvdr/4] = m[kvdr/4];
}
}
#pragma unroll
for (int l = 0; l < mma_C::ne/2; ++l) {
const int i = i0 + mma_C::get_i(2*l);
dmA[l] = x_dm[i*(WARP_SIZE/QI5_K) + i/QI5_K + k0/QI5_K];
}
#pragma unroll
for (int j0 = 0; j0 < mmq_x; j0 += mma_int_B_J8K8::J) {
float tmpd[mma_C::ne] = {0.0f};
float tmpm[mma_C::ne] = {0.0f};
#pragma unroll
for (int kvdr = 0; kvdr < VDR_Q5_K_Q8_1_MMQ; kvdr += 4) {
mma_C C;
mma_B B;
half2 dsB[mma_C::ne/2];
#pragma unroll
for (int l = 0; l < mma_B::ne; ++l) {
const int j = j0 + mma_B::get_j(l);
const int k = (2*k0 + 2*kvdr + mma_B::get_k(l)) % WARP_SIZE;
B.x[l] = y_qs[j*MMQ_TILE_Y_K + k];
}
#pragma unroll
for (int l = 0; l < mma_C::ne/2; ++l) {
const int j = j0 + mma_C::get_j(l);
dsB[l] = y_ds[j*MMQ_TILE_Y_K + ((2*k0 + 2*kvdr)/QI8_1) % (WARP_SIZE/QI8_1)];
}
C.mma_K8(A[kvdr/4], B);
#pragma unroll
for (int l = 0; l < mma_C::ne; ++l) {
tmpd[l] += (C.x[l]*scA[l/2][kvdr/4]) * __low2float(dsB[l%2]);
tmpm[l] += mA[l/2][kvdr/4] * __high2float(dsB[l%2]);
}
}
#pragma unroll
for (int l = 0; l < mma_C::ne; ++l) {
sum[(j0/mma_B::J)*mma_C::ne + l] += __low2float(dmA[l/2])*tmpd[l] - __high2float(dmA[l/2])*tmpm[l];
}
}
}
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q5_K(
const char * __restrict__ x, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) {
@@ -1188,7 +1279,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
}
template <int mmq_x, int mmq_y, int nwarps>
static __device__ __forceinline__ void vec_dot_q5_K_q8_1_mul_mat(
static __device__ __forceinline__ void vec_dot_q5_K_q8_1_dp4a(
const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
@@ -1214,6 +1305,97 @@ static __device__ __forceinline__ void vec_dot_q5_K_q8_1_mul_mat(
}
}
template <int mmq_x, int mmq_y, int nwarps>
static __device__ __forceinline__ void vec_dot_q5_K_q8_1_mma(
const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
typedef mma_int_A_I16K8 mma_A;
typedef mma_int_B_J8K8 mma_B;
typedef mma_int_C_I16J8 mma_C;
const int * y_qs = (const int *) y + 4;
const half2 * y_ds = (const half2 *) y;
const int i0 = threadIdx.y*mma_A::I;
static_assert(nwarps*mma_A::I == mmq_y, "nwarps*mma_A::I != mmq_y");
mma_A A[2];
int scA[mma_C::ne/2][2];
int mA[mma_C::ne/2][2];
half2 dmA[mma_C::ne/2];
#pragma unroll
for (int kvdr = 0; kvdr < VDR_Q5_K_Q8_1_MMQ; kvdr += 4) {
#pragma unroll
for (int l = 0; l < mma_A::ne; ++l) {
const int i = i0 + mma_A::get_i(l);
const int k = QR5_K*k0 + QR5_K*kvdr + mma_A::get_k(l);
A[kvdr/4].x[l] = x_ql[i*(QR5_K*WARP_SIZE + 1) + k];
}
#pragma unroll
for (int l = 0; l < mma_C::ne/2; ++l) {
const int i = i0 + mma_C::get_i(2*l);
const uint8_t * sc = ((const uint8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k0/16]) + 2 * ((k0 % 16) / 8);
const uint8_t * m = sc + 8;
scA[l][kvdr/4] = sc[kvdr/4];
mA[l][kvdr/4] = m[kvdr/4];
}
}
#pragma unroll
for (int l = 0; l < mma_C::ne/2; ++l) {
const int i = i0 + mma_C::get_i(2*l);
dmA[l] = x_dm[i*(WARP_SIZE/QI5_K) + i/QI5_K + k0/QI5_K];
}
#pragma unroll
for (int j0 = 0; j0 < mmq_x; j0 += mma_int_B_J8K8::J) {
float tmpd[mma_C::ne] = {0.0f};
float tmpm[mma_C::ne] = {0.0f};
#pragma unroll
for (int kvdr = 0; kvdr < VDR_Q5_K_Q8_1_MMQ; kvdr += 4) {
mma_C C;
mma_B B;
half2 dsB[mma_C::ne/2];
#pragma unroll
for (int l = 0; l < mma_B::ne; ++l) {
const int j = j0 + mma_B::get_j(l);
const int k = (2*k0 + 2*kvdr + mma_B::get_k(l)) % WARP_SIZE;
B.x[l] = y_qs[j*MMQ_TILE_Y_K + k];
}
#pragma unroll
for (int l = 0; l < mma_C::ne/2; ++l) {
const int j = j0 + mma_C::get_j(l);
dsB[l] = y_ds[j*MMQ_TILE_Y_K + ((2*k0 + 2*kvdr)/QI8_1) % (WARP_SIZE/QI8_1)];
}
C.mma_K8(A[kvdr/4], B);
#pragma unroll
for (int l = 0; l < mma_C::ne; ++l) {
tmpd[l] += (C.x[l]*scA[l/2][kvdr/4]) * __low2float(dsB[l%2]);
tmpm[l] += mA[l/2][kvdr/4] * __high2float(dsB[l%2]);
}
}
#pragma unroll
for (int l = 0; l < mma_C::ne; ++l) {
sum[(j0/mma_B::J)*mma_C::ne + l] += __low2float(dmA[l/2])*tmpd[l] - __high2float(dmA[l/2])*tmpm[l];
}
}
}
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q6_K(
const char * __restrict__ x, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) {
@@ -1280,7 +1462,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
}
template <int mmq_x, int mmq_y, int nwarps>
static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mul_mat(
static __device__ __forceinline__ void vec_dot_q6_K_q8_1_dp4a(
const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
@@ -1307,6 +1489,97 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mul_mat(
}
}
template <int mmq_x, int mmq_y, int nwarps>
static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
typedef mma_int_A_I16K4 mma_A;
typedef mma_int_B_J8K4 mma_B;
typedef mma_int_C_I16J8 mma_C;
const float * x_df = (const float *) x_dm;
const int * y_qs = (const int *) y + 4;
const float * y_df = (const float *) y;
const int i0 = threadIdx.y*mma_A::I;
static_assert(nwarps*mma_A::I == mmq_y, "nwarps*mma_A::I != mmq_y");
mma_A A[4];
int scA[mma_C::ne/2][4];
float dA[mma_C::ne/2];
#pragma unroll
for (int kvdr = 0; kvdr < VDR_Q6_K_Q8_1_MMQ; kvdr += 4) {
#pragma unroll
for (int l = 0; l < mma_A::ne; ++l) {
const int i = i0 + mma_A::get_i(l);
const int k = QR6_K*k0 + QR6_K*kvdr + mma_A::get_k(l);
A[kvdr/2 + 0].x[l] = x_ql[i*(QR6_K*WARP_SIZE + 1) + k + 0];
A[kvdr/2 + 1].x[l] = x_ql[i*(QR6_K*WARP_SIZE + 1) + k + mma_A::K];
}
#pragma unroll
for (int l = 0; l < mma_C::ne/2; ++l) {
const int i = i0 + mma_C::get_i(2*l);
const int8_t * sc = ((const int8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k0/8]);
scA[l][kvdr/2 + 0] = sc[kvdr/2 + 0];
scA[l][kvdr/2 + 1] = sc[kvdr/2 + 1];
}
}
#pragma unroll
for (int l = 0; l < mma_C::ne/2; ++l) {
const int i = i0 + mma_C::get_i(2*l);
dA[l] = x_df[i*(WARP_SIZE/QI6_K) + i/QI6_K + k0/QI6_K];
}
#pragma unroll
for (int j0 = 0; j0 < mmq_x; j0 += mma_int_B_J8K8::J) {
float tmp[mma_C::ne] = {0.0f};
#pragma unroll
for (int kvdr = 0; kvdr < VDR_Q6_K_Q8_1_MMQ; kvdr += 4) {
mma_C C[2];
mma_B B[2];
float dB[mma_C::ne/2];
#pragma unroll
for (int l = 0; l < mma_B::ne; ++l) {
const int j = j0 + mma_B::get_j(l);
const int k = (2*k0 + 2*kvdr + mma_B::get_k(l)) % WARP_SIZE;
B[0].x[l] = y_qs[j*MMQ_TILE_Y_K + k + 0];
B[1].x[l] = y_qs[j*MMQ_TILE_Y_K + k + mma_B::K];
}
#pragma unroll
for (int l = 0; l < mma_C::ne/2; ++l) {
const int j = j0 + mma_C::get_j(l);
dB[l] = y_df[j*MMQ_TILE_Y_K + ((2*k0 + 2*kvdr)/QI8_1) % (WARP_SIZE/QI8_1)];
}
C[0].mma_K4(A[kvdr/2 + 0], B[0]);
C[1].mma_K4(A[kvdr/2 + 1], B[1]);
#pragma unroll
for (int l = 0; l < mma_C::ne; ++l) {
tmp[l] += (C[0].x[l]*scA[l/2][kvdr/2 + 0] + C[1].x[l]*scA[l/2][kvdr/2 + 1])*dB[l%2];
}
}
#pragma unroll
for (int l = 0; l < mma_C::ne; ++l) {
sum[(j0/mma_B::J)*mma_C::ne + l] += tmp[l]*dA[l/2];
}
}
}
template<int mmq_x, int mmq_y, int nwarps, bool need_check>
static __device__ __forceinline__ void mmq_write_back_dp4a(const float * __restrict__ sum, float * __restrict__ dst, const int & ne0, const int & ne1) {
#pragma unroll
@@ -1448,24 +1721,39 @@ template <int mmq_x, int mmq_y, int nwarps, bool need_check>
struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q4_K> {
static constexpr int vdr = VDR_Q4_K_Q8_1_MMQ;
static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_K<mmq_y, nwarps, need_check>;
static constexpr vec_dot_mmq_t vec_dot = vec_dot_q4_K_q8_1_mul_mat<mmq_x, mmq_y, nwarps>;
#ifdef INT8_MMA_AVAILABLE
static constexpr vec_dot_mmq_t vec_dot = vec_dot_q4_K_q8_1_mma<mmq_x, mmq_y, nwarps>;
static constexpr mmq_write_back_t write_back = mmq_write_back_mma<mmq_x, mmq_y, nwarps, need_check>;
#else
static constexpr vec_dot_mmq_t vec_dot = vec_dot_q4_K_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
static constexpr mmq_write_back_t write_back = mmq_write_back_dp4a<mmq_x, mmq_y, nwarps, need_check>;
#endif // INT8_MMA_AVAILABLE
};
template <int mmq_x, int mmq_y, int nwarps, bool need_check>
struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q5_K> {
static constexpr int vdr = VDR_Q5_K_Q8_1_MMQ;
static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_K<mmq_y, nwarps, need_check>;
static constexpr vec_dot_mmq_t vec_dot = vec_dot_q5_K_q8_1_mul_mat<mmq_x, mmq_y, nwarps>;
#ifdef INT8_MMA_AVAILABLE
static constexpr vec_dot_mmq_t vec_dot = vec_dot_q5_K_q8_1_mma<mmq_x, mmq_y, nwarps>;
static constexpr mmq_write_back_t write_back = mmq_write_back_mma<mmq_x, mmq_y, nwarps, need_check>;
#else
static constexpr vec_dot_mmq_t vec_dot = vec_dot_q5_K_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
static constexpr mmq_write_back_t write_back = mmq_write_back_dp4a<mmq_x, mmq_y, nwarps, need_check>;
#endif // INT8_MMA_AVAILABLE
};
template <int mmq_x, int mmq_y, int nwarps, bool need_check>
struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q6_K> {
static constexpr int vdr = VDR_Q6_K_Q8_1_MMQ;
static constexpr load_tiles_mmq_t load_tiles = load_tiles_q6_K<mmq_y, nwarps, need_check>;
static constexpr vec_dot_mmq_t vec_dot = vec_dot_q6_K_q8_1_mul_mat<mmq_x, mmq_y, nwarps>;
#ifdef INT8_MMA_AVAILABLE
static constexpr vec_dot_mmq_t vec_dot = vec_dot_q6_K_q8_1_mma<mmq_x, mmq_y, nwarps>;
static constexpr mmq_write_back_t write_back = mmq_write_back_mma<mmq_x, mmq_y, nwarps, need_check>;
#else
static constexpr vec_dot_mmq_t vec_dot = vec_dot_q6_K_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
static constexpr mmq_write_back_t write_back = mmq_write_back_dp4a<mmq_x, mmq_y, nwarps, need_check>;
#endif // INT8_MMA_AVAILABLE
};
static int mmq_need_sum(const ggml_type type_x) {
+2 -2
View File
@@ -870,7 +870,7 @@ int main() {
}
});
if (getenv("LLAMA_PYTHON_AVAILABLE") || (std::system("python --version") == 0)) {
if (getenv("LLAMA_PYTHON_AVAILABLE") || (std::system("python -c \"import sys; exit(1) if sys.version_info < (3, 8) else print('Python version is sufficient')\"") == 0)) {
test_all("Python", [](const TestCase & tc) {
write("test-json-schema-input.tmp", tc.schema);
tc.verify_status(std::system(
@@ -878,7 +878,7 @@ int main() {
tc.verify(read("test-grammar-output.tmp"));
});
} else {
fprintf(stderr, "\033[33mWARNING: Python not found, skipping Python JSON schema -> grammar tests.\n\033[0m");
fprintf(stderr, "\033[33mWARNING: Python not found (min version required is 3.8), skipping Python JSON schema -> grammar tests.\n\033[0m");
}
if (getenv("LLAMA_NODE_AVAILABLE") || (std::system("node --version") == 0)) {