from typing import List, Optional import torch from ._ops import ops # page attention ops def paged_attention_v1( out: torch.Tensor, query: torch.Tensor, key_cache: torch.Tensor, value_cache: torch.Tensor, num_kv_heads: int, scale: float, block_tables: torch.Tensor, seq_lens: torch.Tensor, block_size: int, max_seq_len: int, alibi_slopes: Optional[torch.Tensor], kv_cache_dtype: str, k_scale: float, v_scale: float, tp_rank: int = 0, blocksparse_local_blocks: int = 0, blocksparse_vert_stride: int = 0, blocksparse_block_size: int = 64, blocksparse_head_sliding_step: int = 0, ) -> None: ops.paged_attention_v1( out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, seq_lens, block_size, max_seq_len, alibi_slopes, kv_cache_dtype, k_scale, v_scale, tp_rank, blocksparse_local_blocks, blocksparse_vert_stride, blocksparse_block_size, blocksparse_head_sliding_step, ) def paged_attention_v2( out: torch.Tensor, exp_sum: torch.Tensor, max_logits: torch.Tensor, tmp_out: torch.Tensor, query: torch.Tensor, key_cache: torch.Tensor, value_cache: torch.Tensor, num_kv_heads: int, scale: float, block_tables: torch.Tensor, seq_lens: torch.Tensor, block_size: int, max_seq_len: int, alibi_slopes: Optional[torch.Tensor], kv_cache_dtype: str, k_scale: float, v_scale: float, tp_rank: int = 0, blocksparse_local_blocks: int = 0, blocksparse_vert_stride: int = 0, blocksparse_block_size: int = 64, blocksparse_head_sliding_step: int = 0, ) -> None: ops.paged_attention_v2( out, exp_sum, max_logits, tmp_out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, seq_lens, block_size, max_seq_len, alibi_slopes, kv_cache_dtype, k_scale, v_scale, tp_rank, blocksparse_local_blocks, blocksparse_vert_stride, blocksparse_block_size, blocksparse_head_sliding_step, ) def reshape_and_cache( key: torch.Tensor, value: torch.Tensor, key_cache: torch.Tensor, value_cache: torch.Tensor, slot_mapping: torch.Tensor, kv_cache_dtype: str, k_scale: float, v_scale: float, ) -> None: ops.reshape_and_cache( key, value, key_cache, value_cache, slot_mapping, kv_cache_dtype, k_scale, v_scale, ) def reshape_and_cache_flash( key: torch.Tensor, value: torch.Tensor, key_cache: torch.Tensor, value_cache: torch.Tensor, slot_mapping: torch.Tensor, kv_cache_dtype: str, k_scale: torch.Tensor, v_scale: torch.Tensor, ) -> None: ops.reshape_and_cache_flash( key, value, key_cache, value_cache, slot_mapping, kv_cache_dtype, k_scale, v_scale, ) def copy_blocks( key_caches: List[torch.Tensor], value_caches: List[torch.Tensor], block_mapping: torch.Tensor, ) -> None: ops.copy_blocks(key_caches, value_caches, block_mapping) def swap_blocks( src: torch.Tensor, dst: torch.Tensor, block_mapping: torch.Tensor ) -> None: ops.swap_blocks(src, dst, block_mapping) def convert_fp8( output: torch.Tensor, input: torch.Tensor, scale: float = 1.0, kv_dtype: str = "fp8" ) -> None: ops.convert_fp8(output, input, scale, kv_dtype) __all__ = [ "convert_fp8", "paged_attention_v1", "paged_attention_v2", "reshape_and_cache", "copy_blocks", ]