File size: 7,203 Bytes
3a83cdf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
import math
from typing import Optional, Tuple, TypeVar
import torch.nn as nn
import torch
import triton

from functools import lru_cache


from .triton_flash_blocksparse_attn import get_local_strided_sparse_attention_op, _get_sparse_attn_mask, blocksparse_flash_attn_padded_fwd, blocksparse_flash_attn_varlen_fwd


Layout = Tuple[torch.LongTensor, torch.LongTensor]


def create_sparse_attn_mask(
    n_heads: int,
    max_seq_len: int,
    max_seq_len_k: int,
    dtype: torch.dtype,
    device: torch.device,
    BLOCK: int,
    local_blocks: int,
    vert_stride: int,
    homo_head: bool,
    return_dense: bool
) -> Tuple[Layout, torch.Tensor, Optional[torch.Tensor]]:
    layout, block_sparse_pattern, _ = _get_sparse_attn_mask(
        n_heads=n_heads,
        q_len=max_seq_len,
        N_CTX=max_seq_len_k,
        dtype=dtype,
        device=device,
        BLOCK=BLOCK,
        local_blocks=local_blocks,
        vert_stride=vert_stride,
        homo_head=homo_head,
        return_dense=return_dense
    )
    return layout, block_sparse_pattern


class BlockSparseAttentionLayer(nn.Module):
    def __init__(
        self,
        n_heads: int,
        max_seq_len: int,
        sparse_block_size: int,
        local_blocks: int,
        vert_stride: int,
        kernel_block_size: Optional[int] = None,
        homo_head: bool = False,
        active_head_range: Optional[Tuple[int]] = None
    ) -> None:
        super().__init__()

        self.n_heads = n_heads
        self.max_seq_len = max_seq_len
        self.sparse_block_size = sparse_block_size
        self.kernel_block_size = kernel_block_size or sparse_block_size
        self.local_blocks = local_blocks
        self.vert_stride = vert_stride
        self.homo_head = homo_head
        self.active_head_range = active_head_range

        # Internal Parameters used by the layer
        self._sparse_block_mask = None
        self._sparse_layout = None
        self._dtype = None
        self._device = None

        # TODO(bapatra): Ideally, I'd want to keep all the code for
        # forward to be handled here, and not branch for training and inference.
        # However, that refactor would need a lot of testing. For now, using the
        # training op as is, and will refactor again later.
    
    def prune_blocksparse_layout_to_heads(self, h_start: int, h_end: int) -> None:
        self._sparse_block_mask = self._sparse_block_mask[h_start: h_end]
        self._sparse_layout[0] = self._sparse_layout[0][h_start: h_end]
        self._sparse_layout[1] = self._sparse_layout[1][h_start: h_end]
    
    def _initialize_internals(
        self,
        dtype: torch.dtype,
        device: torch.device
    ) -> None:
        self._dtype, self._device = dtype, device
        self._sparse_layout, self._sparse_block_mask = create_sparse_attn_mask(
            n_heads=self.n_heads,
            max_seq_len=self.max_seq_len,
            max_seq_len_k=self.max_seq_len,
            dtype=dtype,
            device=device,
            BLOCK=self.sparse_block_size,
            local_blocks=self.local_blocks,
            vert_stride=self.vert_stride,
            homo_head=self.homo_head,
            return_dense=False,
        )
        if (not self.homo_head) and (self.active_head_range is not None):
            assert len(self.active_head_range) == 2, "\"active_head_range\" should be a tuple of start/end index of the heads."
            h_start, h_end = self.active_head_range
            self.prune_blocksparse_layout_to_heads(h_start=h_start, h_end=h_end)

        assert self.sparse_block_size % self.kernel_block_size == 0,  f"The sparse block size must be a multiple of {self.kernel_block_size}. Found {self.sparse_block_size}."
        assert self.kernel_block_size >=16 and math.log2(self.kernel_block_size) % 1 == 0, f"block_size must be power of 2 and at least 16, but {self.kernel_block_size} is given"
        if self.sparse_block_size // self.kernel_block_size > 1:
            _mul = self.sparse_block_size // self.kernel_block_size
            # need to consider if block_m and block_n are different
            self._sparse_block_mask = torch.kron(self._sparse_block_mask, self._sparse_block_mask.new_ones(_mul, _mul))
            num_sparse_blocks = self._sparse_block_mask.size(-1)
            block_causal_mask = torch.arange(0, num_sparse_blocks)[:, None] >= torch.arange(0, num_sparse_blocks)[None]
            self._sparse_block_mask *= block_causal_mask.type_as(self._sparse_block_mask)


    def forward(
        self,
        q: torch.Tensor,
        k: torch.Tensor,
        v: torch.Tensor,
        sm_scale: float,
        *,
        # Arguments Related to Block Attention Inference
        left_paddings: Optional[torch.LongTensor] = None,
        seqlens: Optional[torch.LongTensor] = None,
        # Arguements Related to Variable Length Inference
        cu_seqlens_k: Optional[torch.LongTensor] = None,
        cu_seqlens_q: Optional[torch.LongTensor] = None,
    ) -> torch.Tensor:

        if left_paddings is None and seqlens is None and cu_seqlens_k is None and cu_seqlens_q is None:
            blocksparse_op = get_local_strided_sparse_attention_op(
                n_heads=self.n_heads,
                max_seq_len=self.max_seq_len,
                sparse_block_size=self.sparse_block_size,
                kernel_block_size=self.kernel_block_size,
                local_blocks=self.local_blocks,
                vert_stride=self.vert_stride,
                homo_head=self.homo_head,
                device=q.device,
                inference=not self.training
            )
            return blocksparse_op(q, k, v, sm_scale)

        assert not torch.is_grad_enabled(), "Variable Length Inference / Batched inference is not supported during training. Please run it in a torch.no_grad() context"
        # First set internals if they have not been set
        if self._sparse_block_mask is None or (self._dtype != q.dtype) or (self._device != q.device):
            self._initialize_internals(dtype=q.dtype, device=q.device)
        
        if k.dim() == 3:
            assert cu_seqlens_k is not None
            return blocksparse_flash_attn_varlen_fwd(
                q=q,
                k=k,
                v=v,
                cu_seqlens_k=cu_seqlens_k,
                cu_seqlens_q=cu_seqlens_q,
                sm_scale=sm_scale,
                sparse_layout=self._sparse_layout,
                block_size=self.kernel_block_size,
                max_seqlen=self.max_seq_len,
            )
        if k.dim() == 4:
            assert not (left_paddings is None and seqlens is None), "Either left_paddings or seqlens must be provided for batched inference."
            return blocksparse_flash_attn_padded_fwd(
                q=q,
                k=k,
                v=v,
                sm_scale=sm_scale,
                sparse_layout=self._sparse_layout,
                left_paddings=left_paddings,
                seqlens=seqlens,
                block_size=self.kernel_block_size,
                max_seqlen=self.max_seq_len,
            )
        raise ValueError('q/k/v must be either 3 dim for variable-length input or 4 dim for fixed-length.')