Skip to content

Commit

Permalink
CUDA: use async data loading for FlashAttention (#11894)
Browse files Browse the repository at this point in the history
* CUDA: use async data loading for FlashAttention

---------

Co-authored-by: Diego Devesa <[email protected]>
  • Loading branch information
JohannesGaessler and slaren authored Feb 17, 2025
1 parent f7b1116 commit 73e2ed3
Show file tree
Hide file tree
Showing 6 changed files with 724 additions and 719 deletions.
21 changes: 15 additions & 6 deletions ggml/src/ggml-cuda/common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,13 @@
#define CUDART_HMAX 11070 // CUDA 11.7, min. ver. for which __hmax and __hmax2 are known to work (may be higher than needed)
#define CUDART_HMASK 12000 // CUDA 12.0, min. ver. for half2 -> uint mask comparisons

#define GGML_CUDA_CC_PASCAL 600
#define GGML_CUDA_CC_DP4A 610 // minimum compute capability for __dp4a, an intrinsic for byte-wise dot products
#define GGML_CUDA_CC_VOLTA 700
#define GGML_CUDA_CC_TURING 750
#define GGML_CUDA_CC_AMPERE 800
#define GGML_CUDA_CC_OFFSET_AMD 0x1000000
#define GGML_CUDA_CC_PASCAL 600
#define GGML_CUDA_CC_DP4A 610 // minimum compute capability for __dp4a, an intrinsic for byte-wise dot products
#define GGML_CUDA_CC_VOLTA 700
#define GGML_CUDA_CC_TURING 750
#define GGML_CUDA_CC_AMPERE 800
#define GGML_CUDA_CC_ADA_LOVELACE 890
#define GGML_CUDA_CC_OFFSET_AMD 0x1000000

// GCN/CNDA, wave size is 64
#define GGML_CUDA_CC_GCN4 (GGML_CUDA_CC_OFFSET_AMD + 0x803) // Tonga, Fiji, Polaris, minimum for fast fp16
Expand Down Expand Up @@ -199,6 +200,10 @@ typedef float2 dfloat2;
#define NEW_MMA_AVAILABLE
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING

#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
#define CP_ASYNC_AVAILABLE
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE

#if !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ <= GGML_CUDA_CC_QY1)
#define FLASH_ATTN_AVAILABLE
#endif // !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ <= GGML_CUDA_CC_QY1)
Expand Down Expand Up @@ -231,6 +236,10 @@ static bool new_mma_available(const int cc) {
return cc < GGML_CUDA_CC_OFFSET_AMD && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_TURING;
}

static bool cp_async_available(const int cc) {
return cc < GGML_CUDA_CC_OFFSET_AMD && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_AMPERE;
}

static constexpr __device__ int ggml_cuda_get_physical_warp_size() {
#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
return __AMDGCN_WAVEFRONT_SIZE;
Expand Down
46 changes: 46 additions & 0 deletions ggml/src/ggml-cuda/cp-async.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
// Simplified API for asynchronous data loading.

#include "common.cuh"

// Copies data from global to shared memory, cg == cache global.
// Both the src and dst pointers must be aligned to 16 bit.
// Shared memory uses 32 bit addressing, the pointer is passed as unsigned int.
// Generic pointers can be converted to 32 bit shared memory pointers using __cvta_generic_to_shared.
// Only the 16 bit copy is exposed because 4 and 8 bit copies did not yield performance improvements.
template <int preload>
static __device__ __forceinline__ void cp_async_cg_16(const unsigned int dst, const void * src) {
static_assert(preload == 0 || preload == 64 || preload == 128 || preload == 256, "bad preload");
#ifdef CP_ASYNC_AVAILABLE
#if CUDART_VERSION >= 11040
if (preload == 256) {
asm volatile("cp.async.cg.shared.global.L2::256B [%0], [%1], 16;"
: : "r"(dst), "l"(src));
} else if (preload == 128) {
asm volatile("cp.async.cg.shared.global.L2::128B [%0], [%1], 16;"
: : "r"(dst), "l"(src));
} else if (preload == 64) {
asm volatile("cp.async.cg.shared.global.L2::64B [%0], [%1], 16;"
: : "r"(dst), "l"(src));
} else
#endif // CUDART_VERSION >= 11040
{
asm volatile("cp.async.cg.shared.global.L2 [%0], [%1], 16;"
: : "r"(dst), "l"(src));
}
#else
GGML_UNUSED(dst);
GGML_UNUSED(src);
NO_DEVICE_CODE;
#endif // CP_ASYNC_AVAILABLE
}

// Makes each thread wait until its asynchronous data copies are done.
// This does NOT provide any additional synchronization.
// In particular, when copying data with multiple warps a call to __syncthreads will be needed.
static __device__ __forceinline__ void cp_async_wait_all() {
#ifdef CP_ASYNC_AVAILABLE
asm volatile("cp.async.wait_all;");
#else
NO_DEVICE_CODE;
#endif // CP_ASYNC_AVAILABLE
}
15 changes: 9 additions & 6 deletions ggml/src/ggml-cuda/fattn-common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -716,7 +716,9 @@ void launch_fattn(

ggml_cuda_pool & pool = ctx.pool();
cudaStream_t main_stream = ctx.stream();
const int nsm = ggml_cuda_info().devices[ggml_cuda_get_device()].nsm;
const int id = ggml_cuda_get_device();
const int cc = ggml_cuda_info().devices[id].cc;
const int nsm = ggml_cuda_info().devices[id].nsm;

ggml_cuda_pool_alloc<half> K_f16(pool);
ggml_cuda_pool_alloc<half> V_f16(pool);
Expand Down Expand Up @@ -768,13 +770,14 @@ void launch_fattn(
dim3 blocks_num;
if (parallel_blocks == 0) {
// For short contexts it can be faster to have the SMs work on whole tiles because this lets us skip the fixup.
const int tiles_nwaves = (ntiles_total - nsm - 1) / nsm;
const bool tiles_inefficient = 3*nsm < 2*tiles_nwaves*ntiles_total;
const bool short_context = K->ne[1] < 4096;
const int tiles_nwaves = (ntiles_total + 2*nsm - 1) / (2*nsm);
const int tiles_efficiency_percent = 100 * ntiles_total / (2*nsm*tiles_nwaves);

const int nblocks_stream_k = 2*nsm;

blocks_num.x = short_context && !tiles_inefficient ? ntiles_total : nblocks_stream_k;
const bool use_stream_k = tiles_efficiency_percent < 75 || cc >= GGML_CUDA_CC_ADA_LOVELACE;

blocks_num.x = use_stream_k ? nblocks_stream_k : ntiles_total;
blocks_num.y = 1;
blocks_num.z = 1;

Expand Down Expand Up @@ -827,7 +830,7 @@ void launch_fattn(
CUDA_CHECK(cudaGetLastError());

if constexpr (parallel_blocks == 0) {
if (blocks_num.x % ntiles_total != 0) { // Fixup is only needed if the SMs work on fractional tiles.
if (ntiles_total % blocks_num.x != 0) { // Fixup is only needed if the SMs work on fractional tiles.
const dim3 block_dim_combine(D, 1, 1);
const dim3 blocks_num_combine = blocks_num;

Expand Down
Loading

0 comments on commit 73e2ed3

Please sign in to comment.