-
Notifications
You must be signed in to change notification settings - Fork 92
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP] support paged attention #710
Draft
jslhcl
wants to merge
18
commits into
main
Choose a base branch
from
leca/pagedAttention
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Draft
Changes from 2 commits
Commits
Show all changes
18 commits
Select commit
Hold shift + click to select a range
0927cca
support paged attention
cf10813
Merge branch 'main' of https://github.com/microsoft/onnxruntime-exten…
6931a52
add kernel functions and build successfully
77ae3b0
move checkInput() to paged_attention.h
9ac5f18
call flash attention code for prompt mode
90ee6a6
runtime error in RunMultiHeadAttention
jslhcl 17c2c41
UT can run now
jslhcl 75b2d04
new UT with q,k and v whose shape is 381x512, and not use GetScratchB…
jslhcl 68dac5d
make context_lens CPU input, change scale in test, change batch_size …
jslhcl e8bf9ea
change window_size and causal to make ut pass
jslhcl 8352a13
enhance flash attention lib to support page attention, and UT
jslhcl 4bfcfad
runtime error in gpu_data_transfer.cc when copying result back to cpu…
jslhcl 05c0ba2
sync ort-genai changes
jslhcl 7191ea5
sync flash attn 2.5.9 kernel code. test_cuda_paged_attention3 is good…
jslhcl 654cabf
fix runtime error and add new test case
jslhcl e374fb4
UT for paged attention decoding case works
jslhcl 17cf303
fix block_tables and context_lens to avoid unknown random error
jslhcl cb65793
sync with ort-genai and add document
jslhcl File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,111 @@ | ||
// Copyright (c) Microsoft Corporation. All rights reserved. | ||
// Licensed under the MIT License. | ||
|
||
#pragma once | ||
#include "paged_attention_impl.h" | ||
|
||
template<typename T> | ||
struct PagedAttention { | ||
OrtStatusPtr OnModelAttach(const OrtApi& api, const OrtKernelInfo& info) { | ||
int64_t num_heads = 0, head_size = 0; | ||
ORTX_RETURN_IF_ERROR(api.KernelInfoGetAttribute_int64(&info, "num_heads", &num_heads)); | ||
assert(num_heads > 0); | ||
num_heads_ = static_cast<int32_t>(num_heads); | ||
num_kv_heads_ = static_cast<int32_t>(OrtW::GetOpAttributeOrDefault<int64_t>(info, "num_kv_heads", num_heads)); | ||
|
||
ORTX_RETURN_IF_ERROR(api.KernelInfoGetAttribute_int64(&info, "head_size", &head_size)); | ||
assert(head_size > 0); | ||
head_size_ = static_cast<int32_t>(head_size); | ||
|
||
ORTX_RETURN_IF_ERROR(api.KernelInfoGetAttribute_float(&info, "scale", &scale_)); | ||
assert(scale_ > 0); | ||
|
||
num_queries_per_kv_ = num_heads_ / num_kv_heads_; | ||
std::vector<int32_t> head_mapping_host(num_heads_); | ||
for (int i = 0; i < num_kv_heads_; i++) { | ||
for (int j = 0; j < num_queries_per_kv_; j++) { | ||
head_mapping_host[i * num_queries_per_kv_ + j] = i; | ||
} | ||
} | ||
|
||
OrtAllocator* allocator = nullptr; | ||
ORTX_RETURN_IF_ERROR(api.KernelInfoGetAllocator(&info, OrtMemType::OrtMemTypeDefault, &allocator)); | ||
allocator_ = UniquePtrWithDeletor<OrtAllocator>{allocator, [&api](OrtAllocator* p){api.ReleaseAllocator(p);}}; | ||
head_mapping_ = GetScratchBuffer<int32_t>(allocator_->Alloc(allocator_.get(), num_heads_), allocator_.get()); | ||
InitializeHeadMapping(head_mapping_.get(), head_mapping_host.data(), head_mapping_host.size()); | ||
} | ||
|
||
OrtStatusPtr Compute(Ort::Custom::CUDAKernelContext* ctx, const ortc::Tensor<T>& query, const ortc::Tensor<T>& key, | ||
const ortc::Tensor<T>& value, const ortc::Tensor<T>& key_cache, const ortc::Tensor<T>& value_cache, | ||
const ortc::Tensor<int32_t>& block_tables, const ortc::Tensor<int32_t>& slot_mappings, | ||
std::optional<const ortc::Tensor<int32_t>*> context_lens, | ||
std::optional<const ortc::Tensor<int64_t>*> positions | ||
std::optional<const ortc::Tensor<T>*> cos_sin_cache, ortc::Tensor<T>& attn_out) const { | ||
InputMetadata input_metadata; | ||
ORTX_RETURN_IF_ERROR(CheckInputs(ctx.GetCudaStream(), allocator_.get(), query, key, value, key_cache, value_cache, block_tables, slot_mappings, context_lens, positions, input_metadata)); | ||
const std::vector<int64_t>& query_shape = query.Shape(); | ||
T* output_data = attn_out.Allocate(query_shape); | ||
|
||
if (cos_sin_cache.has_value()) { | ||
int64_t rot_dim = (*cos_sin_cache)->Shape()[1]; | ||
assert(rot_dim == head_size_); | ||
rotary_embedding_neox(reinterpret_cast<cudaStream_t>(ctx.GetCudaStream()), (*positions)->Data<int64_t>(), query.DataRaw(), key.DataRaw(), head_size_, | ||
(*cos_sin_cache)->DataRaw(), input_metadata.num_valid_tokens, rot_dim, num_heads_, num_kv_heads_, 1); | ||
} | ||
|
||
const std::vector<int64_t>& key_cache_shape = key_cache.Shape(); | ||
if (input_metadata.num_valid_tokens > 0 && key_cache_shape.size() > 3) { | ||
int64_t key_shape_r[3] = {input_metadata.num_valid_tokens, num_kv_heads_, head_size_}; | ||
int64_t value_shape_r[3] = {input_metadata.num_valid_tokens, num_kv_heads_, head_size_}; | ||
int block_size = gsl::narrow<int>(key_cache_shape[3]); | ||
reshape_and_cache(reinterpret_cast<cudaStream_t>(ctx.GetCudaStream()), key.DataRaw(), value.DataRaw(), key_cache.DataRaw(), value_cache.DataRaw(), slot_mappings.Data(), | ||
key_shape_r, value_shape_r, block_size, key_cache_shape[4], 1); | ||
} | ||
|
||
using TT = typename CudaT<T>::MappedType; | ||
if (input_metadata.num_prompt_tokens > 0) { | ||
//TODO(leca): flash attention for prompt > 0 case | ||
return nullptr; // Don't handle prompt with decoding case for now | ||
} | ||
|
||
if (input_metadata.num_generation_tokens > 0) { | ||
constexpr int PARTITION_SIZE = 512; | ||
int max_num_partitions = (input_metadata.max_context_len + PARTITION_SIZE - 1) / PARTITION_SIZE; | ||
bool use_v1 = max_num_partitions == 1 || (query_shape[0] * query_shape[1]) > PARTITION_SIZE; | ||
int64_t generation_qeury_shape[3] = {input_metadata.num_valid_tokens, num_heads_, head_size_}; | ||
if (use_v1) { | ||
paged_attention_v1(reinterpret_cast<cudaStream_t>(ctx.GetCudaStream()), reinterpret_cast<TT*>(output_data), query.DataRaw(), | ||
key_cache.DataRaw(), value_cache.DataRaw(), head_mapping_.get(), scale_, | ||
block_tables.Data(), context_lens.has_value() ? (*context_lens)->Data() : nullptr, | ||
value_cache.Shape()[3], input_metadata.max_context_len, nullptr, | ||
input_metadata.max_num_blocks_per_seq, generation_qeury_shape, num_queries_per_kv_, 1); | ||
} else { | ||
OrtMemoryInfo* mem_info = nullptr; | ||
ORTX_RETURN_IF_ERROR(OrtW::API::CreateOrtMemoryInfo("Cuda", OrtDeviceAllocator, ctx.device_id, OrtMemTypeDefault, &mem_info)); | ||
void* tmp_output_raw = ctx->GetScratchBufferUnderMultiStream(mem_info, query_shape.size() * max_num_partitions * sizeof(T)); | ||
UniquePtrWithDeletor<T> tmp_output = GetScratchBuffer<T>(tmp_output_raw, allocator_.get()); // TODO(leca): should deallocate inside ORT | ||
void* exp_sums_raw = ctx->GetScratchBufferUnderMultiStream(mem_info, query_shape[0] * query_shape[1] * num_heads_ * max_num_partitions * sizeof(T)); | ||
UniquePtrWithDeletor<T> exp_sums = GetScratchBuffer<T>(exp_sums_raw, allocator_.get()); | ||
void* max_logits_raw = ctx->GetScratchBufferUnderMultiStream(mem_info, query_shape[0] * query_shape[1] * num_heads_ * max_num_partitions * sizeof(T)); | ||
UniquePtrWithDeletor<T> max_logits = GetScratchBuffer<T>(max_logits_raw, allocator_.get()); | ||
paged_attention_v2(reinterpret_cast<cudaStream_t>(ctx.GetCudaStream()), exp_sums_raw, max_logits_raw, tmp_output_raw, reinterpret_cast<TT*>(output_data), query.DataRaw(), | ||
key_cache.DataRaw(), value_cache.DataRaw(), head_mapping_.get(), scale_, | ||
block_tables.Data(), context_lens.has_value() ? (*context_lens)->Data() : nullptr, | ||
value_cache.Shape()[3], input_metadata.max_context_len, nullptr, | ||
input_metadata.max_num_blocks_per_seq, generation_qeury_shape, num_queries_per_kv_, 1); | ||
|
||
OrtW::API::ReleaseMemoryInfo(mem_info); | ||
} | ||
} | ||
return nullptr; | ||
} | ||
|
||
private: | ||
int32_t num_heads_; // number of attention heads | ||
int32_t num_kv_heads_; // number of attention kv_heads | ||
int32_t head_size_; // number of attention heads | ||
float scale_; // sqrt(head_size_) | ||
UniquePtrWithDeletor<int32_t> head_mapping_; | ||
int32_t num_queries_per_kv_; | ||
UniquePtrWithDeletor<OrtAllocator> allocator_; | ||
}; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,145 @@ | ||
#include "paged_attention_impl.h" | ||
#include <vector> | ||
|
||
namespace cuda { | ||
|
||
inline OrtStatusPtr CudaCall(cudaError_t cuda_error) { | ||
if (cuda_error == cudaSuccess) return nullptr; | ||
return OrtW::API::CreateStatus(ORT_FAIL, MakeString("cuda error:", (int)cuda_error).c_str()); | ||
} | ||
|
||
void InitializeHeadMapping(void* dest_data, const void* src_data, size_t count) { | ||
cudaMemcpy(dest_data, src_data, count, cudaMemcpyHostToDevice); | ||
} | ||
|
||
template <typename T> | ||
OrtStatusPtr CheckInputs(const cudaStream_t stream, OrtAllocator* allocator, const ortc::Tensor<T>& query, const ortc::Tensor<T>& key, | ||
const ortc::Tensor<T>& value, const ortc::Tensor<T>& key_cache, const ortc::Tensor<T>& value_cache, | ||
const ortc::Tensor<int32_t>& block_tables, const ortc::Tensor<int32_t>& slot_mappings, | ||
std::optional<const ortc::Tensor<int32_t>*> context_lens, | ||
std::optional<const ortc::Tensor<int64_t>*> positions, InputMetadata& input_metadata) { | ||
const std::vector<int64_t>& query_shape = query.Shape(); | ||
if (query_shape.size() < 2 || query_shape.size() > 3) { | ||
return OrtW::CreateStatus(MakeString("Invalid query shape, expect 2 or 3 dimensions"), ORT_INVALID_ARGUMENT); | ||
} | ||
if (query_shape.back() != num_heads_ * head_size_) { | ||
return OrtW::CreateStatus(MakesString("query shape should equal to num_heads_ * head_size_")); | ||
} | ||
|
||
// TODO(leca): Cpu input or CUDA input? | ||
int seq_len = query_shape.size() == 3 ? query_shape[1] : query_shape[0]; | ||
if (positions.has_value()) { | ||
std::vector<int64_t> positions_host((*positions)->Shape().size()); | ||
ORTX_RETURN_IF_ERROR(CudaCall(cudaMemcpy(positions_host.data(), (*positions)->DataRaw(), (*positions)->SizeInBytes(), cudaMemcpyDeviceToHost))); | ||
while (positions_host.back() == 0) { | ||
positions_host.pop_back(); | ||
seq_len--; | ||
} | ||
|
||
input_metadata.max_num_blocks_per_seq = 0; | ||
// in prompt mode | ||
if (positions_host.size() > 1 || positions_host.back() == 0) { | ||
input_metadata.num_prompt_tokens = seq_len; | ||
input_metadata.num_generation_tokens = 0; | ||
} else { | ||
input_metadata.num_prompt_tokens = 0; | ||
input_metadata.num_generation_tokens = seq_len; | ||
input_metadata.max_context_len = positions_host.back() + 1; // TODO(leca): what if position_host is empty? | ||
|
||
int32_t block_size = gsl::narrow<int32_t>(key_cache.Shape()[3]); | ||
for (int i = 0; i < positions_host.back() + 1; i += block_size) input_metadata.max_num_blocks_per_seq++; | ||
} | ||
} else { | ||
// TODO(leca): context_lens is nullptr? | ||
std::vector<int32_t> context_len_host((*context_lens)->SizeInBytes()); | ||
ORTX_RETURN_IF_ERROR(CudaCall(cudaMemcpy(context_len_host.data(), *(context_lens)->DataRaw(), *(context_lens)->SizeInBytes(), cudaMemcpyDeviceToHost))); | ||
std::vector<int64_t> position_ids; | ||
for (size_t i = 0; i < context_len_host.size(); i++) { | ||
if (context_len_host[i] == 0) continue; | ||
std::vector<int64_t> position_id(context_len_host[i]); | ||
std::iota(position_id.begin(), position_id.end(), 0); // fill position_id with {0, 1, 2, ...context_len_span[i]-1} | ||
position_ids.insert(position_ids.end(), position_id.begin(), position_id.end()); | ||
} | ||
input_metadata.position_ids = GetScratchBuffer<int64_t>(allocator->Alloc(allocator, cnt), allocator); | ||
ORTX_RETURN_IF_ERROR(CudaCall(cudaMemcpyAsync(input_metadata.position_ids.get(), position_ids.data(), position_ids.size(), cudaMemcpyHostToDevice, stream))); | ||
} | ||
input_metadata.num_valid_tokens = seq_len; | ||
|
||
return nullptr; | ||
} | ||
|
||
void paged_attention_v1( | ||
const cudaStream_t stream, | ||
void* out, // [num_seqs, num_heads, head_size] | ||
const void* query, // [num_seqs, num_heads, head_size] | ||
const void* key_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x] | ||
const void* value_cache, // [num_blocks, num_kv_heads, head_size, block_size] | ||
const int* head_mapping, // [num_heads] | ||
float scale, | ||
const int* block_tables, // [num_seqs, max_num_blocks_per_seq] | ||
const int* context_lens, // [num_seqs] | ||
int block_size, | ||
int max_context_len, | ||
const float* __restrict__ alibi_slopes, | ||
const int max_num_blocks_per_seq, | ||
const int64_t* query_shapes, | ||
int num_queries_per_kv, | ||
int dtype) { | ||
|
||
} | ||
|
||
template<typename T> | ||
void paged_attention_v2( | ||
const cudaStream_t stream, | ||
void* out, // [num_seqs, num_heads, head_size] | ||
void* exp_sums, // [num_seqs, num_heads, max_num_partitions] | ||
void* max_logits, // [num_seqs, num_heads, max_num_partitions] | ||
void* tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size] | ||
const void* query, // [num_seqs, num_heads, head_size] | ||
const void* key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] | ||
const void* value_cache, // [num_blocks, num_heads, head_size, block_size] | ||
const int* head_mapping, // [num_heads] | ||
float scale, | ||
const int* block_tables, // [num_seqs, max_num_blocks_per_seq] | ||
const int* context_lens, // [num_seqs] | ||
int block_size, | ||
int max_context_len, | ||
const float* alibi_slopes, | ||
const int max_num_blocks_per_seq, | ||
const int64_t* query_shapes, | ||
int num_queries_per_kv, | ||
int dtype) { | ||
|
||
} | ||
|
||
void rotary_embedding_neox( | ||
const cudaStream_t stream, | ||
const int64_t* positions, // [num_tokens] | ||
void* query, // [num_tokens, num_heads * head_size] | ||
void* key, // [num_tokens, num_kv_heads * head_size] | ||
int head_size, | ||
const void* cos_sin_cache, // [max_position, rot_dim] | ||
int num_tokens, | ||
int rot_dim, | ||
int num_heads, | ||
int num_kv_heads, | ||
int dtype) { | ||
|
||
} | ||
|
||
void reshape_and_cache( | ||
const cudaStream_t stream, | ||
const void* key, // [num_tokens, num_heads, head_size] | ||
const void* value, // [num_tokens, num_heads, head_size] | ||
const void* key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] | ||
const void* value_cache, // [num_blocks, num_heads, head_size, block_size] | ||
const int* slot_mapping, // [num_tokens] | ||
const int64_t* key_shapes, | ||
const int64_t* value_shapes, | ||
const int64_t block_size, | ||
const int vec_x, | ||
int dtype) { | ||
|
||
} | ||
|
||
} // namespace cuda |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you want the kernel can support eager execution, you may use OrtxStatus instead of OrtStatusPtr