JustinLin610
commited on
Commit
·
9f94ac2
1
Parent(s):
f6498e5
support cpu inference, fix conflicts between fp32 and flash-attn
Browse files- modeling_qwen.py +64 -25
modeling_qwen.py
CHANGED
@@ -15,6 +15,7 @@ from torch.cuda.amp import autocast
|
|
15 |
from torch.nn import CrossEntropyLoss
|
16 |
from transformers import PreTrainedTokenizer, GenerationConfig, StoppingCriteriaList
|
17 |
from transformers.generation.logits_process import LogitsProcessorList
|
|
|
18 |
if TYPE_CHECKING:
|
19 |
from transformers.generation.streamers import BaseStreamer
|
20 |
from transformers.generation.utils import GenerateOutput
|
@@ -38,15 +39,19 @@ try:
|
|
38 |
use_flash_rotary = True
|
39 |
except ImportError:
|
40 |
use_flash_rotary = False
|
41 |
-
print(
|
42 |
-
|
|
|
|
|
43 |
|
44 |
try:
|
45 |
from flash_attn.ops.rms_norm import rms_norm
|
46 |
except ImportError:
|
47 |
rms_norm = None
|
48 |
-
print(
|
49 |
-
|
|
|
|
|
50 |
|
51 |
from .configuration_qwen import QWenConfig
|
52 |
from .qwen_generation_utils import (
|
@@ -69,8 +74,10 @@ try:
|
|
69 |
from flash_attn.flash_attn_interface import flash_attn_unpadded_func
|
70 |
except ImportError:
|
71 |
flash_attn_unpadded_func = None
|
72 |
-
print(
|
73 |
-
|
|
|
|
|
74 |
|
75 |
|
76 |
class FlashSelfAttention(torch.nn.Module):
|
@@ -177,8 +184,12 @@ class QWenAttention(nn.Module):
|
|
177 |
config.hidden_size, self.projection_size, bias=not config.no_bias
|
178 |
)
|
179 |
|
180 |
-
self.is_fp32 = not(config.bf16 or config.fp16)
|
181 |
-
if
|
|
|
|
|
|
|
|
|
182 |
self.core_attention_flash = FlashSelfAttention(
|
183 |
causal=True, attention_dropout=config.attn_pdrop
|
184 |
)
|
@@ -197,14 +208,15 @@ class QWenAttention(nn.Module):
|
|
197 |
if self.rotary_ndims is not None
|
198 |
else self.hidden_size_per_attention_head
|
199 |
)
|
200 |
-
self.rotary_emb = RotaryEmbedding(
|
201 |
-
dim, base=config.rotary_emb_base
|
202 |
-
)
|
203 |
|
204 |
self.use_dynamic_ntk = config.use_dynamic_ntk
|
205 |
self.use_logn_attn = config.use_logn_attn
|
206 |
|
207 |
-
logn_list = [
|
|
|
|
|
|
|
208 |
self.logn_tensor = torch.Tensor(logn_list)[None, :, None, None]
|
209 |
self._ntk_cached = 1.0
|
210 |
|
@@ -335,14 +347,20 @@ class QWenAttention(nn.Module):
|
|
335 |
if layer_past:
|
336 |
# layer past[0] shape: bs * seq_len * head_num * dim
|
337 |
kv_seq_len += layer_past[0].shape[1]
|
338 |
-
if
|
|
|
|
|
|
|
|
|
339 |
context_value = math.log(kv_seq_len / self.seq_length, 2) + 1
|
340 |
ntk_alpha = 2 ** math.ceil(context_value) - 1
|
341 |
ntk_alpha = max(ntk_alpha, 1)
|
342 |
self._ntk_cached = ntk_alpha
|
343 |
else:
|
344 |
ntk_alpha = self._ntk_cached
|
345 |
-
rotary_pos_emb = self.rotary_emb(kv_seq_len, ntk_alpha=ntk_alpha).to(
|
|
|
|
|
346 |
|
347 |
if rotary_pos_emb is not None:
|
348 |
if isinstance(rotary_pos_emb, tuple):
|
@@ -377,7 +395,12 @@ class QWenAttention(nn.Module):
|
|
377 |
logn_tensor = self.logn_tensor[:, seq_start:seq_end, :, :]
|
378 |
query = query * logn_tensor.expand_as(query)
|
379 |
|
380 |
-
if
|
|
|
|
|
|
|
|
|
|
|
381 |
q, k, v = query, key, value
|
382 |
context_layer = self.core_attention_flash(q, k, v)
|
383 |
|
@@ -398,7 +421,11 @@ class QWenAttention(nn.Module):
|
|
398 |
attn_output = self.c_proj(context_layer)
|
399 |
outputs = (attn_output, present)
|
400 |
if output_attentions:
|
401 |
-
if
|
|
|
|
|
|
|
|
|
402 |
raise ValueError("Cannot output attentions while using flash-attn")
|
403 |
else:
|
404 |
outputs += (attn_weight,)
|
@@ -750,7 +777,9 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
|
750 |
super().__init__(config)
|
751 |
self.transformer = QWenModel(config)
|
752 |
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
|
753 |
-
assert not
|
|
|
|
|
754 |
if config.bf16:
|
755 |
self.transformer.bfloat16()
|
756 |
self.lm_head.bfloat16()
|
@@ -929,21 +958,25 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
|
929 |
generation_config: Optional[GenerationConfig] = None,
|
930 |
logits_processor: Optional[LogitsProcessorList] = None,
|
931 |
stopping_criteria: Optional[StoppingCriteriaList] = None,
|
932 |
-
prefix_allowed_tokens_fn: Optional[
|
|
|
|
|
933 |
synced_gpus: Optional[bool] = None,
|
934 |
streamer: Optional["BaseStreamer"] = None,
|
935 |
**kwargs,
|
936 |
) -> Union[GenerateOutput, torch.LongTensor]:
|
937 |
# Process stop_words_ids.
|
938 |
-
stop_words_ids = kwargs.pop(
|
939 |
if stop_words_ids is None and generation_config is not None:
|
940 |
-
stop_words_ids = getattr(generation_config,
|
941 |
if stop_words_ids is None:
|
942 |
-
stop_words_ids = getattr(self.generation_config,
|
943 |
|
944 |
if stop_words_ids is not None:
|
945 |
stop_words_logits_processor = StopWordsLogitsProcessor(
|
946 |
-
stop_words_ids=stop_words_ids,
|
|
|
|
|
947 |
if logits_processor is None:
|
948 |
logits_processor = LogitsProcessorList([stop_words_logits_processor])
|
949 |
else:
|
@@ -978,7 +1011,13 @@ class RotaryEmbedding(torch.nn.Module):
|
|
978 |
seqlen = max_seq_len + offset
|
979 |
if seqlen > self._seq_len_cached or ntk_alpha != self._ntk_alpha_cached:
|
980 |
base = self.base * ntk_alpha ** (self.dim / (self.dim - 2))
|
981 |
-
self.inv_freq = 1.0 / (
|
|
|
|
|
|
|
|
|
|
|
|
|
982 |
self._seq_len_cached = seqlen
|
983 |
self._ntk_alpha_cached = ntk_alpha
|
984 |
seq = torch.arange(seqlen, device=self.inv_freq.device)
|
@@ -1028,8 +1067,8 @@ class RMSNorm(torch.nn.Module):
|
|
1028 |
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
1029 |
|
1030 |
def forward(self, x):
|
1031 |
-
if rms_norm is not None:
|
1032 |
return rms_norm(x, self.weight, self.eps)
|
1033 |
else:
|
1034 |
output = self._norm(x.float()).type_as(x)
|
1035 |
-
return output * self.weight
|
|
|
15 |
from torch.nn import CrossEntropyLoss
|
16 |
from transformers import PreTrainedTokenizer, GenerationConfig, StoppingCriteriaList
|
17 |
from transformers.generation.logits_process import LogitsProcessorList
|
18 |
+
|
19 |
if TYPE_CHECKING:
|
20 |
from transformers.generation.streamers import BaseStreamer
|
21 |
from transformers.generation.utils import GenerateOutput
|
|
|
39 |
use_flash_rotary = True
|
40 |
except ImportError:
|
41 |
use_flash_rotary = False
|
42 |
+
print(
|
43 |
+
"Warning: import flash_attn rotary fail, please install FlashAttention rotary to get better performance "
|
44 |
+
"https://github.com/Dao-AILab/flash-attention/tree/main/csrc/rotary"
|
45 |
+
)
|
46 |
|
47 |
try:
|
48 |
from flash_attn.ops.rms_norm import rms_norm
|
49 |
except ImportError:
|
50 |
rms_norm = None
|
51 |
+
print(
|
52 |
+
"Warning: import flash_attn rms_norm fail, please install FlashAttention layer_norm to get better performance "
|
53 |
+
"https://github.com/Dao-AILab/flash-attention/tree/main/csrc/layer_norm"
|
54 |
+
)
|
55 |
|
56 |
from .configuration_qwen import QWenConfig
|
57 |
from .qwen_generation_utils import (
|
|
|
74 |
from flash_attn.flash_attn_interface import flash_attn_unpadded_func
|
75 |
except ImportError:
|
76 |
flash_attn_unpadded_func = None
|
77 |
+
print(
|
78 |
+
"Warning: import flash_attn fail, please install FlashAttention "
|
79 |
+
"https://github.com/Dao-AILab/flash-attention"
|
80 |
+
)
|
81 |
|
82 |
|
83 |
class FlashSelfAttention(torch.nn.Module):
|
|
|
184 |
config.hidden_size, self.projection_size, bias=not config.no_bias
|
185 |
)
|
186 |
|
187 |
+
self.is_fp32 = not (config.bf16 or config.fp16)
|
188 |
+
if (
|
189 |
+
self.use_flash_attn
|
190 |
+
and flash_attn_unpadded_func is not None
|
191 |
+
and not self.is_fp32
|
192 |
+
):
|
193 |
self.core_attention_flash = FlashSelfAttention(
|
194 |
causal=True, attention_dropout=config.attn_pdrop
|
195 |
)
|
|
|
208 |
if self.rotary_ndims is not None
|
209 |
else self.hidden_size_per_attention_head
|
210 |
)
|
211 |
+
self.rotary_emb = RotaryEmbedding(dim, base=config.rotary_emb_base)
|
|
|
|
|
212 |
|
213 |
self.use_dynamic_ntk = config.use_dynamic_ntk
|
214 |
self.use_logn_attn = config.use_logn_attn
|
215 |
|
216 |
+
logn_list = [
|
217 |
+
math.log(i, self.seq_length) if i > self.seq_length else 1
|
218 |
+
for i in range(1, 32768)
|
219 |
+
]
|
220 |
self.logn_tensor = torch.Tensor(logn_list)[None, :, None, None]
|
221 |
self._ntk_cached = 1.0
|
222 |
|
|
|
347 |
if layer_past:
|
348 |
# layer past[0] shape: bs * seq_len * head_num * dim
|
349 |
kv_seq_len += layer_past[0].shape[1]
|
350 |
+
if (
|
351 |
+
self.use_dynamic_ntk
|
352 |
+
and kv_seq_len == hidden_states.size()[1]
|
353 |
+
and not self.training
|
354 |
+
):
|
355 |
context_value = math.log(kv_seq_len / self.seq_length, 2) + 1
|
356 |
ntk_alpha = 2 ** math.ceil(context_value) - 1
|
357 |
ntk_alpha = max(ntk_alpha, 1)
|
358 |
self._ntk_cached = ntk_alpha
|
359 |
else:
|
360 |
ntk_alpha = self._ntk_cached
|
361 |
+
rotary_pos_emb = self.rotary_emb(kv_seq_len, ntk_alpha=ntk_alpha).to(
|
362 |
+
hidden_states.device
|
363 |
+
)
|
364 |
|
365 |
if rotary_pos_emb is not None:
|
366 |
if isinstance(rotary_pos_emb, tuple):
|
|
|
395 |
logn_tensor = self.logn_tensor[:, seq_start:seq_end, :, :]
|
396 |
query = query * logn_tensor.expand_as(query)
|
397 |
|
398 |
+
if (
|
399 |
+
self.use_flash_attn
|
400 |
+
and flash_attn_unpadded_func is not None
|
401 |
+
and not self.is_fp32
|
402 |
+
and query.is_cuda
|
403 |
+
):
|
404 |
q, k, v = query, key, value
|
405 |
context_layer = self.core_attention_flash(q, k, v)
|
406 |
|
|
|
421 |
attn_output = self.c_proj(context_layer)
|
422 |
outputs = (attn_output, present)
|
423 |
if output_attentions:
|
424 |
+
if (
|
425 |
+
self.use_flash_attn
|
426 |
+
and flash_attn_unpadded_func is not None
|
427 |
+
and not self.is_fp32
|
428 |
+
):
|
429 |
raise ValueError("Cannot output attentions while using flash-attn")
|
430 |
else:
|
431 |
outputs += (attn_weight,)
|
|
|
777 |
super().__init__(config)
|
778 |
self.transformer = QWenModel(config)
|
779 |
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
|
780 |
+
assert not (
|
781 |
+
config.bf16 and config.fp16
|
782 |
+
), "In config, bf16 and fp16 cannot both be true"
|
783 |
if config.bf16:
|
784 |
self.transformer.bfloat16()
|
785 |
self.lm_head.bfloat16()
|
|
|
958 |
generation_config: Optional[GenerationConfig] = None,
|
959 |
logits_processor: Optional[LogitsProcessorList] = None,
|
960 |
stopping_criteria: Optional[StoppingCriteriaList] = None,
|
961 |
+
prefix_allowed_tokens_fn: Optional[
|
962 |
+
Callable[[int, torch.Tensor], List[int]]
|
963 |
+
] = None,
|
964 |
synced_gpus: Optional[bool] = None,
|
965 |
streamer: Optional["BaseStreamer"] = None,
|
966 |
**kwargs,
|
967 |
) -> Union[GenerateOutput, torch.LongTensor]:
|
968 |
# Process stop_words_ids.
|
969 |
+
stop_words_ids = kwargs.pop("stop_words_ids", None)
|
970 |
if stop_words_ids is None and generation_config is not None:
|
971 |
+
stop_words_ids = getattr(generation_config, "stop_words_ids", None)
|
972 |
if stop_words_ids is None:
|
973 |
+
stop_words_ids = getattr(self.generation_config, "stop_words_ids", None)
|
974 |
|
975 |
if stop_words_ids is not None:
|
976 |
stop_words_logits_processor = StopWordsLogitsProcessor(
|
977 |
+
stop_words_ids=stop_words_ids,
|
978 |
+
eos_token_id=self.generation_config.eos_token_id,
|
979 |
+
)
|
980 |
if logits_processor is None:
|
981 |
logits_processor = LogitsProcessorList([stop_words_logits_processor])
|
982 |
else:
|
|
|
1011 |
seqlen = max_seq_len + offset
|
1012 |
if seqlen > self._seq_len_cached or ntk_alpha != self._ntk_alpha_cached:
|
1013 |
base = self.base * ntk_alpha ** (self.dim / (self.dim - 2))
|
1014 |
+
self.inv_freq = 1.0 / (
|
1015 |
+
base
|
1016 |
+
** (
|
1017 |
+
torch.arange(0, self.dim, 2, device=self.inv_freq.device).float()
|
1018 |
+
/ self.dim
|
1019 |
+
)
|
1020 |
+
)
|
1021 |
self._seq_len_cached = seqlen
|
1022 |
self._ntk_alpha_cached = ntk_alpha
|
1023 |
seq = torch.arange(seqlen, device=self.inv_freq.device)
|
|
|
1067 |
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
1068 |
|
1069 |
def forward(self, x):
|
1070 |
+
if rms_norm is not None and x.is_cuda:
|
1071 |
return rms_norm(x, self.weight, self.eps)
|
1072 |
else:
|
1073 |
output = self._norm(x.float()).type_as(x)
|
1074 |
+
return output * self.weight
|