Compare commits

...

1 Commits

Author SHA1 Message Date
Jeffrey Morgan
8e22b09e2c ggml-cuda: fix fattn build for GLM 4.7 flash support
Port upstream llama.cpp fixes for GLM 4.7 flash attention:

fattn-tile.cuh:
- Use nthreads=128 for ncols=4 to keep shared memory under 48KB
  on Maxwell (sm_50/52) which has 48KB limit

fattn-mma-f16.cuh:
- Fix np formula: cols_per_warp > ncols ? nwarps : nwarps * cols_per_warp/ncols
- Remove static_assert for cols_per_warp != 8
- Add cols_per_warp == 8 MMA handling
- Add Volta guard to skip ncols < 32 (not supported)
- Remove extern for (576, 512, 2, 4) which isn't used
2026-01-24 03:47:16 +00:00
3 changed files with 119 additions and 46 deletions

View File

@@ -19,20 +19,20 @@ CUDA changes:
- Add template instances for ncols2=4
- Fix nbatch_fa values in nvidia_fp32 config (32->64)
---
ggml/src/ggml-cuda/fattn-mma-f16.cuh | 15 ++++++++++++---
ggml/src/ggml-cuda/fattn-tile.cuh | 18 +++++++++++++++++-
ggml/src/ggml-cuda/fattn.cu | 12 ++++++++----
...attn-mma-f16-instance-ncols1_16-ncols2_4.cu | 1 +
...fattn-mma-f16-instance-ncols1_2-ncols2_4.cu | 1 +
...fattn-mma-f16-instance-ncols1_4-ncols2_4.cu | 1 +
...fattn-mma-f16-instance-ncols1_8-ncols2_4.cu | 1 +
ggml/src/ggml-metal/ggml-metal-device.m | 8 ++------
ggml/src/ggml-metal/ggml-metal-ops.cpp | 2 +-
ggml/src/ggml-metal/ggml-metal.metal | 1 +
10 files changed, 45 insertions(+), 15 deletions(-)
ggml/src/ggml-cuda/fattn-mma-f16.cuh | 40 +++++++++++++++----
ggml/src/ggml-cuda/fattn-tile.cuh | 16 ++++++++
ggml/src/ggml-cuda/fattn.cu | 12 ++++--
...ttn-mma-f16-instance-ncols1_16-ncols2_4.cu | 1 +
...attn-mma-f16-instance-ncols1_2-ncols2_4.cu | 1 +
...attn-mma-f16-instance-ncols1_4-ncols2_4.cu | 1 +
...attn-mma-f16-instance-ncols1_8-ncols2_4.cu | 1 +
ggml/src/ggml-metal/ggml-metal-device.m | 8 +---
ggml/src/ggml-metal/ggml-metal-ops.cpp | 2 +-
ggml/src/ggml-metal/ggml-metal.metal | 1 +
10 files changed, 64 insertions(+), 19 deletions(-)
diff --git a/ggml/src/ggml-cuda/fattn-mma-f16.cuh b/ggml/src/ggml-cuda/fattn-mma-f16.cuh
index 7bd1044c1..a627302f9 100644
index 7bd1044c1..3dea2205e 100644
--- a/ggml/src/ggml-cuda/fattn-mma-f16.cuh
+++ b/ggml/src/ggml-cuda/fattn-mma-f16.cuh
@@ -66,7 +66,8 @@ static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_co
@@ -65,45 +65,103 @@ index 7bd1044c1..a627302f9 100644
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 16, 64, 4, 32, 288, 256, 64, 1, false);
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 32, 128, 2, 32, 160, 128, 64, 1, false);
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 64, 256, 1, 32, 160, 128, 64, 1, false);
@@ -1585,3 +1588,9 @@ DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 256, 64)
@@ -397,7 +400,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
constexpr int ncols = ncols1 * ncols2;
constexpr int cols_per_warp = T_B_KQ::I;
constexpr int cols_per_thread = 2; // This is specifically KQ columns, Volta only has a single VKQ column.
- constexpr int np = nwarps * (cols_per_warp/ncols2) / ncols1; // Number of parallel CUDA warps per Q column.
+ constexpr int np = cols_per_warp > ncols ? nwarps : nwarps * cols_per_warp/ncols; // Number of parallel CUDA warps per Q column.
constexpr int nbatch_fa = ggml_cuda_fattn_mma_get_nbatch_fa(DKQ, DV, ncols);
constexpr int nbatch_K2 = ggml_cuda_fattn_mma_get_nbatch_K2(DKQ, DV, ncols);
constexpr int nbatch_V2 = ggml_cuda_fattn_mma_get_nbatch_V2(DKQ, DV, ncols);
@@ -467,7 +470,6 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
}
}
} else {
- static_assert(cols_per_warp != 8, "cols_per_warp == 8 not implemented");
#pragma unroll
for (int k_KQ_0 = k0_start; k_KQ_0 < k0_stop; k_KQ_0 += T_A_KQ::J) {
load_ldmatrix(Q_B[0], tile_Q + (threadIdx.y / np)*(T_B_KQ::I*stride_tile_Q) + k_KQ_0, stride_tile_Q);
@@ -479,8 +481,18 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
T_A_KQ K_A;
load_ldmatrix(K_A, tile_K + i_KQ_0*stride_tile_K + (k_KQ_0 - k0_start), stride_tile_K);
- // Wide version of KQ_C is column-major => swap A and B.
- mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], Q_B[0], K_A);
+ if constexpr (cols_per_warp == 8) {
+ mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], K_A, Q_B[0]);
+ } else {
+ // Wide version of KQ_C is column-major
+#if defined(AMD_WMMA_AVAILABLE)
+ // RDNA matrix C is column-major.
+ mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], K_A, Q_B[0]);
+#else
+ // swap A and B for CUDA.
+ mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], Q_B[0], K_A);
+#endif // defined(AMD_WMMA_AVAILABLE)
+ }
}
}
}
@@ -841,7 +853,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
constexpr int cols_per_warp = T_B_KQ::I;
constexpr int cols_per_thread = 2; // This is specifically KQ columns, Volta only has a single VKQ column.
- constexpr int np = nwarps * (cols_per_warp/ncols2) / ncols1; // Number of parallel CUDA warps per Q column.
+ constexpr int np = cols_per_warp > ncols ? nwarps : nwarps * cols_per_warp/ncols; // Number of parallel CUDA warps per Q column.
constexpr int nbatch_fa = ggml_cuda_fattn_mma_get_nbatch_fa (DKQ, DV, ncols);
constexpr int nbatch_K2 = ggml_cuda_fattn_mma_get_nbatch_K2 (DKQ, DV, ncols);
constexpr int nbatch_V2 = ggml_cuda_fattn_mma_get_nbatch_V2 (DKQ, DV, ncols);
@@ -1353,6 +1365,13 @@ static __global__ void flash_attn_ext_f16(
NO_DEVICE_CODE;
return;
}
+#ifdef VOLTA_MMA_AVAILABLE
+ if (ncols1*ncols2 < 32) {
+ NO_DEVICE_CODE;
+ return;
+ }
+#endif // VOLTA_MMA_AVAILABLE
+
#if __CUDA_ARCH__ == GGML_CUDA_CC_TURING
if (ncols1*ncols2 > 32) {
NO_DEVICE_CODE;
@@ -1585,3 +1604,8 @@ DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 256, 64)
extern DECL_FATTN_MMA_F16_CASE(576, 512, 1, 16);
extern DECL_FATTN_MMA_F16_CASE(576, 512, 2, 16);
extern DECL_FATTN_MMA_F16_CASE(576, 512, 4, 16);
+
+// GLM 4.7 Flash uses gqa_ratio 4:
+extern DECL_FATTN_MMA_F16_CASE(576, 512, 2, 4);
+extern DECL_FATTN_MMA_F16_CASE(576, 512, 4, 4);
+extern DECL_FATTN_MMA_F16_CASE(576, 512, 8, 4);
+extern DECL_FATTN_MMA_F16_CASE(576, 512, 16, 4);
+// For GLM 4.7 Flash
+extern DECL_FATTN_MMA_F16_CASE(576, 512, 4, 4);
+extern DECL_FATTN_MMA_F16_CASE(576, 512, 8, 4);
+extern DECL_FATTN_MMA_F16_CASE(576, 512, 16, 4);
diff --git a/ggml/src/ggml-cuda/fattn-tile.cuh b/ggml/src/ggml-cuda/fattn-tile.cuh
index 7c4d6fe67..6389ba5c4 100644
index 7c4d6fe67..371be7442 100644
--- a/ggml/src/ggml-cuda/fattn-tile.cuh
+++ b/ggml/src/ggml-cuda/fattn-tile.cuh
@@ -68,6 +68,8 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_nv
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 64, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 64, 64)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 4, 256, 2, 64, 64)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 4, 128, 2, 64, 64)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 64, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2, 64, 64)
return 0;
@@ -122,7 +124,9 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_nv
@@ -122,6 +124,8 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_nv
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 32, 128)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 32, 64)
- GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2, 32, 64)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 4, 256, 2, 64, 64)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 64, 64)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2, 64, 64)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 4, 128, 2, 32, 64)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 32, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2, 32, 64)
return 0;
}
@@ -183,6 +187,8 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_am
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 32, 128)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 32, 128)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 4, 256, 2, 64, 64)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 4, 128, 2, 64, 64)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 64, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2, 64, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 32, 512, 1, 128, 64)
@@ -112,8 +170,8 @@ index 7c4d6fe67..6389ba5c4 100644
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 5, 32, 256)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 3, 64, 128)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 4, 256, 4, 64, 64)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 4, 64, 64)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 4, 128, 2, 64, 64)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 64, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 4, 64, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 32, 256, 2, 128, 64)

View File

@@ -400,7 +400,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
constexpr int ncols = ncols1 * ncols2;
constexpr int cols_per_warp = T_B_KQ::I;
constexpr int cols_per_thread = 2; // This is specifically KQ columns, Volta only has a single VKQ column.
constexpr int np = nwarps * (cols_per_warp/ncols2) / ncols1; // Number of parallel CUDA warps per Q column.
constexpr int np = cols_per_warp > ncols ? nwarps : nwarps * cols_per_warp/ncols; // Number of parallel CUDA warps per Q column.
constexpr int nbatch_fa = ggml_cuda_fattn_mma_get_nbatch_fa(DKQ, DV, ncols);
constexpr int nbatch_K2 = ggml_cuda_fattn_mma_get_nbatch_K2(DKQ, DV, ncols);
constexpr int nbatch_V2 = ggml_cuda_fattn_mma_get_nbatch_V2(DKQ, DV, ncols);
@@ -470,7 +470,6 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
}
}
} else {
static_assert(cols_per_warp != 8, "cols_per_warp == 8 not implemented");
#pragma unroll
for (int k_KQ_0 = k0_start; k_KQ_0 < k0_stop; k_KQ_0 += T_A_KQ::J) {
load_ldmatrix(Q_B[0], tile_Q + (threadIdx.y / np)*(T_B_KQ::I*stride_tile_Q) + k_KQ_0, stride_tile_Q);
@@ -482,8 +481,18 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
T_A_KQ K_A;
load_ldmatrix(K_A, tile_K + i_KQ_0*stride_tile_K + (k_KQ_0 - k0_start), stride_tile_K);
// Wide version of KQ_C is column-major => swap A and B.
mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], Q_B[0], K_A);
if constexpr (cols_per_warp == 8) {
mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], K_A, Q_B[0]);
} else {
// Wide version of KQ_C is column-major
#if defined(AMD_WMMA_AVAILABLE)
// RDNA matrix C is column-major.
mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], K_A, Q_B[0]);
#else
// swap A and B for CUDA.
mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], Q_B[0], K_A);
#endif // defined(AMD_WMMA_AVAILABLE)
}
}
}
}
@@ -844,7 +853,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
constexpr int cols_per_warp = T_B_KQ::I;
constexpr int cols_per_thread = 2; // This is specifically KQ columns, Volta only has a single VKQ column.
constexpr int np = nwarps * (cols_per_warp/ncols2) / ncols1; // Number of parallel CUDA warps per Q column.
constexpr int np = cols_per_warp > ncols ? nwarps : nwarps * cols_per_warp/ncols; // Number of parallel CUDA warps per Q column.
constexpr int nbatch_fa = ggml_cuda_fattn_mma_get_nbatch_fa (DKQ, DV, ncols);
constexpr int nbatch_K2 = ggml_cuda_fattn_mma_get_nbatch_K2 (DKQ, DV, ncols);
constexpr int nbatch_V2 = ggml_cuda_fattn_mma_get_nbatch_V2 (DKQ, DV, ncols);
@@ -1356,6 +1365,13 @@ static __global__ void flash_attn_ext_f16(
NO_DEVICE_CODE;
return;
}
#ifdef VOLTA_MMA_AVAILABLE
if (ncols1*ncols2 < 32) {
NO_DEVICE_CODE;
return;
}
#endif // VOLTA_MMA_AVAILABLE
#if __CUDA_ARCH__ == GGML_CUDA_CC_TURING
if (ncols1*ncols2 > 32) {
NO_DEVICE_CODE;
@@ -1589,8 +1605,7 @@ extern DECL_FATTN_MMA_F16_CASE(576, 512, 1, 16);
extern DECL_FATTN_MMA_F16_CASE(576, 512, 2, 16);
extern DECL_FATTN_MMA_F16_CASE(576, 512, 4, 16);
// GLM 4.7 Flash uses gqa_ratio 4:
extern DECL_FATTN_MMA_F16_CASE(576, 512, 2, 4);
extern DECL_FATTN_MMA_F16_CASE(576, 512, 4, 4);
extern DECL_FATTN_MMA_F16_CASE(576, 512, 8, 4);
extern DECL_FATTN_MMA_F16_CASE(576, 512, 16, 4);
// For GLM 4.7 Flash
extern DECL_FATTN_MMA_F16_CASE(576, 512, 4, 4);
extern DECL_FATTN_MMA_F16_CASE(576, 512, 8, 4);
extern DECL_FATTN_MMA_F16_CASE(576, 512, 16, 4);

View File

@@ -68,7 +68,7 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_nv
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 64, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 64, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 4, 256, 2, 64, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 4, 128, 2, 64, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 64, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2, 64, 64)
@@ -124,9 +124,9 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_nv
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 32, 128)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 32, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 4, 256, 2, 64, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 64, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2, 64, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 4, 128, 2, 32, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 32, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2, 32, 64)
return 0;
}
@@ -187,7 +187,7 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_am
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 32, 128)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 32, 128)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 4, 256, 2, 64, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 4, 128, 2, 64, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 64, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2, 64, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 32, 512, 1, 128, 64)
@@ -251,8 +251,8 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_am
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 5, 32, 256)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 3, 64, 128)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 4, 256, 4, 64, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 4, 64, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 4, 128, 2, 64, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 64, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 4, 64, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 32, 256, 2, 128, 64)