Update modeling_llama.py
Browse files- modeling_llama.py +55 -21
modeling_llama.py
CHANGED
@@ -32,19 +32,52 @@ from transformers.modeling_utils import PreTrainedModel
|
|
32 |
from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
|
33 |
from .configuration_clex import CLEXLlamaConfig
|
34 |
from .clex_layer import LlamaCLEXScalingRotaryEmbedding
|
35 |
-
|
36 |
-
|
37 |
from einops import rearrange
|
38 |
-
|
39 |
-
|
40 |
-
from flash_attn.bert_padding import unpad_input, pad_input
|
41 |
|
42 |
|
43 |
logger = logging.get_logger(__name__)
|
44 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
45 |
_CONFIG_FOR_DOC = "CLEXLlamaConfig"
|
46 |
|
47 |
|
|
|
|
|
|
|
48 |
# Copied from transformers.models.bart.modeling_bart._make_causal_mask
|
49 |
def _make_causal_mask(
|
50 |
input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
|
@@ -137,13 +170,13 @@ def rotate_half(x):
|
|
137 |
return torch.cat((-x2, x1), dim=-1)
|
138 |
|
139 |
|
140 |
-
def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
|
141 |
# The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
|
142 |
cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
|
143 |
sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
|
144 |
cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
|
145 |
sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
|
146 |
-
q_embed = (q * cos) + (rotate_half(q) * sin)
|
147 |
k_embed = (k * cos) + (rotate_half(k) * sin)
|
148 |
return q_embed, k_embed
|
149 |
|
@@ -247,19 +280,17 @@ class LlamaAttention(nn.Module):
|
|
247 |
value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
248 |
|
249 |
kv_seq_len = key_states.shape[-2]
|
250 |
-
if past_key_value is not None:
|
251 |
-
kv_seq_len += past_key_value[0].shape[-2]
|
252 |
-
# [bsz, nh, t, hd]
|
253 |
|
254 |
if past_key_value is not None:
|
255 |
kv_seq_len += past_key_value[0].shape[-2]
|
256 |
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
257 |
|
258 |
if pack_cos_sin is not None:
|
259 |
-
cos, sin = pack_cos_sin
|
260 |
else:
|
261 |
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
262 |
-
|
|
|
263 |
|
264 |
if past_key_value is not None:
|
265 |
# reuse k, v, self_attention
|
@@ -267,12 +298,13 @@ class LlamaAttention(nn.Module):
|
|
267 |
|
268 |
past_key_value = (key_states, value_states) if use_cache else None
|
269 |
|
|
|
270 |
|
271 |
if self.log_scale:
|
272 |
log_n = torch.log(torch.tensor(kv_seq_len*1.0)).to(query_states.device, dtype=query_states.dtype) / \
|
273 |
torch.log(torch.tensor(self.config.max_position_embeddings)).to(query_states.device, dtype=query_states.dtype)
|
274 |
query_states = query_states * log_n
|
275 |
-
if query_states.shape[-2] == 1 or query_states.shape[-2] != key_states.shape[-2] or
|
276 |
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
277 |
|
278 |
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
|
@@ -308,6 +340,7 @@ class LlamaAttention(nn.Module):
|
|
308 |
attn_weights = None
|
309 |
|
310 |
return attn_output, attn_weights, past_key_value
|
|
|
311 |
elif past_key_value is not None:
|
312 |
output = flash_attn_with_kvcache(
|
313 |
query_states.transpose(1, 2),
|
@@ -614,13 +647,15 @@ class LlamaModel(LlamaPreTrainedModel):
|
|
614 |
if inputs_embeds is None:
|
615 |
inputs_embeds = self.embed_tokens(input_ids)
|
616 |
# embed positions
|
617 |
-
if attention_mask is None:
|
618 |
-
|
619 |
-
|
620 |
-
|
621 |
-
attention_mask = self._prepare_decoder_attention_mask(
|
622 |
-
|
623 |
-
)
|
|
|
|
|
624 |
|
625 |
hidden_states = inputs_embeds
|
626 |
|
@@ -802,7 +837,6 @@ class LlamaForCausalLM(LlamaPreTrainedModel):
|
|
802 |
# Enable model parallelism
|
803 |
shift_labels = shift_labels.to(shift_logits.device)
|
804 |
loss = loss_fct(shift_logits, shift_labels)
|
805 |
-
|
806 |
if not return_dict:
|
807 |
output = (logits,) + outputs[1:]
|
808 |
return (loss,) + output if loss is not None else output
|
|
|
32 |
from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
|
33 |
from .configuration_clex import CLEXLlamaConfig
|
34 |
from .clex_layer import LlamaCLEXScalingRotaryEmbedding
|
|
|
|
|
35 |
from einops import rearrange
|
36 |
+
import importlib.metadata
|
37 |
+
import importlib.util
|
|
|
38 |
|
39 |
|
40 |
logger = logging.get_logger(__name__)
|
41 |
|
42 |
+
def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[Tuple[bool, str], bool]:
|
43 |
+
# Check we're not importing a "pkg_name" directory somewhere but the actual library by trying to grab the version
|
44 |
+
package_exists = importlib.util.find_spec(pkg_name) is not None
|
45 |
+
package_version = "N/A"
|
46 |
+
if package_exists:
|
47 |
+
try:
|
48 |
+
package_version = importlib.metadata.version(pkg_name)
|
49 |
+
package_exists = True
|
50 |
+
except importlib.metadata.PackageNotFoundError:
|
51 |
+
package_exists = False
|
52 |
+
logger.info(f"Detected {pkg_name} version {package_version}")
|
53 |
+
if return_version:
|
54 |
+
return package_exists, package_version
|
55 |
+
else:
|
56 |
+
return package_exists
|
57 |
+
|
58 |
+
def is_flash_attn_available():
|
59 |
+
if not _is_package_available("torch", return_version=True):
|
60 |
+
return False
|
61 |
+
|
62 |
+
# Let's add an extra check to see if cuda is available
|
63 |
+
import torch
|
64 |
+
|
65 |
+
return _is_package_available("flash_attn") and torch.cuda.is_available()
|
66 |
+
|
67 |
+
if is_flash_attn_available():
|
68 |
+
from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func, flash_attn_qkvpacked_func, flash_attn_with_kvcache
|
69 |
+
# from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func
|
70 |
+
from flash_attn.bert_padding import unpad_input, pad_input
|
71 |
+
|
72 |
+
|
73 |
+
|
74 |
+
|
75 |
_CONFIG_FOR_DOC = "CLEXLlamaConfig"
|
76 |
|
77 |
|
78 |
+
|
79 |
+
|
80 |
+
|
81 |
# Copied from transformers.models.bart.modeling_bart._make_causal_mask
|
82 |
def _make_causal_mask(
|
83 |
input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
|
|
|
170 |
return torch.cat((-x2, x1), dim=-1)
|
171 |
|
172 |
|
173 |
+
def apply_rotary_pos_emb(q, k, cos, sin, q_len, position_ids):
|
174 |
# The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
|
175 |
cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
|
176 |
sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
|
177 |
cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
|
178 |
sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
|
179 |
+
q_embed = (q * cos[:, :, -q_len:, :]) + (rotate_half(q) * sin[:, :, -q_len:, :])
|
180 |
k_embed = (k * cos) + (rotate_half(k) * sin)
|
181 |
return q_embed, k_embed
|
182 |
|
|
|
280 |
value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
281 |
|
282 |
kv_seq_len = key_states.shape[-2]
|
|
|
|
|
|
|
283 |
|
284 |
if past_key_value is not None:
|
285 |
kv_seq_len += past_key_value[0].shape[-2]
|
286 |
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
287 |
|
288 |
if pack_cos_sin is not None:
|
289 |
+
cos, sin = pack_cos_sin.to(query_states.device)
|
290 |
else:
|
291 |
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
292 |
+
key_position_ids = torch.arange(kv_seq_len, dtype=torch.long, device=position_ids.device).unsqueeze(0).view(-1, kv_seq_len)
|
293 |
+
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, q_len, key_position_ids)
|
294 |
|
295 |
if past_key_value is not None:
|
296 |
# reuse k, v, self_attention
|
|
|
298 |
|
299 |
past_key_value = (key_states, value_states) if use_cache else None
|
300 |
|
301 |
+
use_flashatn = self.config.use_flashattn and is_flash_attn_available()
|
302 |
|
303 |
if self.log_scale:
|
304 |
log_n = torch.log(torch.tensor(kv_seq_len*1.0)).to(query_states.device, dtype=query_states.dtype) / \
|
305 |
torch.log(torch.tensor(self.config.max_position_embeddings)).to(query_states.device, dtype=query_states.dtype)
|
306 |
query_states = query_states * log_n
|
307 |
+
if query_states.shape[-2] == 1 or query_states.shape[-2] != key_states.shape[-2] or use_flashatn:
|
308 |
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
309 |
|
310 |
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
|
|
|
340 |
attn_weights = None
|
341 |
|
342 |
return attn_output, attn_weights, past_key_value
|
343 |
+
# use flash attention
|
344 |
elif past_key_value is not None:
|
345 |
output = flash_attn_with_kvcache(
|
346 |
query_states.transpose(1, 2),
|
|
|
647 |
if inputs_embeds is None:
|
648 |
inputs_embeds = self.embed_tokens(input_ids)
|
649 |
# embed positions
|
650 |
+
# if attention_mask is None:
|
651 |
+
# attention_mask = torch.ones(
|
652 |
+
# (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
|
653 |
+
# )
|
654 |
+
# attention_mask = self._prepare_decoder_attention_mask(
|
655 |
+
# attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
|
656 |
+
# )
|
657 |
+
attention_mask = None
|
658 |
+
|
659 |
|
660 |
hidden_states = inputs_embeds
|
661 |
|
|
|
837 |
# Enable model parallelism
|
838 |
shift_labels = shift_labels.to(shift_logits.device)
|
839 |
loss = loss_fct(shift_logits, shift_labels)
|
|
|
840 |
if not return_dict:
|
841 |
output = (logits,) + outputs[1:]
|
842 |
return (loss,) + output if loss is not None else output
|