-
Notifications
You must be signed in to change notification settings - Fork 10.8k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
CUDA: use async data loading for FlashAttention (#11894)
* CUDA: use async data loading for FlashAttention --------- Co-authored-by: Diego Devesa <[email protected]>
- Loading branch information
1 parent
f7b1116
commit 73e2ed3
Showing
6 changed files
with
724 additions
and
719 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
// Simplified API for asynchronous data loading. | ||
|
||
#include "common.cuh" | ||
|
||
// Copies data from global to shared memory, cg == cache global. | ||
// Both the src and dst pointers must be aligned to 16 bit. | ||
// Shared memory uses 32 bit addressing, the pointer is passed as unsigned int. | ||
// Generic pointers can be converted to 32 bit shared memory pointers using __cvta_generic_to_shared. | ||
// Only the 16 bit copy is exposed because 4 and 8 bit copies did not yield performance improvements. | ||
template <int preload> | ||
static __device__ __forceinline__ void cp_async_cg_16(const unsigned int dst, const void * src) { | ||
static_assert(preload == 0 || preload == 64 || preload == 128 || preload == 256, "bad preload"); | ||
#ifdef CP_ASYNC_AVAILABLE | ||
#if CUDART_VERSION >= 11040 | ||
if (preload == 256) { | ||
asm volatile("cp.async.cg.shared.global.L2::256B [%0], [%1], 16;" | ||
: : "r"(dst), "l"(src)); | ||
} else if (preload == 128) { | ||
asm volatile("cp.async.cg.shared.global.L2::128B [%0], [%1], 16;" | ||
: : "r"(dst), "l"(src)); | ||
} else if (preload == 64) { | ||
asm volatile("cp.async.cg.shared.global.L2::64B [%0], [%1], 16;" | ||
: : "r"(dst), "l"(src)); | ||
} else | ||
#endif // CUDART_VERSION >= 11040 | ||
{ | ||
asm volatile("cp.async.cg.shared.global.L2 [%0], [%1], 16;" | ||
: : "r"(dst), "l"(src)); | ||
} | ||
#else | ||
GGML_UNUSED(dst); | ||
GGML_UNUSED(src); | ||
NO_DEVICE_CODE; | ||
#endif // CP_ASYNC_AVAILABLE | ||
} | ||
|
||
// Makes each thread wait until its asynchronous data copies are done. | ||
// This does NOT provide any additional synchronization. | ||
// In particular, when copying data with multiple warps a call to __syncthreads will be needed. | ||
static __device__ __forceinline__ void cp_async_wait_all() { | ||
#ifdef CP_ASYNC_AVAILABLE | ||
asm volatile("cp.async.wait_all;"); | ||
#else | ||
NO_DEVICE_CODE; | ||
#endif // CP_ASYNC_AVAILABLE | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.