|
from typing import List, Optional |
|
|
|
import torch |
|
|
|
from ._ops import ops |
|
|
|
|
|
|
|
def paged_attention_v1( |
|
out: torch.Tensor, |
|
query: torch.Tensor, |
|
key_cache: torch.Tensor, |
|
value_cache: torch.Tensor, |
|
num_kv_heads: int, |
|
scale: float, |
|
block_tables: torch.Tensor, |
|
seq_lens: torch.Tensor, |
|
block_size: int, |
|
max_seq_len: int, |
|
alibi_slopes: Optional[torch.Tensor], |
|
kv_cache_dtype: str, |
|
k_scale: float, |
|
v_scale: float, |
|
tp_rank: int = 0, |
|
blocksparse_local_blocks: int = 0, |
|
blocksparse_vert_stride: int = 0, |
|
blocksparse_block_size: int = 64, |
|
blocksparse_head_sliding_step: int = 0, |
|
) -> None: |
|
ops.paged_attention_v1( |
|
out, |
|
query, |
|
key_cache, |
|
value_cache, |
|
num_kv_heads, |
|
scale, |
|
block_tables, |
|
seq_lens, |
|
block_size, |
|
max_seq_len, |
|
alibi_slopes, |
|
kv_cache_dtype, |
|
k_scale, |
|
v_scale, |
|
tp_rank, |
|
blocksparse_local_blocks, |
|
blocksparse_vert_stride, |
|
blocksparse_block_size, |
|
blocksparse_head_sliding_step, |
|
) |
|
|
|
|
|
def paged_attention_v2( |
|
out: torch.Tensor, |
|
exp_sum: torch.Tensor, |
|
max_logits: torch.Tensor, |
|
tmp_out: torch.Tensor, |
|
query: torch.Tensor, |
|
key_cache: torch.Tensor, |
|
value_cache: torch.Tensor, |
|
num_kv_heads: int, |
|
scale: float, |
|
block_tables: torch.Tensor, |
|
seq_lens: torch.Tensor, |
|
block_size: int, |
|
max_seq_len: int, |
|
alibi_slopes: Optional[torch.Tensor], |
|
kv_cache_dtype: str, |
|
k_scale: float, |
|
v_scale: float, |
|
tp_rank: int = 0, |
|
blocksparse_local_blocks: int = 0, |
|
blocksparse_vert_stride: int = 0, |
|
blocksparse_block_size: int = 64, |
|
blocksparse_head_sliding_step: int = 0, |
|
) -> None: |
|
ops.paged_attention_v2( |
|
out, |
|
exp_sum, |
|
max_logits, |
|
tmp_out, |
|
query, |
|
key_cache, |
|
value_cache, |
|
num_kv_heads, |
|
scale, |
|
block_tables, |
|
seq_lens, |
|
block_size, |
|
max_seq_len, |
|
alibi_slopes, |
|
kv_cache_dtype, |
|
k_scale, |
|
v_scale, |
|
tp_rank, |
|
blocksparse_local_blocks, |
|
blocksparse_vert_stride, |
|
blocksparse_block_size, |
|
blocksparse_head_sliding_step, |
|
) |
|
|
|
|
|
def reshape_and_cache( |
|
key: torch.Tensor, |
|
value: torch.Tensor, |
|
key_cache: torch.Tensor, |
|
value_cache: torch.Tensor, |
|
slot_mapping: torch.Tensor, |
|
kv_cache_dtype: str, |
|
k_scale: float, |
|
v_scale: float, |
|
) -> None: |
|
ops.reshape_and_cache( |
|
key, |
|
value, |
|
key_cache, |
|
value_cache, |
|
slot_mapping, |
|
kv_cache_dtype, |
|
k_scale, |
|
v_scale, |
|
) |
|
|
|
|
|
def reshape_and_cache_flash( |
|
key: torch.Tensor, |
|
value: torch.Tensor, |
|
key_cache: torch.Tensor, |
|
value_cache: torch.Tensor, |
|
slot_mapping: torch.Tensor, |
|
kv_cache_dtype: str, |
|
k_scale: torch.Tensor, |
|
v_scale: torch.Tensor, |
|
) -> None: |
|
ops.reshape_and_cache_flash( |
|
key, |
|
value, |
|
key_cache, |
|
value_cache, |
|
slot_mapping, |
|
kv_cache_dtype, |
|
k_scale, |
|
v_scale, |
|
) |
|
|
|
|
|
def copy_blocks( |
|
key_caches: List[torch.Tensor], |
|
value_caches: List[torch.Tensor], |
|
block_mapping: torch.Tensor, |
|
) -> None: |
|
ops.copy_blocks(key_caches, value_caches, block_mapping) |
|
|
|
|
|
def swap_blocks( |
|
src: torch.Tensor, dst: torch.Tensor, block_mapping: torch.Tensor |
|
) -> None: |
|
ops.swap_blocks(src, dst, block_mapping) |
|
|
|
|
|
def convert_fp8( |
|
output: torch.Tensor, input: torch.Tensor, scale: float = 1.0, kv_dtype: str = "fp8" |
|
) -> None: |
|
ops.convert_fp8(output, input, scale, kv_dtype) |
|
|
|
|
|
__all__ = [ |
|
"convert_fp8", |
|
"paged_attention_v1", |
|
"paged_attention_v2", |
|
"reshape_and_cache", |
|
"copy_blocks", |
|
] |
|
|