sharpenb commited on
Commit
49b8ca0
1 Parent(s): 5e9ea67

Upload folder using huggingface_hub (#1)

Browse files

- 280a4464349485d578cd0583f7e14d683f0783e5ead54f847d48a475f96be979 (4514f143501729dad06b8894f8e96d3f1f665356)
- 25300cc843729640862f53f6b2617077f0f3d295d8b67795b949167bb850fbe6 (0fd61643772d9961673f09e163a0501900f20bcf)

README.md ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ thumbnail: "https://assets-global.website-files.com/646b351987a8d8ce158d1940/64ec9e96b4334c0e1ac41504_Logo%20with%20white%20text.svg"
3
+ base_model: numind/NuExtract-large
4
+ metrics:
5
+ - memory_disk
6
+ - memory_inference
7
+ - inference_latency
8
+ - inference_throughput
9
+ - inference_CO2_emissions
10
+ - inference_energy_consumption
11
+ tags:
12
+ - pruna-ai
13
+ ---
14
+ <!-- header start -->
15
+ <!-- 200823 -->
16
+ <div style="width: auto; margin-left: auto; margin-right: auto">
17
+ <a href="https://www.pruna.ai/" target="_blank" rel="noopener noreferrer">
18
+ <img src="https://i.imgur.com/eDAlcgk.png" alt="PrunaAI" style="width: 100%; min-width: 400px; display: block; margin: auto;">
19
+ </a>
20
+ </div>
21
+ <!-- header end -->
22
+
23
+ [![Twitter](https://img.shields.io/twitter/follow/PrunaAI?style=social)](https://twitter.com/PrunaAI)
24
+ [![GitHub](https://img.shields.io/github/followers/PrunaAI?label=Follow%20%40PrunaAI&style=social)](https://github.com/PrunaAI)
25
+ [![LinkedIn](https://img.shields.io/badge/LinkedIn-Connect-blue)](https://www.linkedin.com/company/93832878/admin/feed/posts/?feedType=following)
26
+ [![Discord](https://img.shields.io/badge/Discord-Join%20Us-blue?style=social&logo=discord)](https://discord.gg/rskEr4BZJx)
27
+
28
+ # Simply make AI models cheaper, smaller, faster, and greener!
29
+
30
+ - Give a thumbs up if you like this model!
31
+ - Contact us and tell us which model to compress next [here](https://www.pruna.ai/contact).
32
+ - Request access to easily compress your *own* AI models [here](https://z0halsaff74.typeform.com/pruna-access?typeform-source=www.pruna.ai).
33
+ - Read the documentations to know more [here](https://pruna-ai-pruna.readthedocs-hosted.com/en/latest/)
34
+ - Join Pruna AI community on Discord [here](https://discord.gg/CP4VSgck) to share feedback/suggestions or get help.
35
+
36
+ ## Results
37
+
38
+ ![image info](./plots.png)
39
+
40
+ **Frequently Asked Questions**
41
+ - ***How does the compression work?*** The model is compressed with llm-int8.
42
+ - ***How does the model quality change?*** The quality of the model output might vary compared to the base model.
43
+ - ***How is the model efficiency evaluated?*** These results were obtained on HARDWARE_NAME with configuration described in `model/smash_config.json` and are obtained after a hardware warmup. The smashed model is directly compared to the original base model. Efficiency results may vary in other settings (e.g. other hardware, image size, batch size, ...). We recommend to directly run them in the use-case conditions to know if the smashed model can benefit you.
44
+ - ***What is the model format?*** We use safetensors.
45
+ - ***What calibration data has been used?*** If needed by the compression method, we used WikiText as the calibration data.
46
+ - ***What is the naming convention for Pruna Huggingface models?*** We take the original model name and append "turbo", "tiny", or "green" if the smashed model has a measured inference speed, inference memory, or inference energy consumption which is less than 90% of the original base model.
47
+ - ***How to compress my own models?*** You can request premium access to more compression methods and tech support for your specific use-cases [here](https://z0halsaff74.typeform.com/pruna-access?typeform-source=www.pruna.ai).
48
+ - ***What are "first" metrics?*** Results mentioning "first" are obtained after the first run of the model. The first run might take more memory or be slower than the subsequent runs due cuda overheads.
49
+ - ***What are "Sync" and "Async" metrics?*** "Sync" metrics are obtained by syncing all GPU processes and stop measurement when all of them are executed. "Async" metrics are obtained without syncing all GPU processes and stop when the model output can be used by the CPU. We provide both metrics since both could be relevant depending on the use-case. We recommend to test the efficiency gains directly in your use-cases.
50
+
51
+ ## Setup
52
+
53
+ You can run the smashed model with these steps:
54
+
55
+ 0. Check requirements from the original repo numind/NuExtract-large installed. In particular, check python, cuda, and transformers versions.
56
+ 1. Make sure that you have installed quantization related packages.
57
+ ```bash
58
+ pip install transformers accelerate bitsandbytes>0.37.0
59
+ ```
60
+ 2. Load & run the model.
61
+ ```python
62
+ from transformers import AutoModelForCausalLM, AutoTokenizer
63
+
64
+
65
+ model = AutoModelForCausalLM.from_pretrained("PrunaAI/numind-NuExtract-large-bnb-4bit-smashed", trust_remote_code=True, device_map='auto')
66
+ tokenizer = AutoTokenizer.from_pretrained("numind/NuExtract-large")
67
+
68
+ input_ids = tokenizer("What is the color of prunes?,", return_tensors='pt').to(model.device)["input_ids"]
69
+
70
+ outputs = model.generate(input_ids, max_new_tokens=216)
71
+ tokenizer.decode(outputs[0])
72
+ ```
73
+
74
+ ## Configurations
75
+
76
+ The configuration info are in `smash_config.json`.
77
+
78
+ ## Credits & License
79
+
80
+ The license of the smashed model follows the license of the original model. Please check the license of the original model numind/NuExtract-large before using this model which provided the base model. The license of the `pruna-engine` is [here](https://pypi.org/project/pruna-engine/) on Pypi.
81
+
82
+ ## Want to compress other models?
83
+
84
+ - Contact us and tell us which model to compress next [here](https://www.pruna.ai/contact).
85
+ - Request access to easily compress your own AI models [here](https://z0halsaff74.typeform.com/pruna-access?typeform-source=www.pruna.ai).
cl100k_base.tiktoken ADDED
The diff for this file is too large to render. See raw diff
 
config.json ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "/ceph/hdd/staff/charpent/.cache/modelsjayrsbxckim1dlli",
3
+ "architectures": [
4
+ "Phi3SmallForCausalLM"
5
+ ],
6
+ "attention_bias": false,
7
+ "attention_dropout_prob": 0.0,
8
+ "auto_map": {
9
+ "AutoConfig": "configuration_phi3_small.Phi3SmallConfig",
10
+ "AutoModelForCausalLM": "modeling_phi3_small.Phi3SmallForCausalLM",
11
+ "AutoModelForSequenceClassification": "numind/NuExtract-large--modeling_phi3_small.Phi3SmallForSequenceClassification",
12
+ "AutoTokenizer": "numind/NuExtract-large--tokenization_phi3_small.Phi3SmallTokenizer"
13
+ },
14
+ "blocksparse_block_size": 64,
15
+ "blocksparse_homo_head_pattern": false,
16
+ "blocksparse_num_local_blocks": 16,
17
+ "blocksparse_triton_kernel_block_size": 64,
18
+ "blocksparse_vert_stride": 8,
19
+ "bos_token_id": 100257,
20
+ "dense_attention_every_n_layers": 2,
21
+ "dummy_token_indices": [
22
+ 100256,
23
+ 100258,
24
+ 100259,
25
+ 100260,
26
+ 100264,
27
+ 100265,
28
+ 100267,
29
+ 100268,
30
+ 100269,
31
+ 100270,
32
+ 100271,
33
+ 100272,
34
+ 100273,
35
+ 100274,
36
+ 100275,
37
+ 100276,
38
+ 100277,
39
+ 100278,
40
+ 100279,
41
+ 100280,
42
+ 100281,
43
+ 100282,
44
+ 100283,
45
+ 100284,
46
+ 100285,
47
+ 100286,
48
+ 100287,
49
+ 100288,
50
+ 100289,
51
+ 100290,
52
+ 100291,
53
+ 100292,
54
+ 100293,
55
+ 100294,
56
+ 100295,
57
+ 100296,
58
+ 100297,
59
+ 100298,
60
+ 100299,
61
+ 100300,
62
+ 100301,
63
+ 100302,
64
+ 100303,
65
+ 100304,
66
+ 100305,
67
+ 100306,
68
+ 100307,
69
+ 100308,
70
+ 100309,
71
+ 100310,
72
+ 100311,
73
+ 100312,
74
+ 100313,
75
+ 100314,
76
+ 100315,
77
+ 100316,
78
+ 100317,
79
+ 100318,
80
+ 100319,
81
+ 100320,
82
+ 100321,
83
+ 100322,
84
+ 100323,
85
+ 100324,
86
+ 100325,
87
+ 100326,
88
+ 100327,
89
+ 100328,
90
+ 100329,
91
+ 100330,
92
+ 100331,
93
+ 100332,
94
+ 100333,
95
+ 100334,
96
+ 100335,
97
+ 100336,
98
+ 100337,
99
+ 100338,
100
+ 100339,
101
+ 100340,
102
+ 100341,
103
+ 100342,
104
+ 100343,
105
+ 100344,
106
+ 100345,
107
+ 100346,
108
+ 100347,
109
+ 100348,
110
+ 100349,
111
+ 100350,
112
+ 100351
113
+ ],
114
+ "embedding_dropout_prob": 0.1,
115
+ "eos_token_id": 100257,
116
+ "ff_dim_multiplier": null,
117
+ "ff_intermediate_size": 14336,
118
+ "ffn_dropout_prob": 0.1,
119
+ "gegelu_limit": 20.0,
120
+ "gegelu_pad_to_256": true,
121
+ "hidden_act": "gegelu",
122
+ "hidden_size": 4096,
123
+ "initializer_range": 0.02,
124
+ "layer_norm_epsilon": 1e-05,
125
+ "max_position_embeddings": 8192,
126
+ "model_type": "phi3small",
127
+ "mup_attn_multiplier": 1.0,
128
+ "mup_embedding_multiplier": 10.0,
129
+ "mup_use_scaling": true,
130
+ "mup_width_multiplier": 8.0,
131
+ "num_attention_heads": 32,
132
+ "num_hidden_layers": 32,
133
+ "num_key_value_heads": 8,
134
+ "pad_sequence_to_multiple_of_64": true,
135
+ "quantization_config": {
136
+ "_load_in_4bit": true,
137
+ "_load_in_8bit": false,
138
+ "bnb_4bit_compute_dtype": "bfloat16",
139
+ "bnb_4bit_quant_storage": "uint8",
140
+ "bnb_4bit_quant_type": "fp4",
141
+ "bnb_4bit_use_double_quant": false,
142
+ "llm_int8_enable_fp32_cpu_offload": false,
143
+ "llm_int8_has_fp16_weight": false,
144
+ "llm_int8_skip_modules": [
145
+ "lm_head"
146
+ ],
147
+ "llm_int8_threshold": 6.0,
148
+ "load_in_4bit": true,
149
+ "load_in_8bit": false,
150
+ "quant_method": "bitsandbytes"
151
+ },
152
+ "reorder_and_upcast_attn": false,
153
+ "rope_embedding_base": 1000000,
154
+ "rope_position_scale": 1.0,
155
+ "rope_scaling": null,
156
+ "torch_dtype": "float16",
157
+ "transformers_version": "4.42.4",
158
+ "use_cache": true,
159
+ "vocab_size": 100352
160
+ }
configuration_phi3_small.py ADDED
@@ -0,0 +1,250 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team.
3
+ # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ from typing import Any, Dict, List, Optional, Union
17
+
18
+ from transformers.configuration_utils import PretrainedConfig
19
+ from transformers.utils import logging
20
+
21
+ from functools import cached_property
22
+
23
+ """ Phi3Small model configuration """
24
+ logger = logging.get_logger(__name__)
25
+
26
+
27
+ def next_mult(x, y):
28
+ return (x + y - 1) // y * y
29
+
30
+ class Phi3SmallConfig(PretrainedConfig):
31
+ """
32
+ This is the configuration class to store the configuration of a `Phi3Small` model. It is used to
33
+ instantiate a Phi-3-small model according to the specified arguments, defining the model architecture.
34
+ Instantiating a configuration with the defaults will yield a similar configuration to that of the Phi-3-small
35
+ [phi3](https://arxiv.org/pdf/2404.14219) architecture.
36
+
37
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
38
+ documentation from [`PretrainedConfig`] for more information.
39
+
40
+
41
+ Args:
42
+ vocab_size (`int`, *optional*, defaults to 100352):
43
+ Vocabulary size of the Phi3Small model. Defines the number of different tokens that can be represented by the
44
+ `inputs_ids` passed when calling `Phi3Small`.
45
+ max_position_embeddings (`int`, *optional*, defaults to 8192):
46
+ The maximum sequence length that this model might safely be used with.
47
+ rope_embedding_base (`float`, *optional*, defaults to 10^6):
48
+ The base value for the RoPE (Relative Position Encoding) embedding.
49
+ rope_position_scale (`float`, *optional*, defaults to 1.0):
50
+ The scale factor for the RoPE position encoding.
51
+ rope_scaling (`Optional[Dict[str, Union[float, List[float], int]]]`, *optional*, defaults to None):
52
+ The scaling configuration used for LongRoPE.
53
+ hidden_size (`int`, *optional*, defaults to 4096):
54
+ The size of the hidden layers in the model.
55
+ num_hidden_layers (`int`, *optional*, defaults to 32):
56
+ The number of layers in the model.
57
+ num_attention_heads (`int`, *optional*, defaults to 32):
58
+ The number of query heads in the model.
59
+ num_key_value_heads (`int`, *optional*, defaults to 8):
60
+ The number of key-value heads in the model.
61
+ hidden_act (`str`, *optional*, defaults to "gegelu"):
62
+ The activation function used in the model.
63
+ gegelu_limit (`float`, *optional*, defaults to 20.0):
64
+ The limit value for the GELU activation function (for numerical stability).
65
+ gegelu_pad_to_256 (`bool`, *optional*, defaults to True):
66
+ Whether to pad the intermediate size to a multiple of 256 (for faster matmul ops).
67
+ ff_dim_multiplier (`Optional[int]`, *optional*, defaults to None):
68
+ The dimension multiplier for the feed-forward layers.
69
+ ff_intermediate_size (`Optional[int]`, *optional*, defaults to 14336):
70
+ The intermediate size for the feed-forward layers.
71
+ One of `ff_dim_multiplier` or `ff_intermediate_size` must be specified.
72
+ blocksparse_homo_head_pattern (`bool`, *optional*, defaults to False):
73
+ Whether to use a homogeneous head pattern for block-sparse attention.
74
+ blocksparse_block_size (`int`, *optional*, defaults to 64):
75
+ The block size for block-sparse attention.
76
+ blocksparse_num_local_blocks (`int`, *optional*, defaults to 16):
77
+ The number of local blocks for block-sparse attention.
78
+ The local window used in blocksparse equals `blocksparse_num_local_blocks * blocksparse_block_size`
79
+ blocksparse_vert_stride (`int`, *optional*, defaults to 8):
80
+ The vertical stride for block-sparse attention.
81
+ blocksparse_triton_kernel_block_size (`int`, *optional*, defaults to 64):
82
+ The kernel block size for block-sparse attention.
83
+ dense_attention_every_n_layers (`Optional[int]`, *optional*, defaults to 2):
84
+ The frequency of all dense attention layers in the model
85
+ embedding_dropout_prob (`float`, *optional*, defaults to 0.1):
86
+ The dropout probability for the embedding layer.
87
+ attention_dropout_prob (`float`, *optional*, defaults to 0.0):
88
+ The dropout probability for the attention layers.
89
+ ffn_dropout_prob (`float`, *optional*, defaults to 0.1):
90
+ The dropout probability for the feed-forward layers.
91
+ layer_norm_epsilon (`float`, *optional*, defaults to 1e-5):
92
+ The epsilon value for layer normalization.
93
+ initializer_range (`float`, *optional*, defaults to 0.02):
94
+ The range for weight initialization.
95
+ mup_use_scaling (`bool`, *optional*, defaults to True):
96
+ Whether to use scaling for MuP parameters (see: https://arxiv.org/abs/2203.03466).
97
+ mup_width_multiplier (`bool`, *optional*, defaults to 8.0):
98
+ The width multiplier for MuP.
99
+ mup_embedding_multiplier (`bool`, *optional*, defaults to 10.0):
100
+ The embedding multiplier for MuP.
101
+ mup_attn_multiplier (`bool`, *optional*, defaults to 1.0):
102
+ The attention multiplier for MuP.
103
+ use_cache (`bool`, *optional*, defaults to True):
104
+ Whether to use cache for the model.
105
+ bos_token_id (`int`, *optional*, defaults to 100257):
106
+ The token ID for the beginning of sentence.
107
+ eos_token_id (`int`, *optional*, defaults to 100257):
108
+ The token ID for the end of sentence.
109
+ reorder_and_upcast_attn (`bool`, *optional*, defaults to False):
110
+ Whether to reorder and upcast attention.
111
+ pad_sequence_to_multiple_of_64 (`bool`, *optional*, defaults to True):
112
+ Whether to pad the sequence length to a multiple of 64.
113
+ **kwargs:
114
+ Additional keyword arguments.
115
+
116
+ Example:
117
+
118
+ ```python
119
+ >>> from transformers import Phi3SmallConfig, Phi3SmallModel
120
+
121
+ >>> # Initializing a Phi3Small configuration
122
+ >>> configuration = Phi3SmallConfig()
123
+
124
+ >>> # Initializing a model (with random weights) from the configuration
125
+ >>> model = Phi3SmallModel(configuration)
126
+
127
+ >>> # Accessing the model configuration
128
+ >>> configuration = model.config
129
+ ```
130
+ """
131
+
132
+ model_type = "phi3small"
133
+ keys_to_ignore_at_inference = ["past_key_values"]
134
+
135
+
136
+ def __init__(
137
+ self,
138
+ # General information about the model
139
+ vocab_size: int =100352,
140
+ max_position_embeddings: int = 8192,
141
+ # RoPE Related Parameters
142
+ rope_embedding_base: float = 10**6,
143
+ rope_position_scale: float = 1.0,
144
+ rope_scaling: Optional[Dict[str, Union[float, List[float], int]]] = None,
145
+ # General Model Parameters
146
+ hidden_size: int = 4096,
147
+ num_hidden_layers: int = 32,
148
+ # KV Shared Attention Configurations
149
+ num_attention_heads: int = 32,
150
+ num_key_value_heads: int = 8,
151
+ # GEGELU Related Parameters
152
+ hidden_act: str = "gegelu",
153
+ gegelu_limit: float = 20.0,
154
+ gegelu_pad_to_256: bool = True,
155
+ ff_dim_multiplier: Optional[int] = None,
156
+ ff_intermediate_size: Optional[int] = 14336,
157
+ # Block Sparse Attention Parameters
158
+ blocksparse_homo_head_pattern: bool = False,
159
+ blocksparse_block_size: int = 64,
160
+ blocksparse_num_local_blocks: int = 16,
161
+ blocksparse_vert_stride: int = 8,
162
+ blocksparse_triton_kernel_block_size: int = 64,
163
+ # Frequency of block-sparsity
164
+ dense_attention_every_n_layers: Optional[int] = 2,
165
+ # Reegularization parameters
166
+ embedding_dropout_prob: float =0.1,
167
+ attention_dropout_prob: float = 0.0,
168
+ ffn_dropout_prob: float = 0.1,
169
+ layer_norm_epsilon=1e-5,
170
+ initializer_range=0.02,
171
+ # MuP parameters
172
+ mup_use_scaling: bool = True,
173
+ mup_width_multiplier: bool = 8.0,
174
+ mup_embedding_multiplier: bool = 10.0,
175
+ mup_attn_multiplier: bool =1.0,
176
+ use_cache=True,
177
+ # The model does not have a bos token id
178
+ # However, in order for some of the downstream libraries to not break
179
+ # we set this to be the same as the eos_token_id
180
+ bos_token_id: int = 100257,
181
+ eos_token_id: int = 100257,
182
+ reorder_and_upcast_attn=False,
183
+ # Configuration to pad sequence length to a multiple of 64
184
+ pad_sequence_to_multiple_of_64: bool = True,
185
+ **kwargs,
186
+ ):
187
+ self.vocab_size = vocab_size
188
+ self.max_position_embeddings = max_position_embeddings
189
+ self.rope_embedding_base = rope_embedding_base
190
+ self.rope_position_scale = rope_position_scale
191
+ self.rope_scaling = rope_scaling
192
+ self.hidden_size = hidden_size
193
+ # QK Shared Attention
194
+ self.num_hidden_layers = num_hidden_layers
195
+ self.num_attention_heads = num_attention_heads
196
+ self.num_key_value_heads = num_key_value_heads
197
+ # Block Sparse Attention Pattern
198
+ self.blocksparse_homo_head_pattern = blocksparse_homo_head_pattern
199
+ self.blocksparse_block_size = blocksparse_block_size
200
+ self.blocksparse_num_local_blocks = blocksparse_num_local_blocks
201
+ self.blocksparse_vert_stride = blocksparse_vert_stride
202
+ self.blocksparse_triton_kernel_block_size = blocksparse_triton_kernel_block_size
203
+ # Frequency of block sparsity
204
+ self.dense_attention_every_n_layers = dense_attention_every_n_layers
205
+ # Activation function
206
+ self.hidden_act = hidden_act
207
+ self.gegelu_limit = gegelu_limit
208
+ self.gegelu_pad_to_256 = gegelu_pad_to_256
209
+ self.ff_dim_multiplier = ff_dim_multiplier
210
+ self.ff_intermediate_size = ff_intermediate_size
211
+ if self.ff_dim_multiplier is None and self.ff_intermediate_size is None:
212
+ raise ValueError(f"Cannot have both {self.ff_dim_multiplier} and {self.ff_intermediate_size} as None")
213
+ if self.ff_dim_multiplier is not None and self.ff_intermediate_size is not None:
214
+ raise ValueError(f"Cannot specify both {self.ff_dim_multiplier} and {self.ff_intermediate_size}.")
215
+ # General regularization
216
+ self.embedding_dropout_prob = embedding_dropout_prob
217
+ self.attention_dropout_prob = attention_dropout_prob
218
+ self.ffn_dropout_prob = ffn_dropout_prob
219
+ self.layer_norm_epsilon = layer_norm_epsilon
220
+ self.initializer_range = initializer_range
221
+ # MuP parameters
222
+ self.mup_use_scaling = mup_use_scaling
223
+ self.mup_width_multiplier = mup_width_multiplier
224
+ self.mup_embedding_multiplier = mup_embedding_multiplier
225
+ self.mup_attn_multiplier = mup_attn_multiplier
226
+ self.use_cache = use_cache
227
+
228
+ self.reorder_and_upcast_attn = reorder_and_upcast_attn
229
+ self.pad_sequence_to_multiple_of_64 = pad_sequence_to_multiple_of_64
230
+
231
+ self.bos_token_id = bos_token_id
232
+ self.eos_token_id = eos_token_id
233
+
234
+ super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
235
+
236
+ @cached_property
237
+ def dummy_token_indices(self) -> List[int]:
238
+ # Importing here to avoid circular imports
239
+ from .tokenization_phi3_small import Phi3SmallTokenizer
240
+ tokenizer = Phi3SmallTokenizer()
241
+ return tokenizer.dummy_token_indices
242
+
243
+ @property
244
+ def intermediate_size(self) -> int:
245
+ if self.ff_intermediate_size is not None:
246
+ return self.ff_intermediate_size
247
+ intermediate_size = (self.ff_dim_multiplier) * (self.hidden_size // 3) * 2
248
+ if self.gegelu_pad_to_256:
249
+ intermediate_size = next_mult(intermediate_size, 256)
250
+ return intermediate_size
generation_config.json ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "bos_token_id": 100257,
4
+ "eos_token_id": [
5
+ 100257,
6
+ 100266
7
+ ],
8
+ "max_new_tokens": 2000,
9
+ "transformers_version": "4.42.4"
10
+ }
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3a8ed2260866cb4627489b8904992bda101277c420a15fc226ace769f418b4fd
3
+ size 4751891888
modeling_phi3_small.py ADDED
@@ -0,0 +1,1140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import Any, Dict, Optional, List, Tuple, Union
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+
8
+ from einops import rearrange
9
+
10
+ from transformers.modeling_outputs import SequenceClassifierOutputWithPast, CausalLMOutputWithPast, BaseModelOutputWithPast
11
+ from transformers.modeling_utils import PreTrainedModel
12
+ from transformers.utils import logging
13
+
14
+ from transformers.cache_utils import Cache, DynamicCache
15
+
16
+ from .triton_flash_blocksparse_attn import BlockSparseParams
17
+ from .triton_blocksparse_attention_layer import BlockSparseAttentionLayer
18
+ from .positional_embedding import RotaryEmbedding
19
+
20
+ from .configuration_phi3_small import Phi3SmallConfig
21
+
22
+ # Flash Attention Related Imports
23
+ is_flash_attention_available = False
24
+ try:
25
+ import flash_attn
26
+ if int(flash_attn.__version__.split('.')[0]) < 2:
27
+ from flash_attn.flash_attn_interface import (
28
+ flash_attn_func,
29
+ flash_attn_unpadded_kvpacked_func as flash_attn_varlen_kvpacked_func,
30
+ )
31
+
32
+ # rename `max_seqlen`
33
+ def flash_attn_varlen_qkvpacked_func(qkv, cu_seqlens, max_seqlen, dropout_p=0.0, **kwargs):
34
+ return flash_attn_func(qkv, cu_seqlens, dropout_p=dropout_p, max_s=max_seqlen, **kwargs)
35
+
36
+ else:
37
+ from flash_attn.flash_attn_interface import (
38
+ flash_attn_varlen_kvpacked_func,
39
+ )
40
+ from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input
41
+ is_flash_attention_available = True
42
+ except ImportError:
43
+ pass
44
+
45
+ logger = logging.get_logger(__name__)
46
+
47
+ LegacyCache = Tuple[Tuple[torch.FloatTensor]]
48
+
49
+ # Taken from https://github.com/allenai/allennlp/blob/main/allennlp/nn/util.py
50
+ def info_value_of_dtype(dtype: torch.dtype):
51
+ """
52
+ Returns the `finfo` or `iinfo` object of a given PyTorch data type. Does not allow torch.bool.
53
+ """
54
+ if dtype == torch.bool:
55
+ raise TypeError("Does not support torch.bool")
56
+ elif dtype.is_floating_point:
57
+ return torch.finfo(dtype)
58
+ else:
59
+ return torch.iinfo(dtype)
60
+
61
+
62
+ # Taken from https://github.com/allenai/allennlp/blob/main/allennlp/nn/util.py
63
+ def min_value_of_dtype(dtype: torch.dtype):
64
+ """
65
+ Returns the minimum value of a given PyTorch data type. Does not allow torch.bool.
66
+ """
67
+ return info_value_of_dtype(dtype).min
68
+
69
+ # Copied from transformers.models.llama.modeling_llama._get_unpad_data
70
+ def _get_unpad_data(attention_mask):
71
+ seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
72
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
73
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
74
+ cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
75
+ return (
76
+ indices,
77
+ cu_seqlens,
78
+ max_seqlen_in_batch,
79
+ )
80
+
81
+
82
+ @torch.jit.script
83
+ def quick_gelu(x):
84
+ return x * torch.sigmoid(1.702 * x)
85
+
86
+
87
+ @torch.jit.script
88
+ def gegelu(input, limit: Optional[float] = None):
89
+ a_gelu, a_linear = input[..., ::2], input[..., 1::2]
90
+ if limit is not None:
91
+ a_gelu = torch.where(
92
+ torch.isinf(a_gelu), a_gelu, a_gelu.clamp(min=None, max=limit)
93
+ )
94
+ a_linear = torch.where(
95
+ torch.isinf(a_linear), a_linear, a_linear.clamp(min=-limit, max=limit)
96
+ )
97
+ out_gelu = quick_gelu(a_gelu)
98
+ return out_gelu * (a_linear + 1)
99
+
100
+ def collapse_first_n_dims(x: torch.Tensor, n: int) -> torch.Tensor:
101
+ """
102
+ Collapse the first `n` dimensions of a tensor into a single dimension.
103
+
104
+ Args:
105
+ x (torch.Tensor): The input tensor.
106
+ n (int): The number of dimensions to collapse.
107
+
108
+ Returns:
109
+ torch.Tensor: The output tensor.
110
+ """
111
+ return x.view(-1, *x.shape[n:])
112
+
113
+ def pad_tensor_to_next_mult_of(
114
+ tensor: torch.Tensor,
115
+ dim: int,
116
+ n: int,
117
+ ) -> Tuple[torch.Tensor, int]:
118
+ """
119
+ Pads a tensor along a specified dimension to the next multiple of a given number.
120
+
121
+ Args:
122
+ tensor (torch.Tensor): The input tensor.
123
+ dim (int): The dimension along which to pad the tensor.
124
+ n (int): The number to pad the tensor to the next multiple of.
125
+
126
+ Returns:
127
+ Tuple[torch.Tensor, int]: A tuple containing the padded tensor and the amount of padding added.
128
+ """
129
+ residual = tensor.size(dim) % n
130
+ if residual == 0:
131
+ return tensor, 0
132
+ padding = n - residual
133
+ padding_tensor = torch.zeros((*tensor.size()[:dim], padding, *tensor.size()[dim + 1:]), device=tensor.device, dtype=tensor.dtype)
134
+ return torch.cat([tensor, padding_tensor], dim=dim), padding
135
+
136
+ def strip_padding_from_tensor(
137
+ tensor: torch.Tensor,
138
+ dim: int,
139
+ residual: int,
140
+ ) -> torch.Tensor:
141
+ """
142
+ Removes padding from a tensor along a specified dimension.
143
+
144
+ Args:
145
+ tensor (torch.Tensor): The input tensor.
146
+ dim (int): The dimension along which to remove padding.
147
+ residual (int): The amount of padding to remove.
148
+
149
+ Returns:
150
+ torch.Tensor: The tensor with padding removed along the specified dimension.
151
+ """
152
+ return torch.narrow(tensor, dim, 0, tensor.size(dim) - residual)
153
+
154
+ class Phi3SmallMLP(nn.Module):
155
+ def __init__(self, config: Phi3SmallConfig):
156
+ super().__init__()
157
+ self.config = config
158
+ assert self.config.hidden_act == "gegelu", "Only `gegelu` is supported for the Phi-3-small model .."
159
+ self.hidden_size = config.hidden_size
160
+ self.gegelu_limit = config.gegelu_limit
161
+ self.intermediate_size = config.intermediate_size
162
+
163
+ self.up_proj = nn.Linear(self.hidden_size, 2 * self.intermediate_size)
164
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size)
165
+ self.dropout = nn.Dropout(config.ffn_dropout_prob)
166
+
167
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
168
+ return self.dropout(
169
+ self.down_proj(
170
+ gegelu(self.up_proj(x), limit=self.gegelu_limit)
171
+ )
172
+ )
173
+
174
+
175
+ class Phi3SmallSelfAttention(nn.Module):
176
+ def __init__(self, config: Phi3SmallConfig, layer_idx: Optional[int] = None) -> None:
177
+ super().__init__()
178
+ self.config = config
179
+ self.layer_idx = layer_idx
180
+ if layer_idx is None:
181
+ logger.warning_once(
182
+ f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
183
+ "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
184
+ "when creating this class."
185
+ )
186
+
187
+ self.hidden_size = config.hidden_size
188
+ # Number of Query Heads
189
+ self.num_heads = config.num_attention_heads
190
+ self.head_dim = self.hidden_size // self.num_heads
191
+ # Number of Key Value Heads
192
+ self.num_key_value_heads = config.num_key_value_heads
193
+ self.num_q_per_kv = self.num_heads // self.num_key_value_heads
194
+ self.max_position_embeddings = config.max_position_embeddings
195
+ self.rope_embedding_base = config.rope_embedding_base
196
+ self.rope_position_scale = config.rope_position_scale
197
+ self.is_causal = True
198
+
199
+ self.attention_dropout_rate = config.attention_dropout_prob
200
+
201
+ norm_factor = None
202
+ if config.mup_use_scaling:
203
+ norm_factor = self.head_dim / config.mup_attn_multiplier
204
+ else:
205
+ norm_factor = math.sqrt(self.head_dim)
206
+ self.softmax_scale = 1.0 / norm_factor
207
+
208
+ self.query_key_value = nn.Linear(self.hidden_size, (self.num_heads + 2 * self.num_key_value_heads) * self.head_dim)
209
+ self.dense = nn.Linear(self.hidden_size, self.hidden_size)
210
+
211
+ self.blocksparse_params = None
212
+ # layer_idx is 0 indexed because that's what the KV Cache expects.
213
+ if self.config.dense_attention_every_n_layers and ((self.layer_idx + 1) % self.config.dense_attention_every_n_layers == 0):
214
+ logger.info(
215
+ f"Layer {layer_idx + 1} is using dense attention since it is divisible by "
216
+ f"{self.config.dense_attention_every_n_layers}"
217
+ )
218
+ assert is_flash_attention_available, "Flash Attention is not available, but is needed for dense attention"
219
+ else:
220
+ # BlockSparse related Parameters
221
+ self.blocksparse_params = BlockSparseParams.from_config(config)
222
+
223
+ if self.blocksparse:
224
+ active_head_range = None
225
+ """
226
+ ... note(bapatra)::
227
+
228
+ In case of tensor parallelism and while using the heterogeneous head patterns,
229
+ the active head range needs to be modified based on the tensor parallel rank
230
+ and the tensor parallel world size.
231
+
232
+ This is because in the case of heterogeneous head patterns, the kernel needs to know
233
+ which head is on which device, so that it can pick the corresponding blocksparse head
234
+ pattern correctly.
235
+
236
+ Example:
237
+ ```python
238
+
239
+ if not self.blocksparse_params.homo_head_pattern:
240
+ tp_rank = torch.distributed.get_rank() % tp_world_size
241
+ num_heads_per_partition = num_heads // tp_world_size
242
+ active_head_range = (tp_rank * num_heads_per_partition, (tp_rank + 1) * num_heads_per_partition)
243
+
244
+ ```
245
+
246
+ """
247
+
248
+ self._blocksparse_layer = BlockSparseAttentionLayer(
249
+ n_heads=self.num_heads,
250
+ max_seq_len=self.max_position_embeddings,
251
+ sparse_block_size=self.blocksparse_params.block_size,
252
+ local_blocks=self.blocksparse_params.num_local_blocks,
253
+ vert_stride=self.blocksparse_params.vert_stride,
254
+ kernel_block_size=self.blocksparse_params.kernel_block_size,
255
+ homo_head=self.blocksparse_params.homo_head_pattern,
256
+ active_head_range=active_head_range,
257
+ )
258
+ self.rotary_emb = RotaryEmbedding.from_config(config)
259
+
260
+
261
+ @property
262
+ def blocksparse(self):
263
+ return self.blocksparse_params is not None
264
+
265
+ def _split_heads(self, mixed_x_layer: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
266
+ bs, sq, _ = mixed_x_layer.size()
267
+ r"""
268
+ The main idea is that we group tensors as
269
+ [bs, sq, (q00, q01, ... q0m, k0, v0), (q10, q11, ... q1m, k1, v1), ... (qn0, qn1, ... qnm, kn, vn)]
270
+ That ways, when the MP column sharding happens, this tensor will be sharded keeping all the
271
+ queries and keys intact. In order to get the correct qkv, we first break into groups, and then
272
+ index into the groups.
273
+ """
274
+
275
+ intermediate_shape = (bs, sq, -1, (self.num_q_per_kv + 2), self.head_dim)
276
+ mixed_x_layer = mixed_x_layer.view(*intermediate_shape)
277
+ q = mixed_x_layer[:, :, :, :-2]
278
+ k = mixed_x_layer[:, :, :, [-2]]
279
+ v = mixed_x_layer[:, :, :, [-1]]
280
+ q, k, v = [
281
+ rearrange(
282
+ x,
283
+ "bs sq group nh hn -> bs sq (group nh) hn"
284
+ ) for x in (q, k, v)
285
+ ]
286
+ return q, k, v
287
+
288
+ # Copied from transformers.models.mistral.modeling_mistral.MistralFlashAttention2._unpad_input
289
+ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
290
+ batch_size, kv_seq_len, num_heads, head_dim = key_layer.shape
291
+
292
+
293
+ indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
294
+
295
+ key_layer = index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k)
296
+ value_layer = index_first_axis(value_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k)
297
+
298
+ if query_length == kv_seq_len:
299
+ query_layer = index_first_axis(
300
+ query_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k
301
+ )
302
+ cu_seqlens_q = cu_seqlens_k
303
+ max_seqlen_in_batch_q = max_seqlen_in_batch_k
304
+ indices_q = indices_k
305
+ elif query_length == 1:
306
+ max_seqlen_in_batch_q = 1
307
+ cu_seqlens_q = torch.arange(
308
+ batch_size + 1, dtype=torch.int32, device=query_layer.device
309
+ ) # There is a memcpy here, that is very bad.
310
+ indices_q = cu_seqlens_q[:-1]
311
+ query_layer = query_layer.squeeze(1)
312
+ else:
313
+ # The -q_len: slice assumes left padding.
314
+ attention_mask = attention_mask[:, -query_length:]
315
+ query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
316
+
317
+ return (
318
+ query_layer,
319
+ key_layer,
320
+ value_layer,
321
+ indices_q,
322
+ (cu_seqlens_q, cu_seqlens_k),
323
+ (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
324
+ )
325
+
326
+ def _apply_blocksparse_attention(
327
+ self,
328
+ q: torch.Tensor,
329
+ k: torch.Tensor,
330
+ v: torch.Tensor,
331
+ attention_mask: Optional[torch.LongTensor],
332
+ return_attention_probs: bool = False,
333
+ ) -> torch.Tensor:
334
+ """
335
+ Applies blocksparse attention to the input tensors.
336
+
337
+ Args:
338
+ q (torch.Tensor): The query tensor of shape (bs, nqp, seq_len, hn).
339
+ k (torch.Tensor): The key tensor of shape (bs, nkp, seq_len, hn).
340
+ v (torch.Tensor): The value tensor of shape (bs, nkp, seq_len, hn).
341
+ attention_mask (Optional[torch.LongTensor]): The attention mask tensor of shape (bs, seq_len).
342
+ return_attention_probs (bool, optional): Whether to return attention probabilities. Defaults to False.
343
+
344
+ Returns:
345
+ torch.Tensor: The context layer tensor of shape (bs, nqp, seq_len, hn).
346
+ """
347
+ assert not return_attention_probs, "return_attention_probs is not supported for blocksparse attention"
348
+ q, k, v = q.contiguous(), k.contiguous(), v.contiguous()
349
+ # shape: (bs, nqp, seq_len, hn)
350
+ if torch.is_grad_enabled():
351
+ # Training or non-batched inference
352
+ context_layer = self._blocksparse_layer(
353
+ q=q, k=k, v=v, sm_scale=self.softmax_scale
354
+ )
355
+ elif attention_mask is None:
356
+ if q.size(0) != 1:
357
+ logger.warning_once(
358
+ "You are attempting to do batched inference without passing the attention mask.\n"
359
+ "This is okay if you are running loglikelihood requests. However, if you want to do generation, "
360
+ "this probably won't work as expected. Please pass the attention mask to the forward function."
361
+ )
362
+ context_layer = self._blocksparse_layer(
363
+ q=q, k=k, v=v, sm_scale=self.softmax_scale
364
+ )
365
+ else:
366
+ """
367
+ Shapes of tensors are as follows:
368
+ q: (bs, nqp, seq_len, hdim)
369
+ k: (bs, nkp, seq_len, hdim)
370
+ v: (bs, nkp, seq_len, hdim)
371
+ We first need to transpose the shapes to fit what the
372
+ kernel needs, and the reinvert it back at the end of the operations
373
+ """
374
+ assert attention_mask.ndim == 2, "The kernel, like flash-attention-2, only supports 2d attention masks ..."
375
+ left_paddings = attention_mask.shape[1] - attention_mask.sum(dim=-1)
376
+ # shape: (bs, seq_len, nqp, hdim)
377
+ q = q.transpose(1, 2).contiguous()
378
+ # shape: (bs, seq_len, nkp, hdim)
379
+ k = k.transpose(1, 2).contiguous()
380
+ # shape: (bs, seq_len, nkp, hdim)
381
+ v = v.transpose(1, 2).contiguous()
382
+ context_layer = self._blocksparse_layer(
383
+ q=q, k=k, v=v, sm_scale=self.softmax_scale, left_paddings=left_paddings.to(torch.int32)
384
+ )
385
+ # shape: (bs, nqp, seq_len, hdim)
386
+ context_layer = context_layer.transpose(1, 2).contiguous()
387
+ return context_layer
388
+
389
+ def _apply_dense_attention(
390
+ self,
391
+ q: torch.Tensor,
392
+ k: torch.Tensor,
393
+ v: torch.Tensor,
394
+ attention_mask: torch.Tensor,
395
+ return_attention_probs: bool = False,
396
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
397
+ """
398
+ Apply dense attention
399
+
400
+ Args:
401
+ q (torch.Tensor):
402
+ The query tensor, shape: (bs, num_query_heads, seq_len, head_size)
403
+ k (torch.Tensor):
404
+ The key tensor, shape: (bs, num_query_heads, seq_len, head_size)
405
+ v (torch.Tensor):
406
+ The value tensor, shape: (bs, num_query_heads, seq_len, head_size)
407
+
408
+ return_attention_probs (bool, optional):
409
+ Return the attention probabilities. Defaults to False.
410
+
411
+ Returns:
412
+ Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
413
+ Return the output of the attention aggregation. If `return_attention_probs` is True, then
414
+ also return the attention probabilities
415
+
416
+ .. note::
417
+ Right now, am assuming the expansion for the query key values is already done
418
+ outside. But ideally, since Flash attention handles the GQA correctly, we can
419
+ avoid doing that.
420
+
421
+ """
422
+ attention_dropout_prob = self.attention_dropout_rate if self.training else 0.0
423
+ # Get into the correct shape for the Flash Attention API
424
+ # shape: (bs, seq_len, nqp, hn)
425
+ q = q.transpose(1, 2).contiguous()
426
+ query_length = q.size(1)
427
+ # shape: (bs, seq_len, npq, hn)
428
+ k = k.transpose(1, 2).contiguous()
429
+ # shape: (bs, seq_len, npq, hn)
430
+ v = v.transpose(1, 2).contiguous()
431
+
432
+ if attention_mask is not None:
433
+ causal = q.size(2) == k.size(2)
434
+ batch_size = q.shape[0]
435
+ flat_q, flat_k, flat_v, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
436
+ q, k, v, attention_mask, query_length
437
+ )
438
+ cu_seqlens_q, cu_seqlens_k = cu_seq_lens
439
+ max_seqlen_q, max_seqlen_k = max_seq_lens
440
+ flat_kv = torch.cat((flat_k.unsqueeze(1), flat_v.unsqueeze(1)), dim=1)
441
+ attn_output_unpad = flash_attn_varlen_kvpacked_func(
442
+ q=flat_q,
443
+ kv=flat_kv,
444
+ cu_seqlens_q=cu_seqlens_q,
445
+ cu_seqlens_k=cu_seqlens_k,
446
+ max_seqlen_q=max_seqlen_q,
447
+ max_seqlen_k=max_seqlen_k,
448
+ dropout_p=attention_dropout_prob,
449
+ softmax_scale=self.softmax_scale,
450
+ causal=causal,
451
+ return_attn_probs=return_attention_probs
452
+ )
453
+ attention_output = pad_input(
454
+ attn_output_unpad, indices_q, batch_size, query_length
455
+ )
456
+ else:
457
+ kv = torch.cat((k.unsqueeze(2), v.unsqueeze(2)), dim=2)
458
+ cu_seqlens_q = torch.arange(
459
+ 0, (q.size(0) + 1), device=q.device, dtype=torch.int32
460
+ ) * q.size(1)
461
+ cu_seqlens_kv = torch.arange(
462
+ 0, (kv.size(0) + 1), device=kv.device, dtype=torch.int32
463
+ ) * kv.size(1)
464
+ max_seqlen_q = q.size(1)
465
+ max_seqlen_k = kv.size(1)
466
+ attention_output = flash_attn_varlen_kvpacked_func(
467
+ q=collapse_first_n_dims(q, 2),
468
+ kv=collapse_first_n_dims(kv, 2),
469
+ cu_seqlens_q=cu_seqlens_q,
470
+ cu_seqlens_k=cu_seqlens_kv,
471
+ max_seqlen_q=max_seqlen_q,
472
+ max_seqlen_k=max_seqlen_k,
473
+ dropout_p=attention_dropout_prob,
474
+ softmax_scale=self.softmax_scale,
475
+ causal=q.size(1) == kv.size(1),
476
+ return_attn_probs=return_attention_probs
477
+ )
478
+ if return_attention_probs:
479
+ (context_layer, attn_probs) = attention_output
480
+ context_layer = context_layer.view(q.size(0), q.size(1), -1, q.size(3)).transpose(1, 2).contiguous()
481
+ return (context_layer, attn_probs)
482
+ context_layer = attention_output
483
+ context_layer = context_layer.view(q.size(0), q.size(1), -1, q.size(3)).transpose(1, 2).contiguous()
484
+ return context_layer
485
+
486
+
487
+ def expand_kv_to_q_size(self, kv: torch.Tensor, num_q_per_kv: int) -> torch.Tensor:
488
+ """
489
+ Expand the key-value tensor to match the size of the query tensor.
490
+
491
+ Args:
492
+ kv (torch.Tensor): The key-value tensor of shape (bsz, nkp, 2, seq_len, hdim).
493
+ num_q_per_kv (int): The number of queries per key-value.
494
+
495
+ Returns:
496
+ torch.Tensor: The expanded key-value tensor of shape (bsz, nqp, 2, seq_len, hdim).
497
+ Where nqp = num_q_per_kv * nkp
498
+
499
+ .. note(bapatra)::
500
+ Right now, I am using a repeat_interleave to expand the kv to the size of q.
501
+ This incurs a memory penalty, since the tensors are actually copied.
502
+ TODO: If this does yield benefits, then potentially we can use the re-written
503
+ flash attention kernel that can handle GQA.
504
+ """
505
+
506
+ repeats = torch.tensor([num_q_per_kv] * kv.size(1)).to(kv.device)
507
+ total = repeats.sum()
508
+ expanded_kv = torch.repeat_interleave(
509
+ kv,
510
+ repeats=repeats,
511
+ dim=1,
512
+ output_size=total
513
+ )
514
+ return expanded_kv
515
+
516
+ def forward(
517
+ self,
518
+ hidden_states: torch.Tensor,
519
+ attention_mask: Optional[torch.Tensor] = None,
520
+ position_ids: Optional[torch.LongTensor] = None,
521
+ past_key_values: Optional[Cache] = None,
522
+ output_attentions: bool = False,
523
+ use_cache: bool = False,
524
+ **kwargs,
525
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
526
+ """
527
+ The forward function of the Self Attention Layer.
528
+
529
+ Args:
530
+ hidden_states (torch.Tensor):
531
+ The input tensor of shape (bs, q_len, h).
532
+ attention_mask (Optional[torch.Tensor], optional):
533
+ The attention mask tensor of shape (bs, seq_len). This is the 2D attention mask tensor as is standard in the flash-attention
534
+ kernel.
535
+ Defaults to None.
536
+ position_ids (Optional[torch.LongTensor], optional):
537
+ The position ids tensor of shape (bs, q_len). Defaults to None. Unused by the function.
538
+ past_key_value (Optional[Cache], optional):
539
+ The previous kv cache values. Defaults to None.
540
+ output_attentions (bool, optional):
541
+ Whether to return the attention scores. Defaults to False.
542
+ .. note::
543
+ For the blocksparse attention kernel, we do not support returning the attention scores.
544
+ use_cache (bool, optional):
545
+ Whether to use the cache for storing the kv. Defaults to False.
546
+
547
+ Returns:
548
+ Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
549
+ The output tensor of shape (bs, q_len, h),
550
+ the attention scores tensor of shape (bs, nqp, q_len, seq_len) if `output_attentions` is True,
551
+ and the updated cache values if `use_cache` is True.
552
+
553
+ Notations:
554
+ ------------
555
+ bs: batch size
556
+ sq_len: sequence length of the entire sequence
557
+ q_len: sequence length of the query
558
+ cache_sq: sequence length in the cache
559
+ If there is no cache then cache_sq = 0
560
+ and sq_len = q_len
561
+ otherwise sq_len = q_len + cache_sq
562
+ h: hidden size
563
+ nq: number of query heads
564
+ nkv: number of key heads
565
+ hn: hidden size per head
566
+ hn = h // nq
567
+ nqp: number of query heads (per MP partition)
568
+ nqp = nq // (num mp partitions)
569
+ nkvp: number of key-value heads (per MP partition)
570
+ nkvp = nk // (num mp partitions)
571
+
572
+ """
573
+ # shape: (bs, q_len, h)
574
+ bsz, q_len, _ = hidden_states.size()
575
+
576
+ # shape: (bs, q_len, (nqp + 2 * nkvp) * hn)
577
+ mixed_x_layer = self.query_key_value(hidden_states)
578
+ # shape: (bs, q_len, nqp, hn), shape: (bs, q_len, nkvp, hn), shape: (bs, q_len, nkvp, hn)
579
+ q, k, v = self._split_heads(mixed_x_layer)
580
+
581
+ # shape: (bs, qnp, q_len, hn)
582
+ query_states = q.permute(0, 2, 1, 3).contiguous()
583
+ # shape: (bs, nkvp, q_len, hn)
584
+ key_states = k.permute(0, 2, 1, 3).contiguous()
585
+ # shape: (bs, nkvp, q_len, hn)
586
+ value_states = v.permute(0, 2, 1, 3).contiguous()
587
+
588
+ kv_seq_len = key_states.shape[-2]
589
+ if past_key_values is not None:
590
+ if self.layer_idx is None:
591
+ raise ValueError(
592
+ f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
593
+ "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
594
+ "with a layer index."
595
+ )
596
+ if self.rotary_emb is not None:
597
+ seqlen_offset = past_key_values.get_usable_length(kv_seq_len, layer_idx=self.layer_idx)
598
+ # shape: (bs, nqp, q_len, hn), shape: (bs, nkvp, q_len, hn)
599
+ query_states, key_states = self.rotary_emb(
600
+ query_states, key_states, seq_dimension=2, seqlen_offset=seqlen_offset
601
+ )
602
+ key_states, value_states = past_key_values.update(key_states=key_states, value_states=value_states, layer_idx=self.layer_idx)
603
+ else:
604
+ # In this case seq_len = q_len and cache_sq = 0
605
+ if self.rotary_emb is not None:
606
+ # shape: (bs, nqp, seq_len, hn), shape: (bs, nkvp, seq_len, hn)
607
+ query_states, key_states = self.rotary_emb(query_states, key_states, seq_dimension=2)
608
+
609
+ # shape: (bs, nkvp, 2, seq_len, hn)
610
+ kv_states = torch.cat((key_states.unsqueeze(2), value_states.unsqueeze(2)), dim=2)
611
+ # shape: (bs, nqp, 2, seq_len, hn)
612
+ expanded_kv_states = self.expand_kv_to_q_size(kv_states, num_q_per_kv=self.num_q_per_kv)
613
+ # shape: (bs, nqp, seq_len, hn), shape: (bs, nqp, seq_len, hn)
614
+ expanded_key_states, expanded_value_states = expanded_kv_states[:, :, 0], expanded_kv_states[:, :, 1]
615
+ if self.blocksparse:
616
+ attn_function_output = self._apply_blocksparse_attention(
617
+ q=query_states,
618
+ k=expanded_key_states,
619
+ v=expanded_value_states,
620
+ attention_mask=attention_mask,
621
+ return_attention_probs=output_attentions
622
+ )
623
+ else:
624
+ attn_function_output = self._apply_dense_attention(
625
+ q=query_states,
626
+ k=expanded_key_states,
627
+ v=expanded_value_states,
628
+ attention_mask=attention_mask,
629
+ return_attention_probs=output_attentions
630
+ )
631
+
632
+ attn_weights = None
633
+ if output_attentions:
634
+ attn_output, attn_weights = attn_function_output
635
+ else:
636
+ # shape: (bs, nqp, seq_len, hn)
637
+ attn_output = attn_function_output
638
+ # shape: (bs, seq_len, nqp, hn)
639
+ attn_output = attn_output.transpose(1, 2).contiguous()
640
+
641
+ # shape: (bs, seq_len, h)
642
+ attn_output = attn_output.view(bsz, q_len, -1)
643
+ attn_output = self.dense(attn_output)
644
+ return attn_output, attn_weights, past_key_values
645
+
646
+
647
+ class Phi3SmallDecoderLayer(nn.Module):
648
+ def __init__(self, config: Phi3SmallConfig, layer_idx: int):
649
+ super().__init__()
650
+ self.hidden_size = config.hidden_size
651
+ self.self_attn = Phi3SmallSelfAttention(config, layer_idx)
652
+ self.mlp = Phi3SmallMLP(config)
653
+
654
+ self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_epsilon)
655
+ self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_epsilon)
656
+
657
+ def forward(
658
+ self,
659
+ hidden_states: torch.Tensor,
660
+ attention_mask: Optional[torch.Tensor] = None,
661
+ position_ids: Optional[torch.LongTensor] = None,
662
+ past_key_values: Optional[Cache] = None,
663
+ output_attentions: Optional[bool] = None,
664
+ use_cache: Optional[bool] = None,
665
+ **kwargs,
666
+ ) -> Tuple[torch.FloatTensor, Optional[torch.FloatTensor], Optional[Cache]]:
667
+ residual = hidden_states
668
+ hidden_states = self.input_layernorm(hidden_states)
669
+
670
+ # Self Attention
671
+ hidden_states, self_attn_weights, present_key_values = self.self_attn(
672
+ hidden_states=hidden_states,
673
+ attention_mask=attention_mask,
674
+ position_ids=position_ids,
675
+ past_key_values=past_key_values,
676
+ output_attentions=output_attentions,
677
+ use_cache=use_cache,
678
+ )
679
+ hidden_states = residual + hidden_states
680
+
681
+ # Fully Connected
682
+ residual = hidden_states
683
+ hidden_states = self.post_attention_layernorm(hidden_states)
684
+ hidden_states = self.mlp(hidden_states)
685
+ hidden_states = residual + hidden_states
686
+
687
+ outputs = (hidden_states,)
688
+
689
+ if output_attentions:
690
+ outputs += (self_attn_weights,)
691
+
692
+ if use_cache:
693
+ outputs += (present_key_values,)
694
+
695
+ return outputs
696
+
697
+
698
+
699
+ class Phi3SmallPreTrainedModel(PreTrainedModel):
700
+ config_class = Phi3SmallConfig
701
+ base_model_prefix = "model"
702
+ supports_gradient_checkpointing = True
703
+ _no_split_modules = ["Phi3SmallDecoderLayer"]
704
+ skip_keys_device_placement = "past_key_values"
705
+ _supports_flash_attn_2 = True
706
+ _supports_sdpa = False
707
+ _supports_cache_class = True
708
+
709
+ def _init_weights(self, module: nn.Module):
710
+ std = self.config.initializer_range
711
+ if isinstance(module, nn.Linear):
712
+ # Slightly different from the TF version which uses truncated_normal for initialization
713
+ # cf https://github.com/pytorch/pytorch/pull/5617
714
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
715
+ elif isinstance(module, nn.Embedding):
716
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
717
+ if module.padding_idx is not None:
718
+ module.weight.data[module.padding_idx].zero_()
719
+ elif isinstance(module, nn.LayerNorm):
720
+ module.bias.data.zero_()
721
+ module.weight.data.fill_(1.0)
722
+
723
+ # The output projection on the decoder attention layer as well as the down_proj in the MLP are scaled
724
+ # differently (dubbed `output_layer_init_method` in the Megatron code). This is replicated here
725
+ for name, p in module.named_parameters():
726
+ if any(x in name for x in ("c_proj.weight", "down_proj.weight", "o_proj.weight")):
727
+ # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
728
+ p.data.normal_(mean=0.0, std=(self.config.initializer_range / math.sqrt(2 * self.config.num_hidden_layers)))
729
+
730
+
731
+ class Phi3SmallModel(Phi3SmallPreTrainedModel):
732
+
733
+ def __init__(self, config):
734
+ super().__init__(config)
735
+ self.config = config
736
+
737
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
738
+
739
+ # Embedding Dropout
740
+ self.embedding_dropout = nn.Dropout(config.embedding_dropout_prob)
741
+
742
+ # MuP Embedding scaling
743
+ self.mup_embedding_multiplier = config.mup_embedding_multiplier
744
+
745
+ self.layers = nn.ModuleList([Phi3SmallDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)])
746
+
747
+ self.final_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_epsilon)
748
+
749
+ self.gradient_checkpointing = False
750
+
751
+ # Initialize weights and apply final processing
752
+ self.post_init()
753
+
754
+ def get_input_embeddings(self):
755
+ return self.embed_tokens
756
+
757
+ def set_input_embeddings(self, value):
758
+ self.embed_tokens = value
759
+
760
+ @property
761
+ def pad_sequence_to_multiple_of_64(self):
762
+ # We only need to do this for the backward pass. So only required
763
+ # when we are in the context of generating gradients
764
+ return self.config.pad_sequence_to_multiple_of_64 and torch.is_grad_enabled()
765
+
766
+ def forward(
767
+ self,
768
+ input_ids: torch.LongTensor = None,
769
+ attention_mask: Optional[torch.Tensor] = None,
770
+ position_ids: Optional[torch.LongTensor] = None,
771
+ past_key_values: Optional[Union[Cache, LegacyCache]] = None,
772
+ inputs_embeds: Optional[torch.FloatTensor] = None,
773
+ use_cache: Optional[bool] = None,
774
+ output_attentions: Optional[bool] = None,
775
+ output_hidden_states: Optional[bool] = None,
776
+ return_dict: Optional[bool] = None,
777
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
778
+
779
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
780
+ output_hidden_states = (
781
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
782
+ )
783
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
784
+
785
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
786
+
787
+ if input_ids is not None and inputs_embeds is not None:
788
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
789
+ elif input_ids is not None:
790
+ batch_size, seq_length = input_ids.shape
791
+ elif inputs_embeds is not None:
792
+ batch_size, seq_length, _ = inputs_embeds.shape
793
+ else:
794
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
795
+
796
+ if self.gradient_checkpointing and self.training:
797
+ if use_cache:
798
+ logger.warning_once(
799
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
800
+ )
801
+ use_cache = False
802
+
803
+ past_key_values_length = 0
804
+
805
+ if use_cache:
806
+ use_legacy_cache = not isinstance(past_key_values, Cache)
807
+ if use_legacy_cache:
808
+ past_key_values = DynamicCache.from_legacy_cache(past_key_values)
809
+ past_key_values_length = past_key_values.get_usable_length(seq_length)
810
+
811
+ if position_ids is None:
812
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
813
+ position_ids = torch.arange(
814
+ past_key_values_length, past_key_values_length + seq_length, dtype=torch.long, device=device
815
+ )
816
+ position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
817
+ else:
818
+ position_ids = position_ids.view(-1, seq_length).long()
819
+
820
+ if attention_mask is not None:
821
+ if batch_size <= 0:
822
+ raise ValueError("batch_size has to be defined and > 0")
823
+
824
+ if inputs_embeds is None:
825
+ inputs_embeds = self.embed_tokens(input_ids)
826
+ inputs_embeds = self.embedding_dropout(inputs_embeds)
827
+
828
+ if self.mup_embedding_multiplier is not None and self.mup_embedding_multiplier > 0.0:
829
+ inputs_embeds = inputs_embeds * self.mup_embedding_multiplier
830
+
831
+ residual = 0
832
+ if self.pad_sequence_to_multiple_of_64:
833
+ # note(bapatra): Since we don't particularly use the position_ids and the attention mask
834
+ # we don't need to pad them
835
+ inputs_embeds, residual = pad_tensor_to_next_mult_of(tensor=inputs_embeds, dim=1, n=64)
836
+
837
+ hidden_states = inputs_embeds
838
+
839
+ # decoder layers
840
+ all_hidden_states = () if output_hidden_states else None
841
+ all_self_attns = () if output_attentions else None
842
+ next_decoder_cache = None
843
+
844
+ for decoder_layer in self.layers:
845
+ if output_hidden_states:
846
+ all_hidden_states += (hidden_states,)
847
+
848
+ if self.gradient_checkpointing and self.training:
849
+ layer_outputs = self._gradient_checkpointing_func(
850
+ decoder_layer.__call__,
851
+ hidden_states,
852
+ attention_mask,
853
+ position_ids,
854
+ past_key_values,
855
+ output_attentions,
856
+ use_cache,
857
+ )
858
+ else:
859
+ layer_outputs = decoder_layer(
860
+ hidden_states,
861
+ attention_mask=attention_mask,
862
+ position_ids=position_ids,
863
+ past_key_values=past_key_values,
864
+ output_attentions=output_attentions,
865
+ use_cache=use_cache,
866
+ )
867
+ hidden_states = layer_outputs[0]
868
+
869
+ if use_cache:
870
+ # Following the Mistral schema for layer return values
871
+ next_decoder_cache = layer_outputs[2 if output_attentions else 1]
872
+ if output_attentions:
873
+ all_self_attns += (layer_outputs[1],)
874
+
875
+ hidden_states = self.final_layernorm(hidden_states)
876
+
877
+ if residual > 0:
878
+ hidden_states = strip_padding_from_tensor(tensor=hidden_states, dim=1, residual=residual)
879
+
880
+ # add hidden states from the last decoder layer
881
+ if output_hidden_states:
882
+ all_hidden_states += (hidden_states,)
883
+
884
+ next_cache = None
885
+ if use_cache:
886
+ next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache
887
+
888
+ if not return_dict:
889
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
890
+ return BaseModelOutputWithPast(
891
+ last_hidden_state=hidden_states,
892
+ past_key_values=next_cache,
893
+ hidden_states=all_hidden_states,
894
+ attentions=all_self_attns,
895
+ )
896
+
897
+
898
+ class Phi3SmallForCausalLM(Phi3SmallPreTrainedModel):
899
+ _tied_weights_keys = ["lm_head.weight"]
900
+
901
+ def __init__(self, config):
902
+ super().__init__(config)
903
+ self.model = Phi3SmallModel(config)
904
+ self.vocab_size = config.vocab_size
905
+ self.lm_head = nn.Linear(config.hidden_size, self.vocab_size, bias=False)
906
+ self.mup_width_multiplier = config.mup_width_multiplier
907
+
908
+ # Create the mask for the dummy tokens in the vocabulary
909
+ dummy_token_indices = config.dummy_token_indices
910
+ dummy_tokens_mask = torch.zeros(self.vocab_size).bool()
911
+ dummy_tokens_mask[dummy_token_indices] = True
912
+ # shape: (vocab_size,)
913
+ self.register_buffer("dummy_tokens_mask", dummy_tokens_mask, persistent=False)
914
+
915
+ # Initialize weights and apply final processing
916
+ self.post_init()
917
+
918
+ def get_input_embeddings(self):
919
+ return self.model.embed_tokens
920
+
921
+ def set_input_embeddings(self, value):
922
+ self.model.embed_tokens = value
923
+
924
+ def get_output_embeddings(self):
925
+ return self.lm_head
926
+
927
+ def set_output_embeddings(self, value):
928
+ self.lm_head = value
929
+
930
+ def set_decoder(self, decoder):
931
+ self.model = decoder
932
+
933
+ def get_decoder(self):
934
+ return self.model
935
+
936
+ def forward(
937
+ self,
938
+ input_ids: torch.LongTensor = None,
939
+ attention_mask: Optional[torch.Tensor] = None,
940
+ position_ids: Optional[torch.LongTensor] = None,
941
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
942
+ inputs_embeds: Optional[torch.FloatTensor] = None,
943
+ labels: Optional[torch.LongTensor] = None,
944
+ use_cache: Optional[bool] = None,
945
+ output_attentions: Optional[bool] = None,
946
+ output_hidden_states: Optional[bool] = None,
947
+ return_dict: Optional[bool] = None,
948
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
949
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
950
+ output_hidden_states = (
951
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
952
+ )
953
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
954
+
955
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
956
+ outputs = self.model(
957
+ input_ids=input_ids,
958
+ attention_mask=attention_mask,
959
+ position_ids=position_ids,
960
+ past_key_values=past_key_values,
961
+ inputs_embeds=inputs_embeds,
962
+ use_cache=use_cache,
963
+ output_attentions=output_attentions,
964
+ output_hidden_states=output_hidden_states,
965
+ return_dict=return_dict,
966
+ )
967
+
968
+ hidden_states = outputs[0]
969
+ logits = self.lm_head(hidden_states)
970
+ logits = logits.float()
971
+ if self.mup_width_multiplier:
972
+ logits = logits / self.mup_width_multiplier
973
+ logits = logits.masked_fill(self.dummy_tokens_mask, min_value_of_dtype(logits.dtype))
974
+
975
+ loss = None
976
+ if labels is not None:
977
+ # Shift so that tokens < n predict n
978
+ shift_logits = logits[..., :-1, :].contiguous()
979
+ shift_labels = labels[..., 1:].contiguous()
980
+ # Flatten the tokens
981
+ loss_fct = nn.CrossEntropyLoss()
982
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
983
+ shift_labels = shift_labels.view(-1)
984
+ # Enable model parallelism
985
+ shift_labels = shift_labels.to(shift_logits.device)
986
+ loss = loss_fct(shift_logits, shift_labels)
987
+
988
+ if not return_dict:
989
+ output = (logits,) + outputs[1:]
990
+ return (loss,) + output if loss is not None else output
991
+
992
+ return CausalLMOutputWithPast(
993
+ loss=loss,
994
+ logits=logits,
995
+ past_key_values=outputs.past_key_values,
996
+ hidden_states=outputs.hidden_states,
997
+ attentions=outputs.attentions,
998
+ )
999
+
1000
+ def prepare_inputs_for_generation(
1001
+ self,
1002
+ input_ids: torch.LongTensor,
1003
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1004
+ attention_mask: Optional[torch.FloatTensor] = None,
1005
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1006
+ **kwargs
1007
+ ) -> Dict[str, Any]:
1008
+ # only last token for inputs_ids if past is defined in kwargs
1009
+ if past_key_values:
1010
+ input_ids = input_ids[:, -1].unsqueeze(-1)
1011
+
1012
+ position_ids = kwargs.get("position_ids", None)
1013
+
1014
+ if attention_mask is not None and position_ids is None:
1015
+ # create position_ids on the fly for batch generation
1016
+ position_ids = attention_mask.long().cumsum(-1) - 1
1017
+ position_ids.masked_fill_(attention_mask == 0, 1)
1018
+ if past_key_values:
1019
+ position_ids = position_ids[:, -1].unsqueeze(-1)
1020
+ else:
1021
+ position_ids = None
1022
+
1023
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
1024
+ if inputs_embeds is not None and past_key_values is None:
1025
+ model_inputs = {"inputs_embeds": inputs_embeds}
1026
+ else:
1027
+ model_inputs = {"input_ids": input_ids}
1028
+
1029
+ model_inputs.update(
1030
+ {
1031
+ "past_key_values": past_key_values,
1032
+ "use_cache": kwargs.get("use_cache"),
1033
+ "position_ids": position_ids,
1034
+ "attention_mask": attention_mask,
1035
+ }
1036
+ )
1037
+ return model_inputs
1038
+
1039
+
1040
+ # Copied from transformers.models.mistral.modeling_mistral.MistralForSequenceClassification with Mistral -> Phi3Small
1041
+ class Phi3SmallForSequenceClassification(Phi3SmallPreTrainedModel):
1042
+ def __init__(self, config):
1043
+ super().__init__(config)
1044
+ self.num_labels = config.num_labels
1045
+ self.model = Phi3SmallModel(config)
1046
+ self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
1047
+
1048
+ # Initialize weights and apply final processing
1049
+ self.post_init()
1050
+
1051
+ def get_input_embeddings(self):
1052
+ return self.model.embed_tokens
1053
+
1054
+ def set_input_embeddings(self, value):
1055
+ self.model.embed_tokens = value
1056
+
1057
+
1058
+ def forward(
1059
+ self,
1060
+ input_ids: torch.LongTensor = None,
1061
+ attention_mask: Optional[torch.Tensor] = None,
1062
+ position_ids: Optional[torch.LongTensor] = None,
1063
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1064
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1065
+ labels: Optional[torch.LongTensor] = None,
1066
+ use_cache: Optional[bool] = None,
1067
+ output_attentions: Optional[bool] = None,
1068
+ output_hidden_states: Optional[bool] = None,
1069
+ return_dict: Optional[bool] = None,
1070
+ ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
1071
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1072
+
1073
+ transformer_outputs = self.model(
1074
+ input_ids,
1075
+ attention_mask=attention_mask,
1076
+ position_ids=position_ids,
1077
+ past_key_values=past_key_values,
1078
+ inputs_embeds=inputs_embeds,
1079
+ use_cache=use_cache,
1080
+ output_attentions=output_attentions,
1081
+ output_hidden_states=output_hidden_states,
1082
+ return_dict=return_dict,
1083
+ )
1084
+ hidden_states = transformer_outputs[0]
1085
+ logits = self.score(hidden_states)
1086
+
1087
+ if input_ids is not None:
1088
+ batch_size = input_ids.shape[0]
1089
+ else:
1090
+ batch_size = inputs_embeds.shape[0]
1091
+
1092
+ if self.config.pad_token_id is None and batch_size != 1:
1093
+ raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
1094
+ if self.config.pad_token_id is None:
1095
+ sequence_lengths = -1
1096
+ else:
1097
+ if input_ids is not None:
1098
+ # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
1099
+ sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
1100
+ sequence_lengths = sequence_lengths % input_ids.shape[-1]
1101
+ sequence_lengths = sequence_lengths.to(logits.device)
1102
+ else:
1103
+ sequence_lengths = -1
1104
+
1105
+ pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
1106
+
1107
+ loss = None
1108
+ if labels is not None:
1109
+ labels = labels.to(logits.device)
1110
+ if self.config.problem_type is None:
1111
+ if self.num_labels == 1:
1112
+ self.config.problem_type = "regression"
1113
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
1114
+ self.config.problem_type = "single_label_classification"
1115
+ else:
1116
+ self.config.problem_type = "multi_label_classification"
1117
+
1118
+ if self.config.problem_type == "regression":
1119
+ loss_fct = nn.MSELoss()
1120
+ if self.num_labels == 1:
1121
+ loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
1122
+ else:
1123
+ loss = loss_fct(pooled_logits, labels)
1124
+ elif self.config.problem_type == "single_label_classification":
1125
+ loss_fct = nn.CrossEntropyLoss()
1126
+ loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
1127
+ elif self.config.problem_type == "multi_label_classification":
1128
+ loss_fct = nn.BCEWithLogitsLoss()
1129
+ loss = loss_fct(pooled_logits, labels)
1130
+ if not return_dict:
1131
+ output = (pooled_logits,) + transformer_outputs[1:]
1132
+ return ((loss,) + output) if loss is not None else output
1133
+
1134
+ return SequenceClassifierOutputWithPast(
1135
+ loss=loss,
1136
+ logits=pooled_logits,
1137
+ past_key_values=transformer_outputs.past_key_values,
1138
+ hidden_states=transformer_outputs.hidden_states,
1139
+ attentions=transformer_outputs.attentions,
1140
+ )
positional_embedding.py ADDED
@@ -0,0 +1,288 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Orginally Taken verbatim from xformers library
3
+ https://github.com/facebookresearch/xformers/blob/bcb707576c6a80eaf850aa80e8643d3497ec2bc4/xformers/components/positional_embedding/rotary.py
4
+
5
+ The difference is that xformers seems to assume the inputs to be
6
+ (bs, head, seq_len, dim) while we assume (bs, seq_len, head, dim)
7
+
8
+ """
9
+ # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
10
+ #
11
+ # This source code is licensed under the BSD license found in the
12
+ # LICENSE file in the root directory of this source tree.
13
+
14
+
15
+ # CREDITS: This implementation is inspired by GPT-NeoX https://github.com/EleutherAI/gpt-neox
16
+ # NOTE: Almost the same right now, moving parts to Triton is the next step
17
+
18
+ import math
19
+ from typing import List, Optional, Tuple, Dict, Union
20
+
21
+ import torch
22
+ import dataclasses
23
+ from transformers.utils import logging
24
+
25
+ from transformers import PretrainedConfig
26
+
27
+ is_dacite_available = False
28
+ try:
29
+ import dacite
30
+ is_dacite_available = True
31
+ except ImportError:
32
+ pass
33
+
34
+ logger = logging.get_logger(__name__)
35
+
36
+ @dataclasses.dataclass
37
+ class LongRopeConfig(object):
38
+ short_factor: List[float]
39
+ long_factor: List[float]
40
+ original_max_position_embeddings: int
41
+ type: str = "longrope"
42
+ short_mscale: float = -1
43
+ long_mscale: float = -1
44
+
45
+
46
+ def __post_init__(self):
47
+ assert self.type in ("longrope", "su"), f"Invalid type {self.type} for LongRopeConfig. Expected longrope / su"
48
+
49
+
50
+ @classmethod
51
+ def from_dict(cls, config_dict: Dict[str, Union[float, List[float], int]]) -> "LongRopeConfig":
52
+ if is_dacite_available:
53
+ # Preferred since we can also type check the input
54
+ return dacite.from_dict(data_class=cls, data=config_dict)
55
+ kwargs = {}
56
+ for field in dataclasses.fields(cls):
57
+ if field.name in config_dict:
58
+ if field.init:
59
+ kwargs[field.name] = config_dict[field.name]
60
+ else:
61
+ raise ValueError(f"Field {field.name} is not initiable")
62
+ else:
63
+ if field.default is dataclasses.MISSING:
64
+ raise ValueError(f"Field {field.name} is required")
65
+ extra_keys = set(config_dict.keys()) - set(kwargs.keys())
66
+ if len(extra_keys) > 0:
67
+ for key in extra_keys:
68
+ logger.error(f"Unrecognized key {key} in config_dict")
69
+ raise ValueError(f"Unrecognized keys in config_dict")
70
+ return cls(**kwargs)
71
+
72
+ def rotate_half(x):
73
+ x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :]
74
+ return torch.cat((-x2, x1), dim=x1.ndim - 1)
75
+
76
+
77
+
78
+ @torch.jit.script
79
+ def apply_rotary_pos_emb(x, cos, sin, seq_dimension: int):
80
+ # NOTE: This could probably be moved to Triton
81
+
82
+ if seq_dimension == 0:
83
+ cos = cos[: x.shape[0], None, None, :]
84
+ sin = sin[: x.shape[0], None, None, :]
85
+ elif seq_dimension == 1:
86
+ # Handle a possible sequence length mismatch in between q and k
87
+ cos = cos[None, : x.shape[1], None, :]
88
+ sin = sin[None, : x.shape[1], None, :]
89
+ elif seq_dimension == 2:
90
+ cos = cos[None, None, : x.shape[2], :]
91
+ sin = sin[None, None, : x.shape[2], :]
92
+
93
+ return (x * cos) + (rotate_half(x) * sin)
94
+
95
+
96
+
97
+ class RotaryEmbedding(torch.nn.Module):
98
+ """
99
+ Adapted from the xformers library
100
+
101
+ The rotary position embeddings from RoFormer_ (Su et. al).
102
+ A crucial insight from the method is that the query and keys are
103
+ transformed by rotation matrices which depend on the relative positions.
104
+ Other implementations are available in the Rotary Transformer repo_ and in
105
+ GPT-NeoX_, GPT-NeoX was an inspiration
106
+ .. _RoFormer: https://arxiv.org/abs/2104.09864
107
+ .. _repo: https://github.com/ZhuiyiTechnology/roformer
108
+ .. _GPT-NeoX: https://github.com/EleutherAI/gpt-neox
109
+ .. warning: Please note that this embedding is not registered on purpose, as it is transformative
110
+ (it does not create the embedding dimension) and will likely be picked up (imported) on a ad-hoc basis
111
+
112
+ # Arguments
113
+ :param dim_mode: head dimention
114
+ :param max_seq_len:
115
+ :param default_seq_dimension: which dim is the sequence length
116
+ :param dtype: cos/sin dtype
117
+ :param use_fused_kernel: if to use customized fused kernel.
118
+ Note: if used, q, k will be modified inplace. Ok for both forward & backward.
119
+ """
120
+
121
+ def __init__(
122
+ self,
123
+ dim_model: int,
124
+ *,
125
+ max_seq_len: Optional[int] = None,
126
+ dtype: Optional[torch.dtype] = None,
127
+ base=10000,
128
+ position_scale=1,
129
+ device: Optional[torch.device] = None,
130
+ longrope_config: Optional[LongRopeConfig] = None,
131
+ ):
132
+ super().__init__()
133
+ self.base = base
134
+ self.dim_model = dim_model
135
+ self.max_seq_len = max_seq_len
136
+ self.longrope_config = longrope_config
137
+
138
+ if self.is_longrope:
139
+ # Keep the maximum range vector, and slice from it as needed
140
+ self.register_buffer(
141
+ "range_vector",
142
+ torch.arange(max_seq_len, device=device, dtype=torch.float32),
143
+ persistent=False
144
+ )
145
+ self.register_buffer(
146
+ "short_factors",
147
+ torch.tensor(self.longrope_config.short_factor, dtype=torch.float32),
148
+ persistent=False
149
+ )
150
+ self.register_buffer(
151
+ "long_factors",
152
+ torch.tensor(self.longrope_config.long_factor, dtype=torch.float32),
153
+ persistent=False
154
+ )
155
+ else:
156
+ # Generate and save the inverse frequency buffer (non trainable)
157
+ inv_freq = 1.0 / (base ** (torch.arange(0, dim_model, 2).float().to(device) / self.dim_model))
158
+ self.register_buffer("inv_freq", inv_freq)
159
+
160
+ self.position_scale = position_scale
161
+
162
+ if not self.is_longrope:
163
+ dtype = dtype or torch.get_default_dtype()
164
+ self._set_cos_sin_cache(
165
+ seq_len=max_seq_len,
166
+ device=self.inv_freq.device,
167
+ dtype=dtype,
168
+ )
169
+ @property
170
+ def is_longrope(self):
171
+ return self.longrope_config is not None
172
+
173
+ @property
174
+ def original_max_seq_len(self):
175
+ if self.longrope_config is not None:
176
+ return self.longrope_config.original_max_position_embeddings
177
+ logger.warning_once(
178
+ (
179
+ "``original_max_seq_len'' is being accessed, but longrope_config has not been set. "
180
+ "Please only do this if you are sure about the context."
181
+ )
182
+ )
183
+ return self.max_seq_len
184
+
185
+ def get_range_vector(self, seq_len: int, device: torch.device):
186
+ if self.is_longrope:
187
+ assert seq_len < self.range_vector.shape[0], f"Found seq_len {seq_len} greater than max_seq_len {self.range_vector.shape[0]}"
188
+ if self.range_vector.device != device:
189
+ self.range_vector = self.range_vector.to(device)
190
+ return self.range_vector[:seq_len]
191
+ return torch.arange(seq_len, device=device, dtype=torch.float32)
192
+
193
+
194
+ def _calc_mscale(self, scale: torch.Tensor) -> torch.Tensor:
195
+ if scale <= 1.0:
196
+ return 1.0
197
+ return math.sqrt(1 + math.log(scale) / math.log(self.original_max_seq_len))
198
+
199
+ def _set_cos_sin_cache(
200
+ self,
201
+ seq_len: int,
202
+ device: Optional[torch.device] = None,
203
+ dtype: Optional[torch.dtype] = None,
204
+ ) -> None:
205
+ dtype = dtype or torch.get_default_dtype()
206
+ self.max_seq_len_cached = seq_len
207
+ t = (torch.arange(self.max_seq_len_cached, device=device, dtype=torch.float32) * self.position_scale).type_as(self.inv_freq)
208
+ device_type = device.type if device is not None else "cpu"
209
+ device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
210
+ with torch.autocast(device_type=device_type, enabled=False):
211
+ # shape: (seq_len, dim_model // 2)
212
+ freqs = torch.outer(t, self.inv_freq)
213
+ # shape: (seq_len, dim_model)
214
+ emb = torch.cat((freqs, freqs), dim=-1)
215
+ cos = emb.cos()
216
+ sin = emb.sin()
217
+ self.register_buffer("cos_cached", cos.to(dtype), persistent=False)
218
+ self.register_buffer("sin_cached", sin.to(dtype), persistent=False)
219
+
220
+ def forward(
221
+ self, q: torch.Tensor,
222
+ k: torch.Tensor,
223
+ seq_dimension: int = 1,
224
+ seqlen_offset: int = 0,
225
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
226
+ """q, k does not include `seqlen_offset`
227
+ q: Either (bs, seq_len, num_heads, head_dim) or (seq_len, bs, num_heads, head_dim)
228
+ k: Either (bs, seq_len, num_heads, head_dim) or (seq_len, bs, num_heads, head_dim)
229
+ """
230
+ if seq_dimension < 0:
231
+ seq_dimension = k.ndim + seq_dimension
232
+ assert seq_dimension in (0, 1, 2)
233
+ seq_len = k.shape[seq_dimension] + seqlen_offset
234
+
235
+ if self.is_longrope:
236
+ if seq_len > self.original_max_seq_len:
237
+ t = self.get_range_vector(seq_len, device=q.device)
238
+ rescale_factors = self.long_factors.to(q.device)
239
+ long_mscale = self.longrope_config.long_mscale
240
+ mscale = long_mscale if long_mscale > 0 else self._calc_mscale(self.max_seq_len / self.original_max_seq_len)
241
+ else:
242
+ t = self.get_range_vector(self.original_max_seq_len, device=q.device)
243
+ rescale_factors = self.short_factors.to(q.device)
244
+ short_mscale = self.longrope_config.short_mscale
245
+ mscale = short_mscale if short_mscale > 0 else 1.0
246
+ assert rescale_factors.shape == (self.dim_model // 2, ), (
247
+ f"misaligned shape for LongRoPE rescale factors:\n"
248
+ f"\tExpected {(self.dim_model // 2, )}, got {rescale_factors.shape}."
249
+ )
250
+ inv_freq = 1.0 / (rescale_factors * (self.base ** (torch.arange(0, self.dim_model, 2).float().to(q.device) / self.dim_model)))
251
+ device_type = q.device.type if q.device is not None else "cpu"
252
+ device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
253
+ with torch.autocast(device_type=device_type, enabled=False):
254
+ freqs = torch.outer(t, inv_freq)
255
+ emb = torch.cat((freqs, freqs), dim=-1)
256
+ cos = emb.cos() * mscale
257
+ sin = emb.sin() * mscale
258
+ cos_cached = cos.to(q.dtype)
259
+ sin_cached = sin.to(q.dtype)
260
+ else:
261
+ if seq_len > self.max_seq_len_cached:
262
+ self._set_cos_sin_cache(
263
+ seq_len=seq_len,
264
+ device=k.device,
265
+ dtype=k.dtype,
266
+ )
267
+ cos_cached = self.cos_cached
268
+ sin_cached = self.sin_cached
269
+ return (
270
+ apply_rotary_pos_emb(
271
+ q, cos_cached[seqlen_offset:seq_len], sin_cached[seqlen_offset:seq_len], seq_dimension=seq_dimension
272
+ ).to(q.dtype),
273
+ apply_rotary_pos_emb(
274
+ k, cos_cached[seqlen_offset:seq_len], sin_cached[seqlen_offset:seq_len], seq_dimension=seq_dimension
275
+ ).to(k.dtype),
276
+ )
277
+
278
+ @classmethod
279
+ def from_config(cls, config: PretrainedConfig) -> "RotaryEmbedding":
280
+ kwargs = dict(
281
+ dim_model=config.hidden_size // config.num_attention_heads,
282
+ max_seq_len=config.max_position_embeddings,
283
+ base=config.rope_embedding_base,
284
+ position_scale=config.rope_position_scale,
285
+ )
286
+ if config.rope_scaling is not None:
287
+ kwargs["longrope_config"] = LongRopeConfig.from_dict(config.rope_scaling)
288
+ return cls(**kwargs)
smash_config.json ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "api_key": null,
3
+ "verify_url": "http://johnrachwan.pythonanywhere.com",
4
+ "smash_config": {
5
+ "pruners": "None",
6
+ "pruning_ratio": 0.0,
7
+ "factorizers": "None",
8
+ "quantizers": "['llm-int8']",
9
+ "weight_quantization_bits": 4,
10
+ "output_deviation": 0.005,
11
+ "compilers": "None",
12
+ "static_batch": true,
13
+ "static_shape": true,
14
+ "controlnet": "None",
15
+ "unet_dim": 4,
16
+ "device": "cuda",
17
+ "cache_dir": "/ceph/hdd/staff/charpent/.cache/modelsjayrsbxc",
18
+ "batch_size": 1,
19
+ "model_name": "numind/NuExtract-large",
20
+ "task": "text_text_generation",
21
+ "max_batch_size": 1,
22
+ "qtype_weight": "torch.qint8",
23
+ "qtype_activation": "torch.quint8",
24
+ "qobserver": "<class 'torch.ao.quantization.observer.MinMaxObserver'>",
25
+ "qscheme": "torch.per_tensor_symmetric",
26
+ "qconfig": "x86",
27
+ "group_size": 128,
28
+ "damp_percent": 0.1,
29
+ "save_load_fn": "bitsandbytes"
30
+ }
31
+ }
special_tokens_map.json ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ {
2
+ "bos_token": "<|endoftext|>",
3
+ "eos_token": "<|endoftext|>",
4
+ "pad_token": "<|endoftext|>"
5
+ }
tokenization_phi3_small.py ADDED
@@ -0,0 +1,313 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://huggingface.co/Qwen/Qwen-7B-Chat/blob/main/tokenization_qwen.py
2
+ import os
3
+ from typing import Collection, List, Optional, Dict, Set, Tuple, Union
4
+
5
+ from functools import cached_property
6
+
7
+ import base64
8
+
9
+ from transformers import PreTrainedTokenizer, AddedToken, AutoConfig
10
+ from transformers.models.auto.tokenization_auto import get_tokenizer_config
11
+ import tiktoken
12
+
13
+
14
+ """
15
+ This tokenizer is almost identical to tiktoken.get_encoding("cl100k_base")
16
+ with a few additional special tokens to support the ChatML format.
17
+
18
+ TODO(bapatra): Right now, I do not save the special tokens to the vocab file.
19
+ Maybe in the future, that would be useful? Can add that support later.
20
+
21
+ """
22
+
23
+ def _load_tiktoken_bpe(tiktoken_bpe_file: str) -> Dict[bytes, int]:
24
+ with open(tiktoken_bpe_file, "rb") as f:
25
+ contents = f.read()
26
+ return {
27
+ base64.b64decode(token): int(rank)
28
+ for token, rank in (line.split() for line in contents.splitlines() if line)
29
+ }
30
+
31
+ # On the megatron codebase, we pad vocabularies to ensure matrix multiplication is fast.
32
+ # this in turn causes some indices to be empty. We account for these empty indices by adding
33
+ # dummy tokens to the tokenizer.
34
+
35
+ EFFECTIVE_PADDED_VOCAB_SIZE = 100352
36
+ ACTUAL_VOCAB_SIZE = 100276
37
+
38
+
39
+ DUMMY_TOKENS = {
40
+ f"<|dummy_id_{11 + offset}|>": 100276 + offset
41
+ for offset in range(1, EFFECTIVE_PADDED_VOCAB_SIZE - ACTUAL_VOCAB_SIZE)
42
+ }
43
+
44
+ SPECIAL_TOKENS = {
45
+ # tiktoken.get_encoding("cl100k_base")._special_tokens
46
+ '<|endoftext|>': 100257,
47
+ '<|fim_prefix|>': 100258,
48
+ '<|fim_middle|>': 100259,
49
+ '<|fim_suffix|>': 100260,
50
+ # Special tokens for post-training
51
+ "<|system|>": 100261,
52
+ "<|user|>": 100262,
53
+ "<|assistant|>": 100263,
54
+ # Dummy unused tokens
55
+ "<|dummy_id_0|>": 100264,
56
+ "<|dummy_id_1|>": 100265,
57
+ # Special tokens for post-training continued
58
+ "<|end|>": 100266,
59
+ # Some dummy tokens, so that tokenization is contiguous and does not cause issues
60
+ # Note that the 100256th token of tiktoken.get_encoding("cl100k_base") does not
61
+ # actually map to anything. So we use a dummy token here.
62
+ "<|dummy_id_2|>": 100256,
63
+ # Likewise, tokens from 100267 to 100275 are also unused
64
+ "<|dummy_id_3|>": 100267,
65
+ "<|dummy_id_4|>": 100268,
66
+ "<|dummy_id_5|>": 100269,
67
+ "<|dummy_id_6|>": 100270,
68
+ "<|dummy_id_7|>": 100271,
69
+ "<|dummy_id_8|>": 100272,
70
+ "<|dummy_id_9|>": 100273,
71
+ "<|dummy_id_10|>": 100274,
72
+ "<|dummy_id_11|>": 100275,
73
+ # The final end of prompt token
74
+ # (unused, but present as a part of tiktoken.get_encoding("cl100k_base")._special_tokens)
75
+ '<|endofprompt|>': 100276,
76
+ # Dummy tokens to account for padding of the tokenizer
77
+ # We pad to ensure tensor cores are used for vocab multiplication
78
+ **DUMMY_TOKENS
79
+ }
80
+
81
+ class Phi3SmallTokenizer(PreTrainedTokenizer):
82
+ vocab_files_names = {
83
+ "vocab_file": "cl100k_base.tiktoken"
84
+ }
85
+
86
+ model_input_names: List[str] = ["input_ids", "attention_mask"]
87
+ padding_side = "left"
88
+
89
+ def __init__(
90
+ self,
91
+ vocab_file: Optional[str] = None,
92
+ errors: str = "replace",
93
+ **kwargs
94
+ ) -> None:
95
+ # PreTrainedTokenizer's init calls _add_tokens, which in turn checks
96
+ # if the token is present in `self.special_tokens``. Hence instantiating it here.
97
+ # The way Qwen gets around this is by checking against SPECIAL_TOKENS
98
+ # But I think it's better to check against the objects own `special_tokens`
99
+ # in case we eventually want to allow the tokenizer to have special tokens.
100
+ self.special_tokens = SPECIAL_TOKENS
101
+
102
+ super().__init__(**kwargs)
103
+ self.errors = errors
104
+
105
+ base = tiktoken.get_encoding("cl100k_base")
106
+ if vocab_file is None:
107
+ self.mergeable_ranks: Dict[bytes, int] = base._mergeable_ranks
108
+ else:
109
+ self.mergeable_ranks = _load_tiktoken_bpe(vocab_file)
110
+
111
+ self.pat_str = base._pat_str
112
+
113
+ enc = tiktoken.Encoding(
114
+ name="phi3small",
115
+ pat_str=self.pat_str,
116
+ mergeable_ranks=self.mergeable_ranks,
117
+ special_tokens=self.special_tokens,
118
+ )
119
+ self.tokenizer = enc
120
+
121
+ self.decoder: Dict[int, bytes] = {
122
+ v: k for k, v in self.mergeable_ranks.items()
123
+ }
124
+ self.decoder.update({v: k for k, v in self.special_tokens.items()})
125
+
126
+ self.eod_id = self.tokenizer.eot_token
127
+ self._eos_token = self._convert_id_to_token(self.eod_id)
128
+
129
+ # Setting the bos_token to be the same as the eos_token
130
+ # Note that this is **not** the correct thing to do, and is done
131
+ # just so that some of the downstream libraries do not break.
132
+ self._bos_token = self._eos_token
133
+
134
+ # Assign the special tokens to class variables
135
+ self.system_id = self.special_tokens["<|system|>"]
136
+ self.user_id = self.special_tokens["<|user|>"]
137
+ self.assistant_id = self.special_tokens["<|assistant|>"]
138
+ self.end_id = self.special_tokens["<|end|>"]
139
+
140
+ @cached_property
141
+ def dummy_token_indices(self) -> List[int]:
142
+ # There are some additional special tokens in the cl100k_base tokenizer
143
+ # that we do not use. Hence, we also consider them to be dummy tokens.
144
+ additional_tokens = [
145
+ "<|fim_prefix|>",
146
+ "<|fim_middle|>",
147
+ "<|fim_suffix|>",
148
+ "<|endofprompt|>"
149
+ ]
150
+ dummy_token_indices = [index for token, index in self.special_tokens.items() if "dummy_id" in token]
151
+ dummy_token_indices.extend([self.special_tokens[token] for token in additional_tokens])
152
+ return sorted(dummy_token_indices)
153
+
154
+ def __getstate__(self):
155
+ state = self.__dict__.copy()
156
+ del state["tokenizer"]
157
+ return state
158
+
159
+ def __setstate__(self, state):
160
+ self.__dict__ = state
161
+ enc = tiktoken.Encoding(
162
+ name="cl100k_im",
163
+ pat_str=self.pat_str,
164
+ mergeable_ranks=self.mergeable_ranks,
165
+ special_tokens=self.special_tokens,
166
+ )
167
+ self.tokenizer = enc
168
+
169
+ def __len__(self):
170
+ return self.tokenizer.n_vocab
171
+
172
+ @classmethod
173
+ def from_pretrained(
174
+ cls,
175
+ pretrained_model_name_or_path: Union[str, os.PathLike],
176
+ *init_inputs,
177
+ **kwargs,
178
+ ):
179
+ cls_kwargs = kwargs
180
+ # First try to load from the tokenization config if it exists
181
+ tokenization_config = get_tokenizer_config(pretrained_model_name_or_path, **kwargs)
182
+ if tokenization_config:
183
+ cls_kwargs = {
184
+ **tokenization_config,
185
+ **cls_kwargs
186
+ }
187
+ else:
188
+ config = AutoConfig.from_pretrained(pretrained_model_name_or_path, trust_remote_code=True)
189
+ cls_kwargs["model_max_length"] = config.max_position_embeddings
190
+ return cls(**cls_kwargs)
191
+
192
+ def get_vocab(self) -> Dict[Union[str, bytes], int]:
193
+ return {**self.mergeable_ranks, **self.special_tokens}
194
+
195
+ def convert_tokens_to_ids(
196
+ self,
197
+ tokens: Union[bytes, str, List[Union[bytes, str]]]
198
+ ) -> Union[int, List[int]]:
199
+ ids = []
200
+ if isinstance(tokens, (str, bytes)):
201
+ if tokens in self.special_tokens:
202
+ return self.special_tokens[tokens]
203
+ else:
204
+ return self.mergeable_ranks.get(tokens)
205
+ ids: List[int] = []
206
+ for token in tokens:
207
+ ids.append(self.convert_tokens_to_ids(token))
208
+ return ids
209
+
210
+ def _add_tokens(
211
+ self,
212
+ new_tokens: Union[List[str], List[AddedToken]],
213
+ special_tokens: bool = False,
214
+ ) -> int:
215
+ if not special_tokens and new_tokens:
216
+ raise ValueError("Only special tokens can be added to this tokenizer")
217
+ for token in new_tokens:
218
+ surface_form = token.content if isinstance(token, AddedToken) else token
219
+ if surface_form not in self.special_tokens:
220
+ raise ValueError(
221
+ "For now, we do not support unknown special tokens\n"
222
+ "In the future, if there is a need for this, we can add special tokens to the tokenizer\n"
223
+ "starting from rank 100261 - 100263 and then 100266 - 100275.\n"
224
+ "And finally, we can re-construct the enc object back\n"
225
+ )
226
+ return 0
227
+
228
+ def save_vocabulary(self, save_directory: str, **kwargs) -> Tuple[str]:
229
+ file_path = os.path.join(save_directory, "cl100k_base.tiktoken")
230
+ with open(file_path, "w") as f:
231
+ for token, rank in self.mergeable_ranks.items():
232
+ line = base64.b64encode(token).decode("utf-8") + " " + str(rank) + "\n"
233
+ f.write(line)
234
+ return (file_path,)
235
+
236
+ def tokenize(
237
+ self,
238
+ text: str,
239
+ allowed_special: Union[Set, str] = "all",
240
+ disallowed_special: Union[Collection, str] = (),
241
+ **kwargs
242
+ ) -> List[Union[bytes, str]]:
243
+ tokens: List[Union[bytes, str]] = []
244
+ for token_id in self.tokenizer.encode(
245
+ text, allowed_special=allowed_special, disallowed_special=disallowed_special
246
+ ):
247
+ tokens.append(self.decoder[token_id])
248
+ return tokens
249
+
250
+ def convert_tokens_to_string(self, tokens: List[Union[bytes, str]]) -> str:
251
+ """
252
+ Converts a sequence of tokens in a single string.
253
+ """
254
+ text = ""
255
+ temp = b""
256
+ for t in tokens:
257
+ if isinstance(t, str):
258
+ if temp:
259
+ text += temp.decode("utf-8", errors=self.errors)
260
+ temp = b""
261
+ text += t
262
+ elif isinstance(t, bytes):
263
+ temp += t
264
+ else:
265
+ raise TypeError("token should only be of type types or str")
266
+ if temp:
267
+ text += temp.decode("utf-8", errors=self.errors)
268
+ return text
269
+
270
+ @property
271
+ def vocab_size(self):
272
+ return self.tokenizer.n_vocab
273
+
274
+ @property
275
+ def eos_token_id(self) -> int:
276
+ return self.eod_id
277
+
278
+ def _convert_id_to_token(self, index: int) -> Union[bytes, str]:
279
+ """Converts an id to a token, special tokens included"""
280
+ if index in self.decoder:
281
+ return self.decoder[index]
282
+ raise ValueError("unknown ids")
283
+
284
+ def _convert_token_to_id(self, token: Union[bytes, str]) -> int:
285
+ """Converts a token to an id using the vocab, special tokens included"""
286
+ if token in self.special_tokens:
287
+ return self.special_tokens[token]
288
+ if token in self.mergeable_ranks:
289
+ return self.mergeable_ranks[token]
290
+ raise ValueError("unknown token")
291
+
292
+ def _tokenize(self, text: str, **kwargs):
293
+ """
294
+ Converts a string in a sequence of tokens (string), using the tokenizer. Split in words for word-based
295
+ vocabulary or sub-words for sub-word-based vocabularies (BPE/SentencePieces/WordPieces).
296
+ Do NOT take care of added tokens.
297
+ """
298
+ raise NotImplementedError
299
+
300
+ def _decode(
301
+ self,
302
+ token_ids: Union[int, List[int]],
303
+ skip_special_tokens: bool = False,
304
+ errors: str = None,
305
+ **kwargs,
306
+ ) -> str:
307
+ if isinstance(token_ids, int):
308
+ token_ids = [token_ids]
309
+ if skip_special_tokens:
310
+ token_ids = [i for i in token_ids if i < self.eod_id]
311
+ return self.tokenizer.decode(token_ids, errors=errors or self.errors)
312
+
313
+
tokenizer_config.json ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_commit_hash": "fc8e001871f4a6be8e6079093b33de334a2316c9",
3
+ "_from_auto": true,
4
+ "added_tokens_decoder": {},
5
+ "auto_map": {
6
+ "AutoTokenizer": [
7
+ "tokenization_phi3_small.Phi3SmallTokenizer",
8
+ null
9
+ ]
10
+ },
11
+ "bos_token": "<|endoftext|>",
12
+ "chat_template": "{{ bos_token }}{% for message in messages %}{{'<|' + message['role'] + '|>' + '\n' + message['content'] + '<|end|>\n' }}{% endfor %}{% if add_generation_prompt %}{{ '<|assistant|>\n' }}{% else %}{{ eos_token }}{% endif %}",
13
+ "clean_up_tokenization_spaces": true,
14
+ "eos_token": "<|endoftext|>",
15
+ "legacy": false,
16
+ "model_max_length": 8192,
17
+ "pad_token": "<|endoftext|>",
18
+ "tokenizer_class": "Phi3SmallTokenizer",
19
+ "trust_remote_code": true
20
+ }
triton_blocksparse_attention_layer.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import Optional, Tuple, TypeVar
3
+ import torch.nn as nn
4
+ import torch
5
+ import triton
6
+
7
+ from functools import lru_cache
8
+
9
+
10
+ from .triton_flash_blocksparse_attn import get_local_strided_sparse_attention_op, _get_sparse_attn_mask, blocksparse_flash_attn_padded_fwd, blocksparse_flash_attn_varlen_fwd
11
+
12
+
13
+ Layout = Tuple[torch.LongTensor, torch.LongTensor]
14
+
15
+
16
+ def create_sparse_attn_mask(
17
+ n_heads: int,
18
+ max_seq_len: int,
19
+ max_seq_len_k: int,
20
+ dtype: torch.dtype,
21
+ device: torch.device,
22
+ BLOCK: int,
23
+ local_blocks: int,
24
+ vert_stride: int,
25
+ homo_head: bool,
26
+ return_dense: bool
27
+ ) -> Tuple[Layout, torch.Tensor, Optional[torch.Tensor]]:
28
+ layout, block_sparse_pattern, _ = _get_sparse_attn_mask(
29
+ n_heads=n_heads,
30
+ q_len=max_seq_len,
31
+ N_CTX=max_seq_len_k,
32
+ dtype=dtype,
33
+ device=device,
34
+ BLOCK=BLOCK,
35
+ local_blocks=local_blocks,
36
+ vert_stride=vert_stride,
37
+ homo_head=homo_head,
38
+ return_dense=return_dense
39
+ )
40
+ return layout, block_sparse_pattern
41
+
42
+
43
+ class BlockSparseAttentionLayer(nn.Module):
44
+ def __init__(
45
+ self,
46
+ n_heads: int,
47
+ max_seq_len: int,
48
+ sparse_block_size: int,
49
+ local_blocks: int,
50
+ vert_stride: int,
51
+ kernel_block_size: Optional[int] = None,
52
+ homo_head: bool = False,
53
+ active_head_range: Optional[Tuple[int]] = None
54
+ ) -> None:
55
+ super().__init__()
56
+
57
+ self.n_heads = n_heads
58
+ self.max_seq_len = max_seq_len
59
+ self.sparse_block_size = sparse_block_size
60
+ self.kernel_block_size = kernel_block_size or sparse_block_size
61
+ self.local_blocks = local_blocks
62
+ self.vert_stride = vert_stride
63
+ self.homo_head = homo_head
64
+ self.active_head_range = active_head_range
65
+
66
+ # Internal Parameters used by the layer
67
+ self._sparse_block_mask = None
68
+ self._sparse_layout = None
69
+ self._dtype = None
70
+ self._device = None
71
+
72
+ # TODO(bapatra): Ideally, I'd want to keep all the code for
73
+ # forward to be handled here, and not branch for training and inference.
74
+ # However, that refactor would need a lot of testing. For now, using the
75
+ # training op as is, and will refactor again later.
76
+
77
+ def prune_blocksparse_layout_to_heads(self, h_start: int, h_end: int) -> None:
78
+ self._sparse_block_mask = self._sparse_block_mask[h_start: h_end]
79
+ self._sparse_layout[0] = self._sparse_layout[0][h_start: h_end]
80
+ self._sparse_layout[1] = self._sparse_layout[1][h_start: h_end]
81
+
82
+ def _initialize_internals(
83
+ self,
84
+ dtype: torch.dtype,
85
+ device: torch.device
86
+ ) -> None:
87
+ self._dtype, self._device = dtype, device
88
+ self._sparse_layout, self._sparse_block_mask = create_sparse_attn_mask(
89
+ n_heads=self.n_heads,
90
+ max_seq_len=self.max_seq_len,
91
+ max_seq_len_k=self.max_seq_len,
92
+ dtype=dtype,
93
+ device=device,
94
+ BLOCK=self.sparse_block_size,
95
+ local_blocks=self.local_blocks,
96
+ vert_stride=self.vert_stride,
97
+ homo_head=self.homo_head,
98
+ return_dense=False,
99
+ )
100
+ if (not self.homo_head) and (self.active_head_range is not None):
101
+ assert len(self.active_head_range) == 2, "\"active_head_range\" should be a tuple of start/end index of the heads."
102
+ h_start, h_end = self.active_head_range
103
+ self.prune_blocksparse_layout_to_heads(h_start=h_start, h_end=h_end)
104
+
105
+ assert self.sparse_block_size % self.kernel_block_size == 0, f"The sparse block size must be a multiple of {self.kernel_block_size}. Found {self.sparse_block_size}."
106
+ assert self.kernel_block_size >=16 and math.log2(self.kernel_block_size) % 1 == 0, f"block_size must be power of 2 and at least 16, but {self.kernel_block_size} is given"
107
+ if self.sparse_block_size // self.kernel_block_size > 1:
108
+ _mul = self.sparse_block_size // self.kernel_block_size
109
+ # need to consider if block_m and block_n are different
110
+ self._sparse_block_mask = torch.kron(self._sparse_block_mask, self._sparse_block_mask.new_ones(_mul, _mul))
111
+ num_sparse_blocks = self._sparse_block_mask.size(-1)
112
+ block_causal_mask = torch.arange(0, num_sparse_blocks)[:, None] >= torch.arange(0, num_sparse_blocks)[None]
113
+ self._sparse_block_mask *= block_causal_mask.type_as(self._sparse_block_mask)
114
+
115
+
116
+ def forward(
117
+ self,
118
+ q: torch.Tensor,
119
+ k: torch.Tensor,
120
+ v: torch.Tensor,
121
+ sm_scale: float,
122
+ *,
123
+ # Arguments Related to Block Attention Inference
124
+ left_paddings: Optional[torch.LongTensor] = None,
125
+ seqlens: Optional[torch.LongTensor] = None,
126
+ # Arguements Related to Variable Length Inference
127
+ cu_seqlens_k: Optional[torch.LongTensor] = None,
128
+ cu_seqlens_q: Optional[torch.LongTensor] = None,
129
+ ) -> torch.Tensor:
130
+
131
+ if left_paddings is None and seqlens is None and cu_seqlens_k is None and cu_seqlens_q is None:
132
+ blocksparse_op = get_local_strided_sparse_attention_op(
133
+ n_heads=self.n_heads,
134
+ max_seq_len=self.max_seq_len,
135
+ sparse_block_size=self.sparse_block_size,
136
+ kernel_block_size=self.kernel_block_size,
137
+ local_blocks=self.local_blocks,
138
+ vert_stride=self.vert_stride,
139
+ homo_head=self.homo_head,
140
+ device=q.device,
141
+ inference=not self.training
142
+ )
143
+ return blocksparse_op(q, k, v, sm_scale)
144
+
145
+ assert not torch.is_grad_enabled(), "Variable Length Inference / Batched inference is not supported during training. Please run it in a torch.no_grad() context"
146
+ # First set internals if they have not been set
147
+ if self._sparse_block_mask is None or (self._dtype != q.dtype) or (self._device != q.device):
148
+ self._initialize_internals(dtype=q.dtype, device=q.device)
149
+
150
+ if k.dim() == 3:
151
+ assert cu_seqlens_k is not None
152
+ return blocksparse_flash_attn_varlen_fwd(
153
+ q=q,
154
+ k=k,
155
+ v=v,
156
+ cu_seqlens_k=cu_seqlens_k,
157
+ cu_seqlens_q=cu_seqlens_q,
158
+ sm_scale=sm_scale,
159
+ sparse_layout=self._sparse_layout,
160
+ block_size=self.kernel_block_size,
161
+ max_seqlen=self.max_seq_len,
162
+ )
163
+ if k.dim() == 4:
164
+ assert not (left_paddings is None and seqlens is None), "Either left_paddings or seqlens must be provided for batched inference."
165
+ return blocksparse_flash_attn_padded_fwd(
166
+ q=q,
167
+ k=k,
168
+ v=v,
169
+ sm_scale=sm_scale,
170
+ sparse_layout=self._sparse_layout,
171
+ left_paddings=left_paddings,
172
+ seqlens=seqlens,
173
+ block_size=self.kernel_block_size,
174
+ max_seqlen=self.max_seq_len,
175
+ )
176
+ raise ValueError('q/k/v must be either 3 dim for variable-length input or 4 dim for fixed-length.')
triton_flash_blocksparse_attn.py ADDED
@@ -0,0 +1,1947 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Author: Eric Lin (xihlin)
3
+ """
4
+ """
5
+ ... note(bapatra)::
6
+ This is written as one big file, instead of splitting into logical components because I was running into issues with transformers auto module
7
+ imports when splitting into different files. I've tried keeping the logical partitions demarkated with comment blocks, but it is not ideal.
8
+ In the future, would be really good to revisit this and refactor into a more readable file structure.
9
+
10
+ """
11
+ from typing import TypeVar
12
+ from functools import lru_cache
13
+ import math
14
+ import pytest
15
+ import torch
16
+ import numpy as np
17
+
18
+ import triton
19
+ import triton.language as tl
20
+
21
+ import os
22
+
23
+ import dataclasses
24
+
25
+ Phi3SmallConfig = TypeVar('Phi3SmallConfig')
26
+
27
+ # triton 2.0.0: fail at backward on A100, for the examples, if h_dim=128.
28
+
29
+ # Done
30
+ # 1. strided of qkv
31
+ # 2. seq len not power of 2
32
+ # 3. bf16 with Triton May, 2023
33
+
34
+ # TODO:
35
+ # 1. wip: support non-contiguous backward, also help reduce memory allocation in training (q, k, v split)
36
+ # 2. block sparse with different BLOCK_M, BLOCK_N?
37
+ # 3. for Lq not divided by BLOCK_M, BLOCK_N, only apply mask to K/V on last batch, still need to apply mask on Q.
38
+ # Attempt, fail to compile
39
+ # 4. For 2nd iter of inference, BLOCK_M=1, how to make things work? K/V maynot divided by BLOCK_N.
40
+ # 5. The inner loop can also be paralled via bigger num_stage(better) or on different thread-block (via m/L and atomic update, but this no-comm/sync between blocks)
41
+
42
+
43
+ ###########################################################
44
+ ################### Kernel Parameters #####################
45
+ ###########################################################
46
+
47
+ @dataclasses.dataclass
48
+ class BlockSparseParams(object):
49
+ block_size: int
50
+ kernel_block_size: int
51
+ num_local_blocks: int
52
+ vert_stride: int
53
+ homo_head_pattern: bool = False
54
+
55
+ @classmethod
56
+ def from_config(cls, config: Phi3SmallConfig) -> "BlockSparseParams":
57
+ return cls(
58
+ block_size=config.blocksparse_block_size,
59
+ kernel_block_size=config.blocksparse_triton_kernel_block_size,
60
+ num_local_blocks=config.blocksparse_num_local_blocks,
61
+ vert_stride=config.blocksparse_vert_stride,
62
+ homo_head_pattern=config.blocksparse_homo_head_pattern,
63
+ )
64
+
65
+
66
+ ###########################################################
67
+ ###########################################################
68
+
69
+ ###########################################################
70
+ ################### Utility Functions #####################
71
+ ###########################################################
72
+
73
+ # helper functions for 3D sparse pattern
74
+ # these function are not optimized and very inefficient. Avoid calling them too frequent.
75
+ # currently, it is only called within `get_local_strided_sparse_attention_op`, which is cached.
76
+ def dense_to_crow_col(x):
77
+ ''' Turning a 2D/3D torch tensor (x) to CSR rows/cols indexing.
78
+ param:
79
+ TODO:
80
+ 1. improve efficiency, is it faster if done in CPU, or customize a cuda kernel for it?
81
+ NOTE: col_indices padded -1
82
+ '''
83
+ pad = -1
84
+ dim = x.dim()
85
+ assert x.dim() in (2, 3)
86
+ if x.dim() == 2:
87
+ x = x[None]
88
+ x = [xi.to_sparse_csr() for xi in x]
89
+ crows = torch.vstack([xi.crow_indices() for xi in x])
90
+ cols = [xi.col_indices() for xi in x]
91
+ max_cols = max(len(xi) for xi in cols)
92
+ cols = [torch.cat([xi, pad + xi.new_zeros(max_cols - xi.shape[0])]) for xi in cols]
93
+ cols = torch.vstack(cols)
94
+ if dim == 2:
95
+ crows = crows[0]
96
+ cols = cols[0]
97
+ return crows, cols
98
+
99
+
100
+ def crow_col_to_dense(crows, cols, dtype=torch.float16):
101
+ dim = crows.dim()
102
+ if dim == 1:
103
+ crows = crows[None]
104
+ cols = cols[None]
105
+ device = crows.device
106
+ crows, cols = crows.cpu(), cols.cpu() # faster in cpu
107
+ shape = (crows.shape[0], crows.shape[1] - 1, cols.max() + 1)
108
+ x = torch.zeros(shape, dtype=dtype)
109
+ for i in range(shape[0]):
110
+ for j in range(shape[1]):
111
+ x[i, j, cols[i, crows[i, j]:crows[i, j+1]]] = 1
112
+ if dim == 1:
113
+ x = x[0]
114
+ return x.to(device)
115
+
116
+
117
+ def dense_to_ccol_row(x):
118
+ '''Similar, but to CSC format
119
+ '''
120
+ x = x.transpose(-2, -1)
121
+ return dense_to_crow_col(x)
122
+
123
+
124
+ def ccol_row_to_dense(ccol, rows, dtype=torch.float16):
125
+ return crow_col_to_dense(ccol, rows, dtype).permute(0, 2, 1).contiguous()
126
+
127
+
128
+ def _get_sparse_attn_mask_homo_head(q_len, N_CTX, dtype, device, BLOCK=128, local_blocks=4, vert_stride=4, return_dense=False):
129
+ '''
130
+ :return: a tuple of 3:
131
+ - tuple of crow_indices, col_indices representation of CSR format.
132
+ - block dense mask
133
+ - all token dense mask (be aware that it can be OOM if it is too big) if `return_dense==True`, otherwise, None
134
+ '''
135
+ with torch.no_grad():
136
+ N_BLOCK = triton.cdiv(N_CTX, BLOCK)
137
+ q_pos = torch.arange(N_BLOCK)[:, None]
138
+ k_pos = torch.arange(N_BLOCK)[None]
139
+ mask_vert_strided = (torch.arange(N_BLOCK) + 1) % vert_stride == 0
140
+ block_mask_dense = ((q_pos >= k_pos) & ((q_pos - k_pos < local_blocks) | mask_vert_strided)).to(device).to(dtype)
141
+ N_BLOCK_Q = triton.cdiv(q_len, BLOCK)
142
+ block_mask_dense_output = block_mask_dense[-N_BLOCK_Q:].contiguous().to_sparse_csr()
143
+ if return_dense:
144
+ mask_dense = torch.kron(block_mask_dense, block_mask_dense.new_ones((BLOCK, BLOCK)))
145
+ causal_mask = torch.tril(torch.ones(N_CTX, N_CTX)).type_as(mask_dense)[-q_len:]
146
+ mask_dense = mask_dense[-q_len:, :N_CTX] * causal_mask
147
+ return (block_mask_dense_output.crow_indices(), block_mask_dense_output.col_indices()), block_mask_dense, mask_dense
148
+ else:
149
+ return (block_mask_dense_output.crow_indices(), block_mask_dense_output.col_indices()), block_mask_dense, None
150
+
151
+
152
+ def _get_sparse_attn_mask(n_heads, q_len, N_CTX, dtype, device, BLOCK=128, local_blocks=4, vert_stride=4, homo_head=True, return_dense=False):
153
+ '''
154
+ :return: a tuple of 3:
155
+ - tuple of crow_indices, col_indices representation of CSR format.
156
+ - block dense mask
157
+ - all token dense mask (be aware that it can be OOM if it is too big) if `return_dense==True`, otherwise, None
158
+ '''
159
+ if homo_head:
160
+ with torch.no_grad():
161
+ (crow, col), block_mask_dense, mask_dense = _get_sparse_attn_mask_homo_head(q_len, N_CTX, dtype, device, BLOCK, local_blocks, vert_stride, return_dense)
162
+ crow = crow[None].expand(n_heads, crow.shape[0])
163
+ col = col[None].expand(n_heads, col.shape[0])
164
+ if return_dense:
165
+ mask_dense = mask_dense[None].expand(n_heads, *mask_dense.shape)
166
+ return (crow, col), block_mask_dense, mask_dense
167
+
168
+ with torch.no_grad():
169
+ N_BLOCK = triton.cdiv(N_CTX, BLOCK)
170
+ q_pos = torch.arange(N_BLOCK)[None, :, None]
171
+ k_pos = torch.arange(N_BLOCK)[None, None]
172
+ head_sliding_step = max(1, int(vert_stride / n_heads)) # if vert_stride <= n_heads, rotating the heads
173
+ mask_vert_strided = [(torch.arange(N_BLOCK) + h * head_sliding_step + 1) % vert_stride == 0 for h in range(n_heads)]
174
+ mask_vert_strided = torch.vstack(mask_vert_strided).unsqueeze(1)
175
+ block_mask_dense = ((q_pos >= k_pos) & ((q_pos - k_pos < local_blocks) | mask_vert_strided)).to(device).to(dtype)
176
+ N_BLOCK_Q = triton.cdiv(q_len, BLOCK)
177
+ block_mask_dense_output = block_mask_dense[:, -N_BLOCK_Q:]
178
+ if return_dense:
179
+ mask_dense = torch.kron(block_mask_dense, block_mask_dense.new_ones((BLOCK, BLOCK)))
180
+ causal_mask = torch.tril(torch.ones(N_CTX, N_CTX)).type_as(mask_dense)[-q_len:]
181
+ mask_dense = mask_dense[..., -q_len:, :N_CTX] * causal_mask[None]
182
+ return dense_to_crow_col(block_mask_dense_output), block_mask_dense, mask_dense
183
+ else:
184
+ return dense_to_crow_col(block_mask_dense_output), block_mask_dense, None
185
+
186
+
187
+ def get_sparse_attn_mask(q, N_CTX, *args, **kwargs):
188
+ return _get_sparse_attn_mask(q.size(1), q.size(2), N_CTX, q.dtype, q.device, *args, **kwargs)
189
+
190
+ ###########################################################
191
+ ###########################################################
192
+
193
+ ###########################################################
194
+ ###################### Training Kernels ###################
195
+ ###########################################################
196
+
197
+ # TODO: only apply loading/saving mask on the last iteration for EVEN_N_BLOCK, useful for 1st iteration of inference.
198
+ # Experiment failed inside loop.
199
+ # Another idea: only on saving? load even out of boundary(will it causes illegal access error)?
200
+ @triton.jit
201
+ def _fwd_kernel(
202
+ Q, K, V, sm_scale,
203
+ layout_crow_ptr,
204
+ layout_col_ptr,
205
+ layout_crow_stride_h, layout_crow_stride_m,
206
+ layout_col_stride_h, layout_col_stride_m,
207
+ TMP, L, M, # NOTE: TMP is a scratchpad buffer to workaround a compiler bug. TMP, L, M are assumed to have contiguous layouts
208
+ Out,
209
+ stride_qz, stride_qh, stride_qm, stride_qd,
210
+ stride_kz, stride_kh, stride_kn, stride_kd,
211
+ stride_vz, stride_vh, stride_vn, stride_vd,
212
+ stride_oz, stride_oh, stride_om, stride_od,
213
+ Z, H, N_CTX,
214
+ PAST_LEN,
215
+ Q_ROUNDED_LEN,
216
+ BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
217
+ BLOCK_N: tl.constexpr,
218
+ EVEN_M_BLOCK: tl.constexpr,
219
+ EVEN_N_BLOCK: tl.constexpr,
220
+ INFERENCE: tl.constexpr,
221
+ NUM_DBLOCKS: tl.constexpr,
222
+ ):
223
+ Q_LEN = N_CTX - PAST_LEN
224
+ start_m = tl.program_id(0)
225
+ off_hz = tl.program_id(1)
226
+ off_h = off_hz % H
227
+ off_z = off_hz // H
228
+ Q += off_z * stride_qz + off_h * stride_qh
229
+ K += off_z * stride_kz + off_h * stride_kh
230
+ V += off_z * stride_vz + off_h * stride_vh
231
+ # initialize offsets
232
+ offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
233
+ offs_n = tl.arange(0, BLOCK_N)
234
+ offs_d = tl.arange(0, BLOCK_DMODEL)
235
+ off_q = offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qd
236
+ # off_k = offs_n[:, None] * stride_kn + offs_d[None, :] * stride_kd
237
+ off_k = offs_n[None, :] * stride_kn + offs_d[:, None] * stride_kd
238
+ off_v = offs_n[:, None] * stride_vn + offs_d[None, :] * stride_vd
239
+ # Initialize pointers to Q, K, V
240
+ q_ptrs = Q + off_q
241
+ k_ptrs = K + off_k
242
+ v_ptrs = V + off_v
243
+ # initialize pointer to m and l
244
+ t_ptrs = TMP + off_hz * Q_ROUNDED_LEN + offs_m
245
+ m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float('inf')
246
+ l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
247
+ acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
248
+ if NUM_DBLOCKS >= 2:
249
+ acc2 = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
250
+
251
+ # load q: it will stay in SRAM throughout
252
+ if EVEN_M_BLOCK:
253
+ q = tl.load(q_ptrs)
254
+ if NUM_DBLOCKS >= 2:
255
+ q2 = tl.load(q_ptrs + BLOCK_DMODEL * stride_qd)
256
+ else:
257
+ q = tl.load(q_ptrs, mask=offs_m[:, None] < Q_LEN)
258
+ if NUM_DBLOCKS >= 2:
259
+ q2 = tl.load(q_ptrs + BLOCK_DMODEL * stride_qd, mask=offs_m[:, None] < Q_LEN)
260
+
261
+ layout_ptr = layout_crow_ptr + off_h * layout_crow_stride_h + start_m * layout_crow_stride_m
262
+ start_l = tl.load(layout_ptr).to(tl.int32)
263
+ end_l = tl.load(layout_ptr + layout_crow_stride_m).to(tl.int32)
264
+
265
+ # loop over k, v and update accumulator
266
+ for col_idx_idx in range(start_l, end_l):
267
+ col_idx = tl.load(layout_col_ptr + off_h * layout_col_stride_h + col_idx_idx * layout_col_stride_m).to(tl.int32)
268
+ start_n = col_idx * BLOCK_N
269
+ # -- compute qk ----
270
+ if EVEN_N_BLOCK:
271
+ k = tl.load(k_ptrs + start_n * stride_kn)
272
+ else:
273
+ k = tl.load(k_ptrs + start_n * stride_kn, mask=offs_n[None, :] + start_n < N_CTX)
274
+ qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
275
+ qk += tl.dot(q, k)
276
+
277
+ if NUM_DBLOCKS >= 2:
278
+ if EVEN_N_BLOCK:
279
+ k = tl.load(k_ptrs + start_n * stride_kn + BLOCK_DMODEL * stride_kd)
280
+ else:
281
+ k = tl.load(k_ptrs + start_n * stride_kn + BLOCK_DMODEL * stride_kd, mask=offs_n[None, :] + start_n < N_CTX)
282
+ qk += tl.dot(q2, k)
283
+
284
+ qk *= sm_scale
285
+ qk += tl.where(offs_m[:, None] + PAST_LEN >= (start_n + offs_n[None, :]), 0, float('-inf'))
286
+ # -- compute m_ij, p, l_ij
287
+ m_ij = tl.max(qk, 1)
288
+ p = tl.exp(qk - m_ij[:, None])
289
+ l_ij = tl.sum(p, 1)
290
+ # -- update m_i and l_i
291
+ m_i_new = tl.maximum(m_i, m_ij)
292
+ alpha = tl.exp(m_i - m_i_new)
293
+ beta = tl.exp(m_ij - m_i_new)
294
+ l_i_new = alpha * l_i + beta * l_ij
295
+ # -- update output accumulator --
296
+ # scale p
297
+ p_scale = beta / l_i_new
298
+ p = p * p_scale[:, None]
299
+ # scale acc
300
+ acc_scale = l_i / l_i_new * alpha
301
+ # tl.store(t_ptrs, acc_scale)
302
+ # acc_scale = tl.load(t_ptrs) # BUG: have to store and immediately load
303
+ acc = acc * acc_scale[:, None]
304
+ if NUM_DBLOCKS >= 2:
305
+ acc2 = acc2 * acc_scale[:, None]
306
+ p = p.to(Q.dtype.element_ty)
307
+ # update acc
308
+ if EVEN_N_BLOCK:
309
+ v = tl.load(v_ptrs + start_n * stride_vn)
310
+ else:
311
+ v = tl.load(v_ptrs + start_n * stride_vn, mask=offs_n[:, None] + start_n < N_CTX)
312
+ acc += tl.dot(p, v)
313
+
314
+ if NUM_DBLOCKS >= 2:
315
+ if EVEN_N_BLOCK:
316
+ v = tl.load(v_ptrs + start_n * stride_vn + BLOCK_DMODEL * stride_vd)
317
+ else:
318
+ v = tl.load(v_ptrs + start_n * stride_vn + BLOCK_DMODEL * stride_vd, mask=offs_n[:, None] + start_n < N_CTX)
319
+ acc2 += tl.dot(p, v)
320
+
321
+ # update m_i and l_i
322
+ l_i = l_i_new
323
+ m_i = m_i_new
324
+
325
+ # rematerialize offsets to save registers
326
+ # start_m = tl.program_id(0)
327
+ # offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
328
+ # write back l and m
329
+ if not INFERENCE:
330
+ l_ptrs = L + off_hz * N_CTX + offs_m
331
+ m_ptrs = M + off_hz * N_CTX + offs_m
332
+ if EVEN_M_BLOCK:
333
+ tl.store(l_ptrs, l_i)
334
+ tl.store(m_ptrs, m_i)
335
+ else:
336
+ tl.store(l_ptrs, l_i, mask=offs_m < Q_LEN)
337
+ tl.store(m_ptrs, m_i, mask=offs_m < Q_LEN)
338
+ # initialize pointers to output
339
+ # offs_n = tl.arange(0, BLOCK_DMODEL)
340
+ off_o = off_z * stride_oz + off_h * stride_oh + offs_m[:, None] * stride_om + offs_d[None, :] * stride_od
341
+ out_ptrs = Out + off_o
342
+ tl.store(out_ptrs, acc, mask=offs_m[:, None] < Q_LEN)
343
+ if NUM_DBLOCKS >= 2:
344
+ tl.store(out_ptrs + BLOCK_DMODEL * stride_od, acc2, mask=offs_m[:, None] < Q_LEN)
345
+
346
+
347
+ ## backward
348
+ @triton.heuristics(
349
+ {
350
+ 'EVEN_M_BLOCK': lambda kwargs: kwargs['N_CTX'] % kwargs['BLOCK_M'] == 0,
351
+ }
352
+ )
353
+ @triton.jit
354
+ def _bwd_preprocess(
355
+ Out, DO, L, # assume contiguous for Out, DO, L, NewDO, Delta layout.
356
+ NewDO, Delta,
357
+ N_CTX,
358
+ BLOCK_M: tl.constexpr, D_HEAD: tl.constexpr,
359
+ EVEN_M_BLOCK: tl.constexpr,
360
+ ):
361
+ off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)
362
+ off_d = tl.arange(0, D_HEAD)
363
+ # load
364
+ if EVEN_M_BLOCK:
365
+ o = tl.load(Out + off_m[:, None] * D_HEAD + off_d[None, :]).to(tl.float32)
366
+ do = tl.load(DO + off_m[:, None] * D_HEAD + off_d[None, :]).to(tl.float32)
367
+ else:
368
+ o = tl.load(Out + off_m[:, None] * D_HEAD + off_d[None, :], mask=off_m[:, None] < N_CTX).to(tl.float32)
369
+ do = tl.load(DO + off_m[:, None] * D_HEAD + off_d[None, :], mask=off_m[:, None] < N_CTX).to(tl.float32)
370
+ denom = tl.load(L + off_m).to(tl.float32)
371
+ # compute
372
+ do = do / denom[:, None]
373
+ delta = tl.sum(o * do, axis=1)
374
+ # write-back
375
+ if EVEN_M_BLOCK:
376
+ tl.store(NewDO + off_m[:, None] * D_HEAD + off_d[None, :], do)
377
+ else:
378
+ tl.store(NewDO + off_m[:, None] * D_HEAD + off_d[None, :], do, mask=off_m[:, None] < N_CTX)
379
+ tl.store(Delta + off_m, delta)
380
+
381
+
382
+ # Does not suuport unequal seqlen(q) and seqlen(k)
383
+ @triton.heuristics(
384
+ {
385
+ 'EVEN_M_BLOCK': lambda kwargs: kwargs['N_CTX'] % kwargs['BLOCK_M'] == 0,
386
+ 'EVEN_N_BLOCK': lambda kwargs: kwargs['N_CTX'] % kwargs['BLOCK_N'] == 0,
387
+ }
388
+ )
389
+ @triton.jit
390
+ def _bwd_kernel(
391
+ Q, K, V, sm_scale,
392
+ layout_ccol_ptr,
393
+ layout_row_ptr,
394
+ layout_ccol_stride_h, layout_ccol_stride_m,
395
+ layout_row_stride_h, layout_row_stride_m,
396
+ Out, DO, # assume contigous: Out, Do, DQ, DK, DV, L, M, D, seq(q) == seq(k), with stride_oz, stride_oh, stride_om, stride_od,
397
+ DQ, DK, DV,
398
+ L, M,
399
+ D,
400
+ stride_qz, stride_qh, stride_qm, stride_qd,
401
+ stride_kz, stride_kh, stride_kn, stride_kd,
402
+ stride_vz, stride_vh, stride_vn, stride_vd,
403
+ stride_oz, stride_oh, stride_om, stride_od,
404
+ # stride_dz, stride_dh, stride_dm, stride_dd,
405
+ Z, H, N_CTX,
406
+ num_block,
407
+ BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
408
+ BLOCK_N: tl.constexpr,
409
+ EVEN_M_BLOCK: tl.constexpr,
410
+ EVEN_N_BLOCK: tl.constexpr,
411
+ NUM_DBLOCKS: tl.constexpr,
412
+ ):
413
+ start_n = tl.program_id(0)
414
+ off_hz = tl.program_id(1)
415
+ off_z = off_hz // H
416
+ off_h = off_hz % H
417
+ # offset pointers for batch/head
418
+ Q += off_z * stride_qz + off_h * stride_qh
419
+ K += off_z * stride_kz + off_h * stride_kh
420
+ V += off_z * stride_vz + off_h * stride_vh
421
+ DO += off_z * stride_oz + off_h * stride_oh
422
+ DQ += off_z * stride_oz + off_h * stride_oh
423
+ DK += off_z * stride_oz + off_h * stride_oh
424
+ DV += off_z * stride_oz + off_h * stride_oh
425
+ # Look like this loop can be parallelled
426
+ # for start_n in range(0, num_block):
427
+
428
+ offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N)
429
+ offs_m = tl.arange(0, BLOCK_M)
430
+ offs_d = tl.arange(0, BLOCK_DMODEL)
431
+ # initialize pointers to value-like data
432
+ k_ptrs = K + (offs_n[:, None] * stride_kn + offs_d[None, :] * stride_kd)
433
+ v_ptrs = V + (offs_n[:, None] * stride_vn + offs_d[None, :] * stride_vd)
434
+
435
+ # pointer to row-wise quantities in value-like data
436
+ D_ptrs = D + off_hz * N_CTX
437
+ m_ptrs = M + off_hz * N_CTX
438
+ # initialize dv amd dk
439
+ dv = tl.zeros([BLOCK_N, BLOCK_DMODEL], dtype=tl.float32)
440
+ dk = tl.zeros([BLOCK_N, BLOCK_DMODEL], dtype=tl.float32)
441
+ # k and v stay in SRAM throughout
442
+ if EVEN_N_BLOCK:
443
+ k = tl.load(k_ptrs)
444
+ v = tl.load(v_ptrs)
445
+ else:
446
+ k = tl.load(k_ptrs, mask=offs_n[:, None] < N_CTX)
447
+ v = tl.load(v_ptrs, mask=offs_n[:, None] < N_CTX)
448
+
449
+ if NUM_DBLOCKS >= 2:
450
+ dv2 = tl.zeros([BLOCK_N, BLOCK_DMODEL], dtype=tl.float32)
451
+ dk2 = tl.zeros([BLOCK_N, BLOCK_DMODEL], dtype=tl.float32)
452
+ if EVEN_N_BLOCK:
453
+ k2 = tl.load(k_ptrs + BLOCK_DMODEL * stride_kd)
454
+ v2 = tl.load(v_ptrs + BLOCK_DMODEL * stride_vd)
455
+ else:
456
+ k2 = tl.load(k_ptrs + BLOCK_DMODEL * stride_kd, mask=offs_n[:, None] < N_CTX)
457
+ v2 = tl.load(v_ptrs + BLOCK_DMODEL * stride_vd, mask=offs_n[:, None] < N_CTX)
458
+
459
+ # loop over rows
460
+
461
+ layout_ptr = layout_ccol_ptr + off_h * layout_ccol_stride_h + start_n * layout_ccol_stride_m
462
+ start_l = tl.load(layout_ptr).to(tl.int32)
463
+ end_l = tl.load(layout_ptr + layout_ccol_stride_m).to(tl.int32)
464
+
465
+ for row_idx_idx in range(start_l, end_l):
466
+ row_idx = tl.load(layout_row_ptr + off_h * layout_row_stride_h + row_idx_idx * layout_row_stride_m).to(tl.int32)
467
+ start_m = row_idx * BLOCK_M
468
+
469
+ # offs_qm = start_m + tl.arange(0, BLOCK_M)
470
+ offs_m_curr = start_m + offs_m
471
+ q_ptrs = Q + (offs_m_curr[:, None] * stride_qm + offs_d[None, :] * stride_qd)
472
+ do_ptrs = DO + (offs_m_curr[:, None] * stride_om + offs_d[None, :] * stride_od)
473
+ dq_ptrs = DQ + (offs_m_curr[:, None] * stride_om + offs_d[None, :] * stride_od)
474
+
475
+ # load q, k, v, do on-chip
476
+ if EVEN_M_BLOCK:
477
+ q = tl.load(q_ptrs)
478
+ else:
479
+ q = tl.load(q_ptrs, mask=offs_m_curr[:, None] < N_CTX)
480
+ # re-compute p = softmax(qk, dim=-1).T
481
+ # NOTE: `do` is pre-divided by `l`; no normalization here
482
+ qk = tl.dot(q, tl.trans(k))
483
+
484
+ if NUM_DBLOCKS >= 2:
485
+ if EVEN_M_BLOCK:
486
+ q2 = tl.load(q_ptrs + BLOCK_DMODEL * stride_qd)
487
+ else:
488
+ q2 = tl.load(q_ptrs + BLOCK_DMODEL * stride_qd, mask=offs_m_curr[:, None] < N_CTX)
489
+ qk += tl.dot(q2, tl.trans(k2))
490
+
491
+ qk += tl.where(offs_m_curr[:, None] >= (offs_n[None, :]), 0, float('-inf'))
492
+
493
+ if EVEN_M_BLOCK:
494
+ m = tl.load(m_ptrs + offs_m_curr)
495
+ else:
496
+ m = tl.load(m_ptrs + offs_m_curr, mask=offs_m_curr < N_CTX)
497
+ p = tl.exp(qk * sm_scale - m[:, None])
498
+
499
+ # compute dv
500
+ if EVEN_M_BLOCK:
501
+ do = tl.load(do_ptrs)
502
+ else:
503
+ do = tl.load(do_ptrs, mask=offs_m_curr[:, None] < N_CTX)
504
+
505
+ if NUM_DBLOCKS >= 2:
506
+ if EVEN_M_BLOCK:
507
+ do2 = tl.load(do_ptrs + BLOCK_DMODEL * stride_od)
508
+ else:
509
+ do2 = tl.load(do_ptrs + BLOCK_DMODEL * stride_od, mask=offs_m_curr[:, None] < N_CTX)
510
+
511
+ dv += tl.dot(tl.trans(p.to(Q.dtype.element_ty)), do)
512
+
513
+ if NUM_DBLOCKS >= 2:
514
+ dv2 += tl.dot(tl.trans(p.to(Q.dtype.element_ty)), do2)
515
+
516
+ # compute dp = dot(v, do)
517
+ if EVEN_M_BLOCK:
518
+ Di = tl.load(D_ptrs + offs_m_curr)
519
+ else:
520
+ Di = tl.load(D_ptrs + offs_m_curr, mask=offs_m_curr < N_CTX)
521
+ dp = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - Di[:, None]
522
+ dp += tl.dot(do, tl.trans(v))
523
+
524
+ if NUM_DBLOCKS >= 2:
525
+ dp += tl.dot(do2, tl.trans(v2))
526
+
527
+ # compute ds = p * (dp - delta[:, None])
528
+ ds = p * dp * sm_scale
529
+ # compute dk = dot(ds.T, q)
530
+ dk += tl.dot(tl.trans(ds.to(Q.dtype.element_ty)), q)
531
+ if NUM_DBLOCKS >= 2:
532
+ dk2 += tl.dot(tl.trans(ds.to(Q.dtype.element_ty)), q2)
533
+
534
+ # # compute dq
535
+ dq = tl.dot(ds.to(Q.dtype.element_ty), k)
536
+ if EVEN_M_BLOCK:
537
+ tl.atomic_add(dq_ptrs, dq)
538
+ else:
539
+ tl.atomic_add(dq_ptrs, dq, mask=offs_m_curr[:, None] < N_CTX)
540
+
541
+ if NUM_DBLOCKS >= 2:
542
+ dq2 = tl.dot(ds.to(Q.dtype.element_ty), k2)
543
+ dq_ptrs2 = dq_ptrs + BLOCK_DMODEL * stride_od
544
+ if EVEN_M_BLOCK:
545
+ tl.atomic_add(dq_ptrs2, dq2)
546
+ else:
547
+ tl.atomic_add(dq_ptrs2, dq2, mask=offs_m_curr[:, None] < N_CTX)
548
+
549
+ # write-back
550
+ dv_ptrs = DV + (offs_n[:, None] * stride_om + offs_d[None, :] * stride_od)
551
+ dk_ptrs = DK + (offs_n[:, None] * stride_om + offs_d[None, :] * stride_od)
552
+ if EVEN_N_BLOCK:
553
+ tl.store(dv_ptrs, dv)
554
+ tl.store(dk_ptrs, dk)
555
+ else:
556
+ tl.store(dv_ptrs, dv, mask=offs_n[:, None] < N_CTX)
557
+ tl.store(dk_ptrs, dk, mask=offs_n[:, None] < N_CTX)
558
+
559
+ if NUM_DBLOCKS >= 2:
560
+ dv_ptrs2 = dv_ptrs + BLOCK_DMODEL * stride_od
561
+ dk_ptrs2 = dk_ptrs + BLOCK_DMODEL * stride_od
562
+ if EVEN_N_BLOCK:
563
+ tl.store(dv_ptrs2, dv2)
564
+ tl.store(dk_ptrs2, dk2)
565
+ else:
566
+ tl.store(dv_ptrs2, dv2, mask=offs_n[:, None] < N_CTX)
567
+ tl.store(dk_ptrs2, dk2, mask=offs_n[:, None] < N_CTX)
568
+
569
+
570
+
571
+ def _forward(ctx, q, k, v, layout_crow_indices, layout_col_indices, sm_scale, BLOCK_M, BLOCK_N, num_warps=None, num_stages=1, inference=None, out=None):
572
+ '''
573
+ :param q, k, v: [batch, n_heads, seq_len, model_dim]. len of q is allowed to be different than k/v.
574
+ :param layout_crow_indices, layout_col_indices: same as CSR.crow_indices, and CSR.col_indices used to preresent a sparse tensor.
575
+ Each element represent a block, i.e, all elements in a block to be attentdd, or not attended at all..
576
+ '''
577
+ assert q.shape[-1] == k.shape[-1] == v.shape[-1]
578
+ assert k.shape[2] == v.shape[2]
579
+ o = out if out is not None else torch.empty_like(q).contiguous()
580
+ grid = (triton.cdiv(q.shape[2], BLOCK_M), q.shape[0] * q.shape[1])
581
+
582
+ q_rounded_len = grid[0] * BLOCK_M
583
+ tmp = torch.empty((q.shape[0] * q.shape[1], q_rounded_len), device=q.device, dtype=torch.float32)
584
+
585
+ if inference is None:
586
+ inference = (not q.requires_grad) and (not k.requires_grad) and (not v.requires_grad)
587
+
588
+ if inference:
589
+ L, m = tmp, tmp # no need to use create new tensor
590
+ else:
591
+ L = torch.empty((q.shape[0] * q.shape[1], q_rounded_len), device=q.device, dtype=torch.float32)
592
+ m = torch.empty((q.shape[0] * q.shape[1], q_rounded_len), device=q.device, dtype=torch.float32)
593
+
594
+ if layout_col_indices.dim() == 1:
595
+ layout_crow_indices = layout_crow_indices[None].expand(q.shape[1] , -1)
596
+ layout_col_indices = layout_col_indices[None].expand(q.shape[1] , -1)
597
+
598
+ assert q.shape[-1] in [64, 128]
599
+ BLOCK_DMODEL = 64
600
+
601
+ if num_warps is None:
602
+ MIN_D = min(BLOCK_M, BLOCK_N, BLOCK_DMODEL)
603
+ num_warps = max(1, 2 ** int(math.log2(MIN_D / 16)))
604
+ # print(f'> {BLOCK_M=}, {BLOCK_N=}, {BLOCK_DMODEL=}, {num_warps=}, {num_stages=}')
605
+ else:
606
+ assert math.log2(num_warps) % 1 == 0, f'''"num_warps" should be power of 2, but got {num_warps}.'''
607
+
608
+ ## For debugging:
609
+ # print(f'>> {q.shape=}, {k.shape=}, {BLOCK_M=}, {BLOCK_N=}, {num_warps=}, {BLOCK_DMODEL=}, {q.stride()=}, {k.stride()=}')
610
+ # print(f'>> {layout_crow_indices=}\n{layout_col_indices=}\n {layout_crow_indices.stride()=}, {layout_crow_indices.stride()=}')
611
+ # print(f'> {q.shape=}, {k.shape=}, {layout_crow_indices.shape}, {layout_col_indices.shape}, {layout_crow_indices.stride()}, \
612
+ # {layout_col_indices.stride()}, {layout_crow_indices=}, {layout_col_indices=}')
613
+
614
+ with torch.cuda.device(q.device.index):
615
+ _fwd_kernel[grid](
616
+ q, k, v, sm_scale,
617
+ layout_crow_indices,
618
+ layout_col_indices,
619
+ layout_crow_indices.stride(0), layout_crow_indices.stride(1),
620
+ layout_col_indices.stride(0), layout_col_indices.stride(1),
621
+ tmp, L, m,
622
+ o,
623
+ q.stride(0), q.stride(1), q.stride(2), q.stride(3),
624
+ k.stride(0), k.stride(1), k.stride(2), k.stride(3),
625
+ v.stride(0), v.stride(1), v.stride(2), v.stride(3),
626
+ o.stride(0), o.stride(1), o.stride(2), o.stride(3),
627
+ q.shape[0], q.shape[1], k.shape[2],
628
+ k.shape[2] - q.shape[2],
629
+ q_rounded_len,
630
+ BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N,
631
+ BLOCK_DMODEL=BLOCK_DMODEL,
632
+ EVEN_M_BLOCK=q.shape[2] % BLOCK_M == 0,
633
+ EVEN_N_BLOCK=k.shape[2] % BLOCK_N == 0 ,
634
+ INFERENCE=inference,
635
+ NUM_DBLOCKS=q.shape[-1] // BLOCK_DMODEL,
636
+ num_warps=num_warps,
637
+ num_stages=num_stages,
638
+ )
639
+ if inference:
640
+ L, m = None, None
641
+
642
+ ctx.save_for_backward(q, k, v, o, L, m, layout_crow_indices, layout_col_indices)
643
+ ctx.BLOCK_M = BLOCK_M
644
+ ctx.BLOCK_N = BLOCK_N
645
+ ctx.BLOCK_DMODEL = BLOCK_DMODEL
646
+ # ctx.BLOCK = BLOCK
647
+ ctx.grid = grid
648
+ ctx.sm_scale = sm_scale
649
+ ctx.num_warps = num_warps
650
+ ctx.num_stages = num_stages
651
+ return o
652
+
653
+
654
+ def _backward(ctx, do, layout_ccol_indices, layout_row_indices, dq=None, dk=None, dv=None):
655
+ # q, k, v, o, l, m = ctx.saved_tensors
656
+ q, k, v, o, l, m, layout_crow_indices, layout_col_indices = ctx.saved_tensors
657
+
658
+ ## this following too slow to do online, so get it from inputs, which is cached.
659
+ # layout_ccol_indices, layout_row_indices = dense_to_ccol_row(crow_col_to_dense(ctx.layout_crow_indices, ctx.layout_col_indices))
660
+ # layout_ccol_indices, layout_row_indices = dense_to_ccol_row(crow_col_to_dense(layout_crow_indices, layout_col_indices))
661
+
662
+ if not do.is_contiguous():
663
+ do = do.contiguous()
664
+ ## for debugging
665
+ # print(f'----> do is not contiguous: {do.stride()=}')
666
+ # raise ValueError(f'>>>> output grad is not contiguous: {do.stride()=}')
667
+
668
+ if not o.is_contiguous():
669
+ # TODO: currently only work with contiguous q/k/v.
670
+ raise ValueError(f'--> output is not contiguous: {o.stride()=}. This is maybe caused by q/k/v not being contiguous.')
671
+
672
+
673
+ if layout_ccol_indices.dim() == 1:
674
+ layout_ccol_indices = layout_ccol_indices[None].expand(q.shape[1], -1)
675
+ layout_row_indices = layout_row_indices[None].expand(q.shape[1], -1)
676
+
677
+ # do = do.contiguous()
678
+ dq = dq if dq is not None else torch.zeros_like(q, dtype=torch.float32)
679
+ dk = dk if dk is not None else torch.empty_like(k)
680
+ dv =dv if dv is not None else torch.empty_like(v)
681
+ do_scaled = torch.empty_like(do)
682
+ delta = torch.empty_like(l)
683
+
684
+ assert o.stride() == dq.stride() == dk.stride() == dv.stride() == do_scaled.stride()
685
+
686
+ _bwd_preprocess[(ctx.grid[0] * ctx.grid[1], )](
687
+ o, do, l,
688
+ do_scaled, delta,
689
+ k.shape[2],
690
+ BLOCK_M=ctx.BLOCK_M, D_HEAD=q.shape[-1],
691
+ )
692
+
693
+ grid = (triton.cdiv(q.shape[2], ctx.BLOCK_N), ctx.grid[1])
694
+
695
+ _bwd_kernel[grid](
696
+ q, k, v, ctx.sm_scale,
697
+ layout_ccol_indices,
698
+ layout_row_indices,
699
+ layout_ccol_indices.stride(0), layout_ccol_indices.stride(1),
700
+ layout_row_indices.stride(0), layout_row_indices.stride(1),
701
+ o, do_scaled,
702
+ dq, dk, dv,
703
+ l, m,
704
+ delta,
705
+ q.stride(0), q.stride(1), q.stride(2), q.stride(3),
706
+ k.stride(0), k.stride(1), k.stride(2), k.stride(3),
707
+ v.stride(0), v.stride(1), v.stride(2), v.stride(3),
708
+ o.stride(0), o.stride(1), o.stride(2), o.stride(3),
709
+ q.shape[0], q.shape[1], q.shape[2],
710
+ ctx.grid[0],
711
+ BLOCK_M=ctx.BLOCK_M,
712
+ BLOCK_N=ctx.BLOCK_N,
713
+ BLOCK_DMODEL=ctx.BLOCK_DMODEL,
714
+ NUM_DBLOCKS=q.shape[-1] // ctx.BLOCK_DMODEL,
715
+ num_warps=ctx.num_warps,
716
+ num_stages=1,
717
+ )
718
+ return dq, dk, dv, None, None, None
719
+
720
+
721
+ class _sparse_attention(torch.autograd.Function):
722
+
723
+ @staticmethod
724
+ def forward(ctx, q, k, v, layout_crow_indices, layout_col_indices, sm_scale):
725
+ BLOCK = 128
726
+ # shape constraints
727
+ return _forward(ctx, q, k, v, layout_crow_indices, layout_col_indices, sm_scale, BLOCK, BLOCK)
728
+
729
+ @staticmethod
730
+ def backward(ctx, do):
731
+ # q, k, v, o, l, m = ctx.saved_tensors
732
+ q, k, v, o, l, m, layout_crow_indices, layout_col_indices = ctx.saved_tensors
733
+ # TODO: the following is very inefficient.
734
+ # layout_ccol_indices, layout_row_indices = dense_to_ccol_row(crow_col_to_dense(ctx.layout_crow_indices, ctx.layout_col_indices))
735
+ layout_ccol_indices, layout_row_indices = dense_to_ccol_row(crow_col_to_dense(layout_crow_indices, layout_col_indices))
736
+ return _backward(ctx, do, layout_ccol_indices, layout_row_indices)
737
+
738
+
739
+
740
+ # suppressed
741
+ class _sparse_attention_inference(_sparse_attention):
742
+ # TODO: does not work now, as BLOCK_M cannot be <1, as shape for tl.dot cannot be smaller than 16.
743
+ @staticmethod
744
+ def forward(ctx, q, k, v, layout_crow_indices, layout_col_indices, sm_scale):
745
+ BLOCK = 128
746
+ return _forward(ctx, q, k, v, layout_crow_indices, layout_col_indices, sm_scale, 1, BLOCK)
747
+
748
+
749
+
750
+ def sparse_attention_factory(BLOCK_M=128, BLOCK_N=128, **kwargs):
751
+ class _sparse_attention_config(_sparse_attention):
752
+ @staticmethod
753
+ def forward(ctx, q, k, v, layout_crow_indices, layout_col_indices, sm_scale):
754
+ # shape constraints
755
+ return _forward(ctx, q, k, v, layout_crow_indices, layout_col_indices, sm_scale, BLOCK_M, BLOCK_N,
756
+ **kwargs
757
+ )
758
+ return _sparse_attention_config.apply
759
+
760
+
761
+ @lru_cache(maxsize=8)
762
+ def get_local_strided_sparse_attention_op(
763
+ n_heads: int,
764
+ max_seq_len:int,
765
+ sparse_block_size: int=128,
766
+ local_blocks: int=4,
767
+ vert_stride: int=4,
768
+ homo_head: bool=False,
769
+ dtype=torch.bfloat16,
770
+ device='cuda',
771
+ active_head_range=None,
772
+ verbose=True,
773
+ **kwargs):
774
+ '''
775
+ :param n_heads: total number of attention heads (regardless of tensor/model parallel)
776
+ :param max_seq_len: max sequence length. Need to be bigger or equal to the length of sequences.
777
+ :param sparse_block_size: sparse block size. Default to 128
778
+ :param local_blocks: number of nearest block to attend to. Default to 4, i.e., attention to previous 4xblock_size tokens.
779
+ :param vert_stride: Default to 4. Meaning
780
+ :param homo_head: if all head shared the same pattern.
781
+ :param active_head_range: tuple of start & end of the heads, e..g, (8, 16). Default to use all heads.
782
+ Mainly for tensor/model parallelization where heads are splitted to different GPUs.
783
+ '''
784
+
785
+ if verbose:
786
+ print((f'> new block_sparse_attn op constructed with config: '
787
+ f'{n_heads=}, {max_seq_len=}, {sparse_block_size=}, {local_blocks=}, '
788
+ f'{vert_stride=}, {homo_head=}, {active_head_range=}, {kwargs=}'))
789
+ # assert math.log2(max_seq_len) % 2 == 0, f"max_seq_len should be power of 2 to be more efficient"
790
+ _, block_sparse_pattern, _ = _get_sparse_attn_mask(n_heads, max_seq_len, max_seq_len, dtype, device,
791
+ BLOCK=sparse_block_size, local_blocks=local_blocks,
792
+ vert_stride=vert_stride, homo_head=homo_head,
793
+ return_dense=False)
794
+ if (not homo_head) and (active_head_range is not None):
795
+ assert isinstance(active_head_range, tuple)
796
+ assert len(active_head_range) == 2, '"active_head_range" should be a tuple of start/end index of the heads.'
797
+ h_start, h_end = active_head_range
798
+ block_sparse_pattern = block_sparse_pattern[h_start:h_end]
799
+ # print(block_sparse_pattern)
800
+ return get_sparse_attn_op(block_sparse_pattern, sparse_block_size, **kwargs)
801
+
802
+
803
+ def get_sparse_attn_op(
804
+ sparse_pattern: torch.tensor,
805
+ sparse_block_size: int=128,
806
+ kernel_block_size=128,
807
+ qkv_format='q,k,v',
808
+ **kwargs):
809
+ '''
810
+ Ccreate a block-sparse op with fixed layout. This is to avoid the need to of create CSR layout and convert it to CSC layout everytime,
811
+ which is very inefficient (use python loops on CPU. PyTorch 1.13 supports CSR->CSC, may help.)
812
+
813
+ :param sparse_pattern: sparse pattern of the blocks. Should be `num_blocks(q) x num_blocks(k)` or `n_heads x num_blocks x num_blocks`.
814
+ This tensor should have lower-triangular matrices on the last 2 dimensions for causal attention
815
+ :param sparse_block_size: sparse block size. Default to 128
816
+ :param kernel_block_size: the tile/block size to launch a triton instance. Default to None, i.e., same as `sparse_block_size`
817
+ :param qkv_format: Choices=['q,k,v', 'q, kv', 'qkv'], i.e., separated q,k,v, or kv packed, or qkv packed. Currently, only 'q,k,v' is supported.
818
+
819
+ :param kwargs: keyward arguments passed to `_forward`
820
+ '''
821
+ # assert qkv_format in ('q,k,v', 'q, kv', 'qkv') # to save from running `concat` at forward/backward
822
+
823
+ assert qkv_format == 'q,k,v'
824
+
825
+ if kernel_block_size is None:
826
+ kernel_block_size = sparse_block_size
827
+ else:
828
+ assert sparse_block_size % kernel_block_size == 0, f"The sparse block size must be a multiple of {kernel_block_size}."
829
+ assert kernel_block_size >=16 and math.log2(kernel_block_size) % 1 == 0, f"block_size must be power of 2 and at least 16, but {kernel_block_size} is given"
830
+
831
+
832
+ # print(f'>> {sparse_pattern.shape=}')
833
+ # print(f'{sparse_pattern=}')
834
+ if sparse_block_size // kernel_block_size > 1:
835
+ _mul = sparse_block_size // kernel_block_size
836
+ # need to consider if block_m and block_n are different
837
+ sparse_pattern = torch.kron(sparse_pattern, sparse_pattern.new_ones(_mul, _mul))
838
+ num_sparse_blocks = sparse_pattern.size(-1)
839
+ block_causal_mask = torch.arange(0, num_sparse_blocks)[:, None] >= torch.arange(0, num_sparse_blocks)[None]
840
+ sparse_pattern *= block_causal_mask.type_as(sparse_pattern)
841
+ # print(f'>> after: {sparse_pattern.shape=}')
842
+ # print(f'{sparse_pattern=}')
843
+
844
+ BLOCK_N = kernel_block_size
845
+ NUM_BLOCK = sparse_pattern.size(-1)
846
+ MAX_SEQ_LEN = kernel_block_size * NUM_BLOCK
847
+
848
+ grand_layout_crow_indices, grand_layout_col_indices = dense_to_crow_col(sparse_pattern)
849
+ # sparse csc layout for backward
850
+ grand_layout_ccol_indices, grand_layout_row_indices = dense_to_ccol_row(sparse_pattern)
851
+
852
+
853
+ # cache GPU backward layout. limit the size to avoid OOM as time goes.
854
+ # For inference, one only needs to cache one block as sequence length always increases
855
+ # Therefore, this cache needs to be reconstructed per every `block_size`-steps.
856
+ # For training/finetune, set to 8 to increase cache hit.
857
+ # Given an input, the block_len will be the same for all layers, so cache is very helpful.
858
+
859
+ max_cache_size = 1 if kwargs.get('inference', False) else 8
860
+
861
+ @lru_cache(maxsize=max_cache_size)
862
+ def get_backward_layout_by_block_len(block_len):
863
+ assert block_len <= NUM_BLOCK
864
+ if block_len == NUM_BLOCK:
865
+ return (grand_layout_ccol_indices, grand_layout_row_indices)
866
+ return dense_to_ccol_row(sparse_pattern[..., :block_len, :block_len])
867
+
868
+ # for debugging
869
+ # if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0:
870
+ # print(f'> {sparse_pattern.cpu().tolist()=}')
871
+ # print('----')
872
+ # print(f'> {grand_layout_crow_indices.cpu().tolist()=}\n{grand_layout_col_indices.cpu().tolist()=}')
873
+
874
+
875
+ # q, k, v separated
876
+ class _q_k_v_sparse_attention(torch.autograd.Function):
877
+ @staticmethod
878
+ def forward(ctx, q, k, v, sm_scale):
879
+ # assert q.shape[2] == 1 or q.shape[2] == k.shape[2]
880
+ # shape constraints
881
+ MIN_BLOCK_SIZE = 16
882
+ assert BLOCK_N >= MIN_BLOCK_SIZE
883
+ BLOCK_M = 16 if q.shape[2] <= 16 else BLOCK_N # BLOCK_M has to be power of 2
884
+
885
+ # this following code only works for causal attention
886
+ K_BLOCKS = triton.cdiv(k.shape[2], kernel_block_size)
887
+ # Q_START_BLOCKS = K_BLOCKS - 1 if q.shape[2] == 1 else 0
888
+ Q_START_BLOCKS = K_BLOCKS - triton.cdiv(q.shape[2], BLOCK_N)
889
+ # print(Q_START_BLOCKS, K_BLOCKS)
890
+
891
+ layout_crow_indices = grand_layout_crow_indices[..., Q_START_BLOCKS:K_BLOCKS+1]
892
+ layout_col_indices = grand_layout_col_indices
893
+ # print(BLOCK_M, BLOCK_N, Q_START_BLOCKS, K_BLOCKS+1, layout_crow_indices, layout_col_indices)
894
+
895
+ return _forward(ctx, q, k, v, layout_crow_indices, layout_col_indices, sm_scale, BLOCK_M, BLOCK_N,
896
+ **kwargs
897
+ )
898
+ @staticmethod
899
+ def backward(ctx, do):
900
+ q, k = ctx.saved_tensors[:2]
901
+ assert q.shape[2] == k.shape[2], '> currently backward can only be done if q, k have same length. Contact @EricLin if you need it.'
902
+ # assume q, k have same length
903
+ block_len = triton.cdiv(do.shape[2], kernel_block_size)
904
+ backward_layout = get_backward_layout_by_block_len(block_len)
905
+ return _backward(ctx, do, *backward_layout)[:4]
906
+
907
+
908
+ def _q_k_v_sparse_attention_fn(*args):
909
+ return _q_k_v_sparse_attention.apply(*args)
910
+
911
+ _q_k_v_sparse_attention_fn.sparse_pattern = sparse_pattern
912
+ _q_k_v_sparse_attention_fn.grand_layout_crow_indices = grand_layout_crow_indices
913
+ _q_k_v_sparse_attention_fn.grand_layout_col_indices = grand_layout_col_indices
914
+ _q_k_v_sparse_attention_fn.grand_layout_ccol_indices = grand_layout_ccol_indices
915
+ _q_k_v_sparse_attention_fn.grand_layout_row_indices = grand_layout_row_indices
916
+
917
+ return _q_k_v_sparse_attention_fn
918
+
919
+ ###########################################################
920
+ ###########################################################
921
+
922
+ ###########################################################
923
+ ################ Inference Kernels ########################
924
+ ###########################################################
925
+
926
+ def blocksparse_flash_attn_padded_fwd(
927
+ q, k, v, # (batch, tokens, n_heads, head_size)
928
+ sm_scale,
929
+ sparse_layout,
930
+ *,
931
+ left_paddings = None,
932
+ seqlens = None,
933
+ block_size = 64,
934
+ max_seqlen = None
935
+ ):
936
+ '''
937
+ q, k, v: (batch, tokens, n_heads/n_kv_heads, head_size)
938
+ left_paddings: (batch, ), number of left paddings for each sample.
939
+ seqlens: can be used to specify right padding. No need to specify if left_paddings is used.
940
+ '''
941
+ batches, q_len, n_heads, head_size = q.shape
942
+ _, k_len, n_kv_heads, _ = k.shape
943
+
944
+
945
+ assert q.dim() == k.dim() == v.dim() == 4
946
+ assert q.size(2) % k.size(2) == 0
947
+ assert q.size(0) == k.size(0) and q.size(3) == k.size(3)
948
+ assert k.shape == v.shape # TODO: allow diff head_size for k, v
949
+ assert q_len == 1 or q_len == k_len, \
950
+ f'q length can only 1 for decoding for same as k length for prefilling.'
951
+
952
+ q_k_ratio = q.size(2) // k.size(2)
953
+
954
+ if max_seqlen:
955
+ assert k.size(1) <= max_seqlen, f'k has seqlen {k.size(1)} while max sequence length is set to {max_seqlen}.'
956
+
957
+ # paddings always has zero output, a little slower than using empty
958
+ out = q.new_zeros(q.shape)
959
+
960
+ layout_crow_indices, layout_col_indices = sparse_layout
961
+ block_d = triton.next_power_of_2(head_size)
962
+
963
+ if left_paddings is not None:
964
+ assert left_paddings.shape == (batches,)
965
+ k_batch_starts = left_paddings.to(q.device, dtype=torch.int32).contiguous()
966
+ else:
967
+ k_batch_starts = torch.zeros((batches,), dtype=torch.int32, device=q.device)
968
+
969
+ if seqlens is not None:
970
+ k_batch_ends = k_batch_starts + seqlens.type_as(k_batch_starts)
971
+ assert k_batch_ends.max() <= k_len, f'seqlens (+left_paddings if any) exceeds seqlen.'
972
+ else:
973
+ k_batch_ends = torch.zeros_like(k_batch_starts) + k_len
974
+
975
+ if q_len == 1:
976
+ q_batch_starts = torch.zeros_like(k_batch_starts)
977
+ q_batch_ends = q_batch_starts + 1
978
+ else:
979
+ q_batch_starts = k_batch_starts
980
+ q_batch_ends = k_batch_ends
981
+
982
+ # switch to use cpu to avoid too many kernel lauch when iterate over
983
+ q_lens = (q_batch_ends - q_batch_starts).cpu()
984
+ n_blocks = (q_lens + block_size - 1) // block_size
985
+
986
+ q_batch_ids = torch.tensor([i for i, n in enumerate(n_blocks) for _ in range(n)],
987
+ dtype=q_batch_starts.dtype,
988
+ device=q_batch_starts.device)
989
+ q_start_sids = torch.tensor([i * block_size for n in n_blocks for i in range(n)],
990
+ dtype=q_batch_starts.dtype,
991
+ device=q_batch_starts.device)
992
+
993
+ grid = (len(q_start_sids), n_heads)
994
+
995
+ with torch.cuda.device(q.device.index):
996
+ _fwd_kernel_batch_inference[grid](
997
+ q, k, v, out,
998
+ sm_scale,
999
+ q_batch_starts,
1000
+ q_batch_ends,
1001
+ k_batch_starts,
1002
+ k_batch_ends,
1003
+ q_batch_ids,
1004
+ q_start_sids,
1005
+
1006
+ *q.stride(),
1007
+ *k.stride(),
1008
+ *v.stride(),
1009
+ *out.stride(),
1010
+
1011
+ layout_crow_indices,
1012
+ layout_col_indices,
1013
+ *layout_crow_indices.stride(),
1014
+ *layout_col_indices.stride(),
1015
+
1016
+ q_k_ratio,
1017
+ HAS_BATCH_DIM = True,
1018
+ D_HEAD = head_size,
1019
+ BLOCK_M = block_size,
1020
+ BLOCK_N = block_size,
1021
+ BLOCK_D = block_d,
1022
+ BLOCK_M_LOADING = 16 if q_len == 1 else block_size, # smaller for decoding
1023
+ EVEN_D = block_d == head_size,
1024
+ num_warps = 1 if q_len == 1 else 4,
1025
+ num_stages = 1
1026
+ )
1027
+
1028
+
1029
+ return out
1030
+
1031
+
1032
+ def blocksparse_flash_attn_varlen_fwd(
1033
+ q, k, v, # (#tokens, n_heads, head_size)
1034
+ cu_seqlens_k,
1035
+ cu_seqlens_q,
1036
+ sm_scale,
1037
+ sparse_layout,
1038
+ *,
1039
+ block_size=64,
1040
+ max_seqlen = None
1041
+ ):
1042
+ # split q to blocks
1043
+ _, n_heads, head_size = q.shape
1044
+ batch_size = cu_seqlens_k.size(0) - 1
1045
+
1046
+
1047
+ # print(f'> {q.shape=}, {k.shape=}')
1048
+ assert q.dim() == k.dim() == v.dim() == 3
1049
+ assert q.size(1) % k.size(1) == 0
1050
+ assert q.size(2) == k.size(2)
1051
+ assert k.shape == v.shape # TODO: allow diff head_size for k, v
1052
+ assert cu_seqlens_k.dim() == 1
1053
+
1054
+ q_k_ratio = q.size(1) // k.size(1)
1055
+
1056
+ if cu_seqlens_q is None:
1057
+ if q.size(0) == batch_size: # decoding only
1058
+ cu_seqlens_q = torch.arange(0, batch_size + 1,
1059
+ dtype=cu_seqlens_k.dtype,
1060
+ device=cu_seqlens_k.device)
1061
+ elif q.size(0) == k.size(0):
1062
+ cu_seqlens_q = cu_seqlens_k
1063
+ else:
1064
+ raise ValueError('cu_seqlens_q must be specified if it is mix of prefilling and decoding.')
1065
+ else:
1066
+ assert cu_seqlens_k.size(0) == cu_seqlens_q.size(0)
1067
+
1068
+ # switch to use cpu to avoid too many kernel lauch when iterate over
1069
+ q_lens = (cu_seqlens_q[1:] - cu_seqlens_q[:-1]).cpu()
1070
+ k_lens = (cu_seqlens_k[1:] - cu_seqlens_k[:-1]).cpu()
1071
+
1072
+ assert torch.logical_or(q_lens == 1, k_lens == q_lens).all(), \
1073
+ 'length of q should either be 1 (decoding) or same as k (prefilling).'
1074
+
1075
+ if max_seqlen:
1076
+ assert k_lens.max() <= max_seqlen
1077
+
1078
+ n_blocks = (q_lens + block_size - 1) // block_size
1079
+
1080
+ q_batch_ids = torch.tensor([i for i, n in enumerate(n_blocks) for _ in range(n)],
1081
+ dtype=cu_seqlens_q.dtype,
1082
+ device=cu_seqlens_q.device)
1083
+ q_start_sids = torch.tensor([i * block_size for n in n_blocks for i in range(n)],
1084
+ dtype=cu_seqlens_q.dtype,
1085
+ device=cu_seqlens_q.device)
1086
+
1087
+
1088
+ out = q.new_empty(q.shape)
1089
+ cu_seqlens_q = cu_seqlens_q.contiguous()
1090
+ cu_seqlens_k = cu_seqlens_k.contiguous()
1091
+
1092
+ layout_crow_indices, layout_col_indices = sparse_layout
1093
+ block_d = triton.next_power_of_2(head_size)
1094
+
1095
+ decoding_only = (q_lens == 1).all()
1096
+
1097
+ grid = (len(q_start_sids), n_heads)
1098
+
1099
+ with torch.cuda.device(q.device.index):
1100
+ _fwd_kernel_batch_inference[grid](
1101
+ q, k, v, out,
1102
+ sm_scale,
1103
+ cu_seqlens_q[:-1],
1104
+ cu_seqlens_q[1:],
1105
+ cu_seqlens_k[:-1],
1106
+ cu_seqlens_k[1:],
1107
+ q_batch_ids,
1108
+ q_start_sids,
1109
+
1110
+ 0, *q.stride(),
1111
+ 0, *k.stride(),
1112
+ 0, *v.stride(),
1113
+ 0, *out.stride(),
1114
+
1115
+ layout_crow_indices,
1116
+ layout_col_indices,
1117
+ *layout_crow_indices.stride(),
1118
+ *layout_col_indices.stride(),
1119
+
1120
+ q_k_ratio,
1121
+ HAS_BATCH_DIM = False,
1122
+ D_HEAD = head_size,
1123
+ BLOCK_M = block_size,
1124
+ BLOCK_N = block_size,
1125
+ BLOCK_D = block_d,
1126
+ BLOCK_M_LOADING = 16 if decoding_only else block_size, # smaller for decoding
1127
+ EVEN_D = block_d == head_size,
1128
+ num_warps = 1 if decoding_only else 4,
1129
+ num_stages = 3
1130
+ )
1131
+
1132
+ return out
1133
+
1134
+
1135
+ @triton.jit
1136
+ def _fwd_kernel_inner(
1137
+ acc, l_i, m_i,
1138
+ q, Q,
1139
+ k_block_col_idx,
1140
+ layout_col_ptr,
1141
+ layout_col_stride_h, layout_col_stride_m,
1142
+ k_ptrs,
1143
+ v_ptrs,
1144
+ off_h, offs_m, offs_n, offs_d,
1145
+ stride_kt, stride_vt,
1146
+ sm_scale,
1147
+ k_seqlen,
1148
+ past_len,
1149
+ LAST_K_BLOCK: tl.constexpr,
1150
+ BLOCK_M_LOADING: tl.constexpr,
1151
+ BLOCK_N: tl.constexpr,
1152
+ D_HEAD: tl.constexpr,
1153
+ EVEN_D: tl.constexpr,
1154
+ M_LT_N: tl.constexpr
1155
+ ):
1156
+ k_block_id = tl.load(layout_col_ptr + off_h * layout_col_stride_h + k_block_col_idx * layout_col_stride_m).to(tl.int32)
1157
+ start_n = k_block_id * BLOCK_N
1158
+ # -- compute qk ----
1159
+ if LAST_K_BLOCK:
1160
+ if EVEN_D:
1161
+ k = tl.load(k_ptrs + start_n * stride_kt,
1162
+ mask=offs_n[None, :] + start_n < k_seqlen)
1163
+ else:
1164
+ # mask = mask & (offs_d[:, ])
1165
+ k = tl.load(k_ptrs + start_n * stride_kt,
1166
+ mask=(offs_n[None, :] + start_n < k_seqlen) & (offs_d[:, None] < D_HEAD))
1167
+ else:
1168
+ if EVEN_D:
1169
+ k = tl.load(k_ptrs + start_n * stride_kt)
1170
+ else:
1171
+ k = tl.load(k_ptrs + start_n * stride_kt,
1172
+ mask=offs_d[:, None] < D_HEAD)
1173
+
1174
+
1175
+ qk = tl.zeros([BLOCK_M_LOADING, BLOCK_N], dtype=tl.float32)
1176
+ qk += tl.dot(q, k)
1177
+
1178
+ qk *= sm_scale
1179
+
1180
+ # the following is needed only when LAST_K_BLOCK or BLOCK_M < BLOCK_N
1181
+ if LAST_K_BLOCK | M_LT_N:
1182
+ qk += tl.where(offs_m[:, None] + past_len >= (start_n + offs_n[None, :]), 0, float('-inf'))
1183
+
1184
+ # -- compute m_ij, p, l_ij
1185
+ m_ij = tl.max(qk, 1)
1186
+ p = tl.exp(qk - m_ij[:, None])
1187
+
1188
+ l_ij = tl.sum(p, 1)
1189
+ # -- update m_i and l_i
1190
+ m_i_new = tl.maximum(m_i, m_ij)
1191
+ alpha = tl.exp(m_i - m_i_new)
1192
+ beta = tl.exp(m_ij - m_i_new)
1193
+ l_i_new = alpha * l_i + beta * l_ij
1194
+ # -- update output accumulator --
1195
+ # scale p
1196
+ p_scale = beta / l_i_new
1197
+ p = p * p_scale[:, None]
1198
+ # scale acc
1199
+ acc_scale = l_i / l_i_new * alpha
1200
+ acc = acc * acc_scale[:, None]
1201
+
1202
+ p = p.to(Q.dtype.element_ty)
1203
+ # update acc
1204
+ if LAST_K_BLOCK:
1205
+ if EVEN_D:
1206
+ v = tl.load(v_ptrs + start_n * stride_vt,
1207
+ mask=offs_n[:, None] + start_n < k_seqlen)
1208
+ else:
1209
+ v = tl.load(v_ptrs + start_n * stride_vt,
1210
+ mask=(offs_n[:, None] + start_n < k_seqlen) & (offs_d[None, :] < D_HEAD))
1211
+ else:
1212
+ if EVEN_D:
1213
+ v = tl.load(v_ptrs + start_n * stride_vt)
1214
+ else:
1215
+ v = tl.load(v_ptrs + start_n * stride_vt,
1216
+ mask=offs_d[None, :] < D_HEAD)
1217
+
1218
+ acc += tl.dot(p, v)
1219
+ # update m_i and l_i
1220
+ l_i = l_i_new
1221
+ m_i = m_i_new
1222
+ return acc, l_i, m_i
1223
+
1224
+
1225
+ @triton.heuristics(
1226
+ {
1227
+ 'M_LT_N': lambda kwargs: kwargs['BLOCK_M'] < kwargs['BLOCK_N'],
1228
+ }
1229
+ )
1230
+ @triton.jit
1231
+ def _fwd_kernel_batch_inference(
1232
+ Q, K, V, Out,
1233
+
1234
+ sm_scale,
1235
+ q_batch_starts,
1236
+ q_batch_ends,
1237
+ k_batch_starts,
1238
+ k_batch_ends,
1239
+ q_batch_ids,
1240
+ q_start_sids,
1241
+
1242
+ stride_qb, stride_qt, stride_qh, stride_qd,
1243
+ stride_kb, stride_kt, stride_kh, stride_kd,
1244
+ stride_vb, stride_vt, stride_vh, stride_vd,
1245
+ stride_ob, stride_ot, stride_oh, stride_od,
1246
+
1247
+ layout_crow_ptr,
1248
+ layout_col_ptr,
1249
+ layout_crow_stride_h, layout_crow_stride_m,
1250
+ layout_col_stride_h, layout_col_stride_m,
1251
+
1252
+ q_k_ratio,
1253
+
1254
+ HAS_BATCH_DIM: tl.constexpr,
1255
+ D_HEAD: tl.constexpr,
1256
+ BLOCK_M: tl.constexpr,
1257
+ BLOCK_N: tl.constexpr,
1258
+ BLOCK_D: tl.constexpr,
1259
+ BLOCK_M_LOADING: tl.constexpr,
1260
+ EVEN_D: tl.constexpr,
1261
+ M_LT_N: tl.constexpr
1262
+ ):
1263
+ '''
1264
+ NOTATION:
1265
+ pid: position id
1266
+ sid: storage id
1267
+ sbid: storage block id
1268
+ pbid: position block id
1269
+ offs_m, offs_n: storage offsets of m-dim(q, row) and n-dim(k, col)
1270
+
1271
+ q and blocks in KV needs to be contiguous
1272
+
1273
+ Arguments:
1274
+ kv_seq_lens: for compute past_len
1275
+ kv_storage_offsets: similar to block_tables in vllm, except it is dynamic.
1276
+ TODO: fix this
1277
+
1278
+ TODO:
1279
+ Optimize grouped-attn
1280
+
1281
+ CUDA graph support issue
1282
+ 1. grid is dynamic: vllm set up multiple cuda graph in decoding phase, with diff max token size (16, 32, ...)
1283
+ since we mix prompt and decoing phase here, it can be more complex.
1284
+ need to set up diff cuda-graph for diff (off_zm, off_z)
1285
+
1286
+ # indeed, q_batch_ids can be padded to maximum number of grid[0], i.e., assume all decoding
1287
+ therefore, cu_seqlens_q, kv_seq_lens
1288
+
1289
+ '''
1290
+ off_zm = tl.program_id(0)
1291
+ off_h = tl.program_id(1)
1292
+
1293
+ off_h_for_kv = off_h // q_k_ratio
1294
+ off_z = tl.load(q_batch_ids + off_zm).to(tl.int32) # [0, 0, 0, 1]
1295
+ q_start_sid = tl.load(q_start_sids + off_zm)
1296
+ start_m = q_start_sid // BLOCK_M
1297
+
1298
+ if HAS_BATCH_DIM:
1299
+ Q += off_z * stride_qb
1300
+ K += off_z * stride_kb
1301
+ V += off_z * stride_vb
1302
+ Out += off_z * stride_ob
1303
+
1304
+ offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M_LOADING)
1305
+ offs_n = tl.arange(0, BLOCK_N)
1306
+ offs_d = tl.arange(0, BLOCK_D)
1307
+
1308
+ q_cu_start = tl.load(q_batch_starts + off_z).to(tl.int32)
1309
+ q_seqlen = tl.load(q_batch_ends + off_z).to(tl.int32) - q_cu_start
1310
+
1311
+ k_cu_start = tl.load(k_batch_starts + off_z).to(tl.int32)
1312
+ k_seqlen = tl.load(k_batch_ends + off_z).to(tl.int32) - k_cu_start
1313
+
1314
+ past_len = k_seqlen - q_seqlen
1315
+
1316
+ Q += q_cu_start * stride_qt + off_h * stride_qh
1317
+ K += k_cu_start * stride_kt + off_h_for_kv * stride_kh
1318
+ V += k_cu_start * stride_vt + off_h_for_kv * stride_vh
1319
+ Out += q_cu_start * stride_ot + off_h * stride_oh
1320
+
1321
+ q_pbid = (past_len + q_start_sid) // BLOCK_M
1322
+
1323
+ if EVEN_D:
1324
+ q = tl.load(Q + offs_m[:, None] * stride_qt + offs_d[None, :] * stride_qd,
1325
+ mask=offs_m[:, None] < q_seqlen)
1326
+ else:
1327
+ q = tl.load(Q + offs_m[:, None] * stride_qt + offs_d[None, :] * stride_qd,
1328
+ mask=(offs_m[:, None] < q_seqlen) & (offs_d[None, :] < D_HEAD),
1329
+ other=0)
1330
+
1331
+ sparse_crow_ptr = layout_crow_ptr + off_h * layout_crow_stride_h + q_pbid * layout_crow_stride_m
1332
+
1333
+ # TODO: load at once, supported in new Triton
1334
+ k_block_start = tl.load(sparse_crow_ptr).to(tl.int32)
1335
+ k_block_end = tl.load(sparse_crow_ptr + 1).to(tl.int32)
1336
+
1337
+ m_i = tl.zeros([BLOCK_M_LOADING], dtype=tl.float32) - float('inf')
1338
+ l_i = tl.zeros([BLOCK_M_LOADING], dtype=tl.float32)
1339
+ acc = tl.zeros([BLOCK_M_LOADING, BLOCK_D], dtype=tl.float32)
1340
+
1341
+ k_ptrs = K + offs_n[None, :] * stride_kt + offs_d[:, None] * stride_kd
1342
+ v_ptrs = V + offs_n[:, None] * stride_vt + offs_d[None, :] * stride_vd
1343
+
1344
+ for k_block_col_idx in range(k_block_start, k_block_end - 1):
1345
+ acc, l_i, m_i = _fwd_kernel_inner(
1346
+ acc, l_i, m_i,
1347
+ q, Q,
1348
+ k_block_col_idx,
1349
+ layout_col_ptr,
1350
+ layout_col_stride_h, layout_col_stride_m,
1351
+ k_ptrs,
1352
+ v_ptrs,
1353
+ off_h, offs_m, offs_n, offs_d,
1354
+ stride_kt, stride_vt,
1355
+ sm_scale,
1356
+ k_seqlen,
1357
+ past_len,
1358
+ False,
1359
+ BLOCK_M_LOADING,
1360
+ BLOCK_N,
1361
+ D_HEAD,
1362
+ EVEN_D,
1363
+ M_LT_N
1364
+ )
1365
+
1366
+ acc, l_i, m_i = _fwd_kernel_inner(
1367
+ acc, l_i, m_i,
1368
+ q, Q,
1369
+ k_block_end - 1,
1370
+ layout_col_ptr,
1371
+ layout_col_stride_h, layout_col_stride_m,
1372
+ k_ptrs,
1373
+ v_ptrs,
1374
+ off_h, offs_m, offs_n, offs_d,
1375
+ stride_kt, stride_vt,
1376
+ sm_scale,
1377
+ k_seqlen,
1378
+ past_len,
1379
+ True,
1380
+ BLOCK_M_LOADING,
1381
+ BLOCK_N,
1382
+ D_HEAD,
1383
+ EVEN_D,
1384
+ M_LT_N
1385
+ )
1386
+
1387
+ # write output
1388
+ if EVEN_D:
1389
+ tl.store(Out + offs_m[:, None] * stride_ot + offs_d[None, :] * stride_od, acc,
1390
+ mask=offs_m[:, None] < q_seqlen)
1391
+ else:
1392
+ tl.store(Out + offs_m[:, None] * stride_ot + offs_d[None, :] * stride_od, acc,
1393
+ mask=(offs_m[:, None] < q_seqlen) & (offs_d[None, :] < D_HEAD))
1394
+
1395
+
1396
+ ###########################################################
1397
+ ###########################################################
1398
+
1399
+ ###########################################################
1400
+ ################## Testing Utilities ######################
1401
+ ###########################################################
1402
+
1403
+
1404
+ def torch_attention(q, k, v, attn_mask=None, sm_scale=None, block_attn_mask=None, block_size=128, do=None):
1405
+ '''
1406
+ q, k, v: shape=(batch, n_heads, seq, dim)
1407
+ '''
1408
+ # for verification
1409
+ if sm_scale is None:
1410
+ sm_scale = math.sqrt(float(q.size(-1)))
1411
+
1412
+ if block_attn_mask is not None:
1413
+ assert attn_mask is None
1414
+ outs = []
1415
+ for s in range(0, q.size(2), block_size):
1416
+ e = min(s + block_size, q.size(2))
1417
+ q_block = q[:, :, s:e]
1418
+ attn = torch.einsum('bhmd,bhnd->bhmn', q_block, k[:, :, :e]).float() * sm_scale
1419
+ mask = block_attn_mask[..., s // block_size, : (s // block_size + 1)]
1420
+ mask = torch.kron(mask, torch.ones(block_size, block_size, device=mask.device))
1421
+ mask[..., :, s:].masked_fill_(torch.arange(0, block_size)[:, None] <= torch.arange(0, block_size)[None, :], 0)
1422
+ attn = attn.masked_fill((1 - mask).bool(), float('-inf'))
1423
+ attn = attn.softmax(-1)
1424
+ out = torch.einsum('bhmn,bhnd->bhmd', attn.type_as(v), v[:, :, :e])
1425
+ outs.append(out)
1426
+ torch_output = torch.cat(outs, dim=2)
1427
+ else:
1428
+ attn = torch.einsum('bhmd,bhnd->bhmn', q, k).float() * sm_scale
1429
+ # import ipdb; ipdb.set_trace()
1430
+ if attn_mask is not None:
1431
+ attn = attn.masked_fill((1 - attn_mask).bool(), float('-inf'))
1432
+ # print(f'> torch attn: {attn.exp().sum(-1)=}')
1433
+
1434
+ attn = attn.softmax(-1)
1435
+ if do is not None:
1436
+ dv = torch.einsum('bhqk,bhqd->bhkd', attn.type_as(do), do)
1437
+ print(f'> torch_attn computed dv: {dv=}')
1438
+ torch_output = torch.einsum('bhmn,bhnd->bhmd', attn.type_as(v), v)
1439
+ return torch_output
1440
+
1441
+ ###########################################################
1442
+ ###########################################################
1443
+
1444
+ ###########################################################
1445
+ #################### Unit Tests ###########################
1446
+ ###########################################################
1447
+
1448
+
1449
+ @pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [(2, 8, 2048, 128), (1, 4, 4096, 64)])
1450
+ def test_op(Z, H, N_CTX, D_HEAD, Q_LEN=None, dtype=torch.bfloat16, homo_head=True, kernel_block_size=None, sparse_block_size=128, backward=True,
1451
+ sparse_attention_fn=None, local_blocks=4, vert_stride=4, sm_scale=None, max_length=None):
1452
+ Q_LEN = Q_LEN or N_CTX
1453
+ torch.manual_seed(20)
1454
+ q = torch.empty((Z, H, Q_LEN, D_HEAD), dtype=dtype, device='cuda').normal_(mean=0, std=.5) # .requires_grad_()
1455
+ k = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device='cuda').normal_(mean=0, std=.5) # .requires_grad_()
1456
+ v = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device='cuda').normal_(mean=0, std=.5) # .requires_grad_()
1457
+
1458
+ if sm_scale is None:
1459
+ sm_scale = 1. / math.sqrt(D_HEAD)
1460
+
1461
+ # for debugging
1462
+ # print(f'>> {q.shape=}, {k.shape=}, {v.shape=}, {homo_head=}, {kernel_block_size=}, {sparse_block_size=}, {local_blocks=}, {vert_stride=}')
1463
+ sm_scale = 0.0078125
1464
+ if backward:
1465
+ q.requires_grad_(), k.requires_grad_(), v.requires_grad_()
1466
+
1467
+ # qkv = torch.empty((Z, N_CTX, 3*H*D_HEAD), dtype=dtype, device='cuda').normal_(mean=0, std=.5)
1468
+ # q = qkv[..., :H*D_HEAD]
1469
+ # k = qkv[..., H*D_HEAD:2*H*D_HEAD]
1470
+ # v = qkv[..., 2*H*D_HEAD:]
1471
+ # q = q.view(Z, N_CTX, H, -1).permute(0, 2, 1, 3)
1472
+ # k = k.view(Z, N_CTX, H, -1).permute(0, 2, 1, 3)
1473
+ # v = v.view(Z, N_CTX, H, -1).permute(0, 2, 1, 3)
1474
+
1475
+ # if Q_LEN and Q_LEN < N_CTX:
1476
+ # q = q[:, :, -Q_LEN:] # .contiguous()
1477
+
1478
+ # q = q.requires_grad_()
1479
+ # k = k.requires_grad_()
1480
+ # v = v.requires_grad_()
1481
+
1482
+ dout = torch.randn_like(q).contiguous()
1483
+
1484
+ # dout = torch.eye(N_CTX)[:, :D_HEAD][None, None].expand_as(q).type_as(q).contiguous()
1485
+ # print(dout)
1486
+
1487
+ mask_csr, _, mask_dense = get_sparse_attn_mask(q, N_CTX, BLOCK=sparse_block_size,
1488
+ local_blocks=local_blocks, vert_stride=vert_stride, homo_head=homo_head, return_dense=True)
1489
+
1490
+ if sparse_attention_fn is None:
1491
+ sparse_attention_fn = get_local_strided_sparse_attention_op(H, N_CTX,
1492
+ sparse_block_size=sparse_block_size,
1493
+ local_blocks=local_blocks,
1494
+ vert_stride=vert_stride,
1495
+ homo_head=homo_head,
1496
+ device=q.device,
1497
+ dtype=q.dtype,
1498
+ kernel_block_size=kernel_block_size)
1499
+ # reference implementation
1500
+ ref_out = torch_attention(q, k, v, mask_dense, sm_scale)
1501
+
1502
+ # lengths = torch.full((Z,), fill_value=N_CTX, device='cuda')
1503
+ # cu_seqlens = torch.zeros((Z + 1,), device='cuda', dtype=torch.int32)
1504
+ # cu_seqlens[1:] = lengths.cumsum(0)
1505
+ # # qkv = torch.randn((Z * N_CTX, 3, H, D_HEAD), dtype=dtype, device='cuda', requires_grad=True)
1506
+
1507
+ # qkv_list = list(map(lambda x: x.permute(0, 2, 1, 3).contiguous().view(Z * N_CTX, 1, H, D_HEAD), [q, k, v]))
1508
+ # qkv = torch.cat(qkv_list, dim=1)
1509
+ # ref_out0 = flash_attn_func(qkv, cu_seqlens, dropout_p=0, max_s=N_CTX, softmax_scale=sm_scale, causal=True)
1510
+ # ref_out = ref_out0.view(Z, N_CTX, H, D_HEAD).permute(0, 2, 1, 3).contiguous()
1511
+
1512
+
1513
+ if backward:
1514
+ ref_out.backward(dout)
1515
+ ref_dv, v.grad = v.grad.clone(), None
1516
+ ref_dk, k.grad = k.grad.clone(), None
1517
+ ref_dq, q.grad = q.grad.clone(), None
1518
+
1519
+ tri_out = sparse_attention_fn(q, k, v, sm_scale)
1520
+
1521
+ decimal = 1 if dtype == torch.bfloat16 else 2
1522
+ assert torch.allclose(ref_out.cpu(), tri_out.cpu(), atol=1e-2, rtol=0), f'>> {ref_out[0, 0, :, 0].tolist()=}\n\n{tri_out[0, 0, :, 0].tolist()=}'
1523
+
1524
+ if backward:
1525
+ tri_out.backward(dout)
1526
+ tri_dv, v.grad = v.grad.clone(), None
1527
+ tri_dk, k.grad = k.grad.clone(), None
1528
+ tri_dq, q.grad = q.grad.clone(), None
1529
+
1530
+ if backward:
1531
+ assert torch.allclose(ref_dv, tri_dv, atol=1e-2, rtol=1e-2)
1532
+ assert torch.allclose(ref_dk, tri_dk, atol=1e-2, rtol=0)
1533
+ assert torch.allclose(ref_dq, tri_dq, atol=1e-2, rtol=0)
1534
+
1535
+ print(f'> test passed: {Z=}, {H=}, {N_CTX=}, {D_HEAD=}, {Q_LEN=}, {dtype=}, {homo_head=}, {sparse_block_size=}')
1536
+
1537
+ ###########################################################
1538
+
1539
+ if __name__ == '__main__':
1540
+
1541
+ GPU_TYPE = os.popen('nvidia-smi --query-gpu=name --format=csv | tail -n 1').read().strip()
1542
+ # print(GPU_TYPE)
1543
+ support_backward = True # 'A100' in GPU_TYPE. Wasn't supportted in consumer A1000.
1544
+
1545
+ ###############
1546
+ # benchmarking
1547
+
1548
+ HAS_DENSE_TRITON_FLASH = False
1549
+ # try:
1550
+ # from triton.ops.flash_attention import attention as triton_attention
1551
+ # HAS_DENSE_TRITON_FLASH = True
1552
+ # except:
1553
+ # HAS_DENSE_TRITON_FLASH = False
1554
+ # print('> cannot import Trition flash attn')
1555
+
1556
+ try:
1557
+ from flash_attn.flash_attn_interface import flash_attn_func, flash_attn_unpadded_func
1558
+ HAS_FLASH = True
1559
+ except BaseException:
1560
+ HAS_FLASH = False
1561
+ print('> cannot import flash_attn')
1562
+
1563
+
1564
+ # BATCH, N_HEADS, N_CTX, D_HEAD = 4, 48, 4096, 64
1565
+ BATCH, N_HEADS, N_CTX, D_HEAD = 4, 32, 4096, 128 # 6.7B model, with 4k len
1566
+ # BATCH, N_HEADS, N_CTX, D_HEAD = 4, 16, 4096, 128 # 204m model
1567
+
1568
+ BLOCK_SIZE = 64
1569
+ LOCAl_BLOCKS = 8 # 4
1570
+ VERT_STRIDE = 1 # 16 # 8
1571
+ HOMO_HEAD = False
1572
+ sparse_type = 'home' if HOMO_HEAD else 'hetero'
1573
+ dtype = torch.bfloat16
1574
+
1575
+
1576
+ modes = ['fwd', 'bwd'] if support_backward else ['fwd']
1577
+
1578
+ configs = [triton.testing.Benchmark(
1579
+ x_names=['SEQ_LEN'],
1580
+ x_vals=[2**i for i in range(8, 16)],
1581
+ line_arg='provider',
1582
+ line_vals=(['triton'] if HAS_DENSE_TRITON_FLASH else []) + (['flash'] if HAS_FLASH else []) + ['triton_sparse'],
1583
+ line_names=(['Triton-Dense'] if HAS_DENSE_TRITON_FLASH else []) + (['Flash-Dense'] if HAS_FLASH else []) + ['Triton-Sparse'],
1584
+ styles=[('red', '-'), ('blue', '-'), ('green', '-')],
1585
+ ylabel='ms',
1586
+ plot_name=f'fused-attention-batch{BATCH}-head{N_HEADS}-d{D_HEAD}-sparse-local{LOCAl_BLOCKS}-vert{VERT_STRIDE}-{sparse_type}-{dtype}-{mode}',
1587
+ args={'H': N_HEADS, 'BATCH': BATCH, 'D_HEAD': D_HEAD, 'dtype': dtype, 'mode': mode}
1588
+ ) for mode in modes]
1589
+
1590
+
1591
+ @triton.testing.perf_report(configs)
1592
+ def bench_flash_attention(BATCH, H, SEQ_LEN, D_HEAD, mode, provider, dtype=torch.bfloat16, device='cuda', sparse_attention_fn=None):
1593
+ assert mode in ['fwd', 'bwd']
1594
+ warmup = 25
1595
+ rep = 100
1596
+ N_CTX = SEQ_LEN
1597
+ if provider == 'triton':
1598
+ q = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device='cuda', requires_grad=True)
1599
+ k = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device='cuda', requires_grad=True)
1600
+ v = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device='cuda', requires_grad=True)
1601
+ sm_scale = 1.3
1602
+ fn = lambda: triton_attention(q, k, v, sm_scale)
1603
+ if mode == 'bwd':
1604
+ o = fn()
1605
+ do = torch.randn_like(o)
1606
+ fn = lambda: o.backward(do, retain_graph=True)
1607
+ ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep)
1608
+ return ms
1609
+ if provider == 'triton_sparse':
1610
+ q = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device='cuda', requires_grad=True)
1611
+ k = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device='cuda', requires_grad=True)
1612
+ v = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device='cuda', requires_grad=True)
1613
+ sm_scale = 1.3
1614
+ # q_pos = torch.arange(N_CTX // BLOCK, device='cuda')[:, None]
1615
+ # k_pos = torch.arange(N_CTX // BLOCK, device='cuda')[None]
1616
+ # local_blocks = 4 # num_block per attn, block_size is tied to BLOCK
1617
+ # vert_stride =N_CTX + 1 # 4
1618
+ # mask_vert_strided = torch.arange(N_CTX // BLOCK, device='cuda') % vert_stride == vert_stride - 1
1619
+ # mask_dense = ((q_pos >= k_pos) & ((q_pos - k_pos < local_blocks) | mask_vert_strided)).type_as(q)
1620
+ # mask = mask_dense.to_sparse_csr()
1621
+ # mask_csr, _ = get_sparse_attn_mask(q, N_CTX, BLOCK=BLOCK, local_blocks=LOCAl_BLOCKS, vert_stride=VERT_STRIDE, homo_head=HOMO_HEAD)
1622
+
1623
+ if sparse_attention_fn is None:
1624
+ # sparse_attention_fn = sparse_attention
1625
+ sparse_attention_fn = get_local_strided_sparse_attention_op(H, SEQ_LEN,
1626
+ local_blocks=LOCAl_BLOCKS,
1627
+ vert_stride=VERT_STRIDE,
1628
+ homo_head=HOMO_HEAD,
1629
+ sparse_block_size=BLOCK_SIZE,
1630
+ kernel_block_size=BLOCK_SIZE,
1631
+ device=q.device)
1632
+ # sparse_attention_fn = sparse_attention_factory(128, 128, num_warps=8)
1633
+
1634
+ # fn = lambda: sparse_attention_fn(q, k, v, mask_csr[0], mask_csr[1], sm_scale)
1635
+ fn = lambda: sparse_attention_fn(q, k, v, sm_scale)
1636
+ if mode == 'bwd':
1637
+ o = fn()
1638
+ do = torch.randn_like(o)
1639
+ fn = lambda: o.backward(do, retain_graph=True)
1640
+ ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep)
1641
+ return ms
1642
+ if provider == 'flash':
1643
+ lengths = torch.full((BATCH,), fill_value=N_CTX, device=device)
1644
+ cu_seqlens = torch.zeros((BATCH + 1,), device=device, dtype=torch.int32)
1645
+ cu_seqlens[1:] = lengths.cumsum(0)
1646
+ qkv = torch.randn((BATCH * N_CTX, 3, H, D_HEAD), dtype=dtype, device=device, requires_grad=True)
1647
+ fn = lambda: flash_attn_func(qkv, cu_seqlens, 0., N_CTX, causal=True)
1648
+ if mode == 'bwd':
1649
+ o = fn()
1650
+ do = torch.randn_like(o)
1651
+ fn = lambda: o.backward(do, retain_graph=True)
1652
+ ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep)
1653
+ return ms
1654
+
1655
+ # if provider == 'torch':
1656
+ # q = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device='cuda', requires_grad=True)
1657
+ # k = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device='cuda', requires_grad=True)
1658
+ # v = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device='cuda', requires_grad=True)
1659
+ # sm_scale = 1.3
1660
+ # causal_mask = torch.tril(torch.ones(N_CTX, N_CTX)).type_as(q)
1661
+ # fn = lambda: torch_attention(q, k, v, causal_mask, sm_scale)
1662
+ # ms = triton.testing.do_bench(fn, percentiles=None, warmup=warmup, rep=rep)
1663
+ # return ms
1664
+
1665
+
1666
+ BATCH, N_HEADS, N_CTX, D_HEAD, Q_LEN = 4, 32, 4096, 128, 1 # 6.7B model, with 4k len
1667
+
1668
+ BLOCK_SIZE = 64
1669
+ LOCAl_BLOCKS = 8 # 4
1670
+ VERT_STRIDE = 16 # 8
1671
+ HOMO_HEAD = False
1672
+ sparse_type = 'home' if HOMO_HEAD else 'hetero'
1673
+ dtype = torch.bfloat16
1674
+ MAX_N_CTX = 8192
1675
+
1676
+ configs = [triton.testing.Benchmark(
1677
+ x_names=['PAST_LEN'],
1678
+ x_vals=[2**i - 1 for i in range(8, 14)],
1679
+ line_arg='provider',
1680
+ line_vals=['torch'] + (['flash'] if HAS_FLASH else []) + ['triton_sparse', 'triton_dense'],
1681
+ line_names=['Torch'] + (['Flash-Dense'] if HAS_FLASH else []) + ['Triton-Sparse', 'Triton-Dense'],
1682
+ styles=[('red', '-'), ('blue', '-'), ('green', '-'), ('cyan', '-')],
1683
+ ylabel='ms',
1684
+ plot_name=f'fused-attention-inference-batch{BATCH}-head{N_HEADS}-d{D_HEAD}-sparse-local{LOCAl_BLOCKS}-vert{VERT_STRIDE}-{sparse_type}',
1685
+ args={'H': N_HEADS, 'BATCH': BATCH, 'D_HEAD': D_HEAD, 'Q_LEN': Q_LEN, 'dtype': torch.float16, 'mode': mode}
1686
+ ) for mode in ['fwd']]
1687
+ @triton.testing.perf_report(configs)
1688
+ def bench_flash_attention_inference(BATCH, H, PAST_LEN, D_HEAD, Q_LEN, mode, provider, dtype=torch.bfloat16, device='cuda'):
1689
+ assert mode in ['fwd']
1690
+ warmup = 25
1691
+ rep = 100
1692
+ N_CTX = PAST_LEN + Q_LEN
1693
+ if provider == 'torch':
1694
+ q = torch.randn((BATCH, H, Q_LEN, D_HEAD), dtype=dtype, device='cuda', requires_grad=False)
1695
+ k = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device='cuda', requires_grad=False)
1696
+ v = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device='cuda', requires_grad=False)
1697
+ sm_scale = 1.3
1698
+ mask_csr, _, mask_dense = get_sparse_attn_mask(q, N_CTX, BLOCK=BLOCK_SIZE,
1699
+ local_blocks=LOCAl_BLOCKS, vert_stride=VERT_STRIDE, homo_head=VERT_STRIDE, return_dense=True)
1700
+
1701
+ fn = lambda: torch_attention(q, k, v, mask_dense, sm_scale=sm_scale, block_size=2048)
1702
+ ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep)
1703
+ return ms
1704
+ if provider == 'triton_sparse':
1705
+ q = torch.randn((BATCH, H, Q_LEN, D_HEAD), dtype=dtype, device='cuda', requires_grad=False)
1706
+ k = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device='cuda', requires_grad=False)
1707
+ v = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device='cuda', requires_grad=False)
1708
+ sm_scale = 1.3
1709
+ sparse_attention_fn = get_local_strided_sparse_attention_op(H, MAX_N_CTX,
1710
+ local_blocks=LOCAl_BLOCKS,
1711
+ vert_stride=VERT_STRIDE,
1712
+ homo_head=HOMO_HEAD,
1713
+ sparse_block_size=BLOCK_SIZE,
1714
+ kernel_block_size=BLOCK_SIZE,
1715
+ device=q.device,
1716
+ inference=True)
1717
+
1718
+ fn = lambda: sparse_attention_fn(q, k, v, sm_scale)
1719
+ ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep)
1720
+ return ms
1721
+ if provider == 'triton_dense':
1722
+ q = torch.randn((BATCH, H, Q_LEN, D_HEAD), dtype=dtype, device='cuda', requires_grad=False)
1723
+ k = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device='cuda', requires_grad=False)
1724
+ v = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device='cuda', requires_grad=False)
1725
+ sm_scale = 1.3
1726
+ sparse_attention_fn = get_local_strided_sparse_attention_op(H, MAX_N_CTX,
1727
+ local_blocks=1,
1728
+ vert_stride=1,
1729
+ homo_head=True,
1730
+ sparse_block_size=BLOCK_SIZE,
1731
+ kernel_block_size=BLOCK_SIZE,
1732
+ device=q.device,
1733
+ inference=True)
1734
+
1735
+ fn = lambda: sparse_attention_fn(q, k, v, sm_scale)
1736
+ ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep)
1737
+ return ms
1738
+ if provider == 'flash':
1739
+ assert Q_LEN == 1
1740
+ lengths = torch.full((BATCH,), fill_value=N_CTX, device=device)
1741
+ cu_seqlens = torch.zeros((BATCH + 1,), device=device, dtype=torch.int32)
1742
+ cu_seqlens[1:] = lengths.cumsum(0)
1743
+ cu_seqlens_q = torch.arange(BATCH + 1, device=device, dtype=torch.int32)
1744
+
1745
+ # (total_q, nheads, headdim),
1746
+ q = torch.randn((BATCH, H, D_HEAD), dtype=dtype, device='cuda', requires_grad=False)
1747
+ k = torch.randn((BATCH*N_CTX, H, D_HEAD), dtype=dtype, device='cuda', requires_grad=False)
1748
+ v = torch.randn((BATCH*N_CTX, H, D_HEAD), dtype=dtype, device='cuda', requires_grad=False)
1749
+
1750
+ fn = lambda: flash_attn_unpadded_func(q, k, v, cu_seqlens_q, cu_seqlens, 1, N_CTX, dropout_p=0, softmax_scale=1.3, causal=False)
1751
+ ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep)
1752
+ return ms
1753
+
1754
+
1755
+ test_op(1, 4, 512, 128, dtype=torch.float16, homo_head=False, backward=support_backward)
1756
+ # bench_flash_attention.run(save_path='.', print_data=True)
1757
+
1758
+ bench_flash_attention_inference.run(save_path='.', print_data=True)
1759
+ exit()
1760
+ # head_dim=64
1761
+ test_op(1, 2, 1024, 64, kernel_block_size=64, sparse_block_size=64,
1762
+ dtype=torch.bfloat16, homo_head=False, backward=support_backward)
1763
+ # uneven length, bf16
1764
+ test_op(1, 16, 224, 128, dtype=torch.bfloat16, homo_head=False, backward=False, sparse_block_size=128,
1765
+ kernel_block_size=64, local_blocks=8, vert_stride=8)
1766
+ test_op(3, 2, 2047, 128, homo_head=False, backward=False)
1767
+
1768
+ # diff kernel/sparse block size
1769
+ test_op(1, 16, 224, 128, dtype=torch.bfloat16, homo_head=False, backward=False, kernel_block_size=64)
1770
+ # inference
1771
+ # test_op(1, 4, 512 + 256, 128, Q_LEN=1, dtype=torch.bfloat16, homo_head=False, backward=support_backward)
1772
+
1773
+ # dense flash attn
1774
+ test_op(1, 2, 1024, 128, kernel_block_size=128, sparse_block_size=128, dtype=torch.bfloat16, homo_head=False,
1775
+ backward=support_backward, local_blocks=1, vert_stride=1)
1776
+
1777
+ # fp16
1778
+ test_op(1, 4, 512 + 256, 128, dtype=torch.float16, homo_head=False, backward=support_backward)
1779
+
1780
+ # longer sequence
1781
+ test_op(2, 4, 8192, 64, homo_head=False, backward=support_backward)
1782
+ test_op(2, 4, 8192, 128, dtype=torch.bfloat16, homo_head=False, backward=support_backward)
1783
+
1784
+ # homo head
1785
+ test_op(3, 2, 2048, 64, homo_head=True, dtype=torch.bfloat16, backward=False)
1786
+ test_op(3, 2, 2048, 64, homo_head=True, backward=support_backward)
1787
+
1788
+ # sparse_attention_fn = sparse_attention_factory(16, 128, num_warps=1, INFERENCE=True)
1789
+ # test_op(8, 1, 2047, 128, 1, backward=False, sparse_attention_fn=None)
1790
+ # test_op_inference(3, 2, 2048, 128, 2048)
1791
+ # test_op_inference(3, 2, 2047, 64, 2047)
1792
+ # test_op_inference(3, 2, 256, 64, 128)
1793
+ # test_op_inference(3, 2, 2048, 64, 1)
1794
+
1795
+ bench_flash_attention.run(save_path='.', print_data=True)
1796
+ # bench_flash_attention_inference.run(save_path='.', print_data=True)
1797
+
1798
+ # ========================
1799
+ # Some Benchmark Results #
1800
+ # ========================
1801
+
1802
+ # fused-attention-batch4-head48-d64-sparse-local4-vert4-hetero-fwd
1803
+ # SEQ_LEN Triton-Dense Flash-Dense Triton-Sparse
1804
+ # 0 256.0 0.057184 0.069646 0.052567
1805
+ # 1 512.0 0.131688 0.187658 0.110212
1806
+ # 2 1024.0 0.391844 0.524990 0.247875
1807
+ # 3 2048.0 1.305190 1.456685 0.596506
1808
+ # 4 4096.0 4.623019 4.968653 1.600277
1809
+ # 5 8192.0 17.513062 18.332262 4.802458
1810
+ # 6 16384.0 68.453377 70.337540 16.052908
1811
+ # 7 32768.0 270.655487 276.020233 57.938946
1812
+ # fused-attention-batch4-head48-d64-sparse-local4-vert4-hetero-bwd (num_warp=8):
1813
+ # SEQ_LEN Triton-Dense Flash-Dense Triton-Sparse
1814
+ # 0 256.0 0.190120 0.150313 0.181451
1815
+ # 1 512.0 0.406348 0.391767 0.391177
1816
+ # 2 1024.0 1.029704 1.182967 0.885741
1817
+ # 3 2048.0 2.985456 3.843399 2.040469
1818
+ # 4 4096.0 9.808897 13.073701 5.069609
1819
+ # 5 8192.0 34.995201 47.863808 13.948782
1820
+ # 6 16384.0 132.740097 182.579193 42.816513
1821
+ # 7 32768.0 542.223389 714.820618 147.053574
1822
+ # fused-attention-inference-batch4-head32-d128-sparse-local4-vert4-hetero:
1823
+ # PAST_LEN Torch-Dense Flash-Dense Triton-Sparse
1824
+ # 0 256.0 0.050949 0.032357 0.107513
1825
+ # 1 512.0 0.073624 0.050651 0.199086
1826
+ # 2 1024.0 0.107472 0.080379 0.245445
1827
+ # 3 2048.0 0.178423 0.129448 0.338259
1828
+ # 4 4096.0 0.327647 0.223106 0.517048
1829
+ # 5 8192.0 0.588423 0.411263 0.884606
1830
+ # 6 16384.0 1.098898 0.798941 1.611809
1831
+ # 7 32768.0 2.094537 1.594726 3.044160
1832
+
1833
+
1834
+ # 6.7B
1835
+ # fused-attention-batch4-head32-d128-sparse-local4-vert4-hetero-fwd:
1836
+ # SEQ_LEN Triton-Dense Flash-Dense Triton-Sparse
1837
+ # 0 256.0 0.069208 0.082156 0.065097
1838
+ # 1 512.0 0.138271 0.201393 0.144467
1839
+ # 2 1024.0 0.391521 0.624614 0.322382
1840
+ # 3 2048.0 1.268443 2.406325 0.784367
1841
+ # 4 4096.0 4.455703 9.139097 2.100856
1842
+ # 5 8192.0 16.764315 35.289600 6.328320
1843
+ # 6 16384.0 65.221634 138.401794 21.069057
1844
+ # 7 32768.0 257.251343 548.085754 76.111870
1845
+ # fused-attention-batch4-head32-d128-sparse-local4-vert4-hetero-bwd:
1846
+ # SEQ_LEN Triton-Dense Flash-Dense Triton-Sparse
1847
+ # 0 256.0 0.297118 0.266469 0.255255
1848
+ # 1 512.0 0.672826 0.613685 0.552954
1849
+ # 2 1024.0 1.718434 1.705066 1.251953
1850
+ # 3 2048.0 4.936755 5.403875 2.927895
1851
+ # 4 4096.0 15.911594 18.959362 7.436288
1852
+ # 5 8192.0 55.357441 70.808578 21.140224
1853
+ # 6 16384.0 208.188416 273.617920 68.018173
1854
+ # 7 32768.0 806.037476 1081.453613 218.720261
1855
+ # fused-attention-inference-batch4-head32-d128-sparse-local4-vert4-hetero:
1856
+ # PAST_LEN Torch-Dense Flash-Dense Triton-Sparse
1857
+ # 0 256.0 0.050151 0.032337 0.107593
1858
+ # 1 512.0 0.073409 0.051737 0.200200
1859
+ # 2 1024.0 0.107533 0.082099 0.247067
1860
+ # 3 2048.0 0.177259 0.128891 0.338510
1861
+ # 4 4096.0 0.325866 0.223621 0.524842
1862
+ # 5 8192.0 0.586926 0.408913 0.885490
1863
+ # 6 16384.0 1.100834 0.793277 1.612271
1864
+ # 7 32768.0 2.098851 1.595831 3.064544
1865
+
1866
+ # fused-attention-batch4-head32-d128-sparse-local4-vert8-hetero-fwd:
1867
+ # SEQ_LEN Triton-Dense Flash-Dense Triton-Sparse
1868
+ # 0 256.0 0.066673 0.082037 0.065085
1869
+ # 1 512.0 0.137379 0.201880 0.143473
1870
+ # 2 1024.0 0.390675 0.624234 0.312046
1871
+ # 3 2048.0 1.267739 2.406950 0.696045
1872
+ # 4 4096.0 4.445138 9.136333 1.665788
1873
+ # 5 8192.0 16.768614 35.265533 4.380486
1874
+ # 6 16384.0 65.235970 138.393600 12.997633
1875
+ # 7 32768.0 257.317902 550.442993 42.821121
1876
+ # fused-attention-batch4-head32-d128-sparse-local4-vert8-hetero-bwd:
1877
+ # SEQ_LEN Triton-Dense Flash-Dense Triton-Sparse
1878
+ # 0 256.0 0.296461 0.266581 0.254022
1879
+ # 1 512.0 0.671427 0.613643 0.551283
1880
+ # 2 1024.0 1.719918 1.704295 1.229982
1881
+ # 3 2048.0 4.945305 5.403364 2.721906
1882
+ # 4 4096.0 15.934293 18.960999 6.259371
1883
+ # 5 8192.0 55.406593 70.832130 15.676929
1884
+ # 6 16384.0 208.750595 275.004425 44.837891
1885
+ # 7 32768.0 808.057861 1080.647705 141.856766
1886
+ # fused-attention-inference-batch4-head32-d128-sparse-local4-vert8-hetero:
1887
+ # PAST_LEN Torch-Dense Flash-Dense Triton-Sparse
1888
+ # 0 256.0 0.050739 0.032886 0.107837
1889
+ # 1 512.0 0.073507 0.051996 0.200293
1890
+ # 2 1024.0 0.106394 0.080679 0.240610
1891
+ # 3 2048.0 0.177659 0.127660 0.287625
1892
+ # 4 4096.0 0.326326 0.226971 0.377500
1893
+ # 5 8192.0 0.586339 0.407367 0.559266
1894
+ # 6 16384.0 1.102279 0.786221 0.920976
1895
+ # 7 32768.0 2.097370 1.545090 1.644288
1896
+
1897
+
1898
+ ################
1899
+ ##### fp16 #####
1900
+ ################
1901
+
1902
+ # fused-attention-batch4-head16-d64-sparse-local4-vert8-hetero-fwd:
1903
+ # SEQ_LEN Triton-Dense Flash-Dense Triton-Sparse
1904
+ # 0 256.0 0.032518 0.035472 0.029939
1905
+ # 1 512.0 0.054266 0.087841 0.054320
1906
+ # 2 1024.0 0.133447 0.263090 0.102045
1907
+ # 3 2048.0 0.384615 1.023293 0.201763
1908
+ # 4 4096.0 1.300890 4.023936 0.449555
1909
+ # 5 8192.0 4.774144 15.816704 1.150854
1910
+ # 6 16384.0 18.220032 62.771198 3.356001
1911
+ # 7 32768.0 71.405571 250.273788 10.976142
1912
+ # fused-attention-batch4-head16-d64-sparse-local4-vert8-hetero-bwd:
1913
+ # SEQ_LEN Triton-Dense Flash-Dense Triton-Sparse
1914
+ # 0 256.0 0.083342 0.069742 0.079496
1915
+ # 1 512.0 0.159894 0.170995 0.151705
1916
+ # 2 1024.0 0.386071 0.522407 0.331443
1917
+ # 3 2048.0 1.067715 1.737333 0.715248
1918
+ # 4 4096.0 3.382731 6.219520 1.597457
1919
+ # 5 8192.0 11.857793 23.560448 3.879035
1920
+ # 6 16384.0 44.422142 91.251709 10.626843
1921
+ # 7 32768.0 175.011841 359.473145 32.340992
1922
+
1923
+
1924
+ ################
1925
+ ##### bf16 #####
1926
+ ################
1927
+
1928
+ # fused-attention-batch4-head16-d64-sparse-local4-vert8-hetero-fwd:
1929
+ # SEQ_LEN Triton-Dense Flash-Dense Triton-Sparse
1930
+ # 0 256.0 0.037636 0.035902 0.031512
1931
+ # 1 512.0 0.058591 0.087229 0.058125
1932
+ # 2 1024.0 0.143337 0.263919 0.108443
1933
+ # 3 2048.0 0.414458 1.025985 0.214114
1934
+ # 4 4096.0 1.390841 4.020010 0.480550
1935
+ # 5 8192.0 5.067938 15.808171 1.230874
1936
+ # 6 16384.0 19.442280 62.765057 3.597274
1937
+ # 7 32768.0 75.501572 250.443771 11.768959
1938
+ # fused-attention-batch4-head16-d64-sparse-local4-vert8-hetero-bwd:
1939
+ # SEQ_LEN Triton-Dense Flash-Dense Triton-Sparse
1940
+ # 0 256.0 0.084404 0.070663 0.082613
1941
+ # 1 512.0 0.161510 0.172882 0.157661
1942
+ # 2 1024.0 0.388954 0.526047 0.339855
1943
+ # 3 2048.0 1.075814 1.736057 0.732420
1944
+ # 4 4096.0 3.401622 6.221376 1.636039
1945
+ # 5 8192.0 11.915136 23.483391 3.968725
1946
+ # 6 16384.0 44.660225 91.302910 10.857130
1947
+ # 7 32768.0 175.038467 359.048187 32.778240