|
#pragma once |
|
|
|
#include <torch/torch.h> |
|
|
|
void paged_attention_v1( |
|
torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache, |
|
torch::Tensor& value_cache, int64_t num_kv_heads, double scale, |
|
torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size, |
|
int64_t max_seq_len, const std::optional<torch::Tensor>& alibi_slopes, |
|
const std::string& kv_cache_dtype, torch::Tensor& k_scale, |
|
torch::Tensor& v_scale, const int64_t tp_rank, |
|
const int64_t blocksparse_local_blocks, |
|
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, |
|
const int64_t blocksparse_head_sliding_step); |
|
|
|
void paged_attention_v2( |
|
torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits, |
|
torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache, |
|
torch::Tensor& value_cache, int64_t num_kv_heads, double scale, |
|
torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size, |
|
int64_t max_seq_len, const std::optional<torch::Tensor>& alibi_slopes, |
|
const std::string& kv_cache_dtype, torch::Tensor& k_scale, |
|
torch::Tensor& v_scale, const int64_t tp_rank, |
|
const int64_t blocksparse_local_blocks, |
|
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, |
|
const int64_t blocksparse_head_sliding_step); |
|
|
|
void swap_blocks(torch::Tensor& src, torch::Tensor& dst, |
|
const torch::Tensor& block_mapping); |
|
|
|
|
|
|
|
|
|
void copy_blocks(std::vector<torch::Tensor> const& key_caches, |
|
std::vector<torch::Tensor> const& value_caches, |
|
const torch::Tensor& block_mapping); |
|
|
|
void reshape_and_cache(torch::Tensor& key, torch::Tensor& value, |
|
torch::Tensor& key_cache, torch::Tensor& value_cache, |
|
torch::Tensor& slot_mapping, |
|
const std::string& kv_cache_dtype, |
|
torch::Tensor& k_scale, torch::Tensor& v_scale); |
|
|
|
void reshape_and_cache_flash(torch::Tensor& key, torch::Tensor& value, |
|
torch::Tensor& key_cache, |
|
torch::Tensor& value_cache, |
|
torch::Tensor& slot_mapping, |
|
const std::string& kv_cache_dtype, |
|
torch::Tensor& k_scale, torch::Tensor& v_scale); |
|
|
|
int64_t get_device_attribute(int64_t attribute, int64_t device_id); |
|
|
|
int64_t get_max_shared_memory_per_block_device_attribute(int64_t device_id); |
|
|
|
void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache, |
|
const double scale, const std::string& kv_cache_dtype); |
|
|