Upload modeling_phi.py
Browse files- modeling_phi.py +12 -7
modeling_phi.py
CHANGED
@@ -358,7 +358,9 @@ class PhiAttention(nn.Module):
|
|
358 |
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
359 |
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
360 |
|
361 |
-
attn_weights = torch.matmul(
|
|
|
|
|
362 |
|
363 |
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
|
364 |
raise ValueError(
|
@@ -374,7 +376,7 @@ class PhiAttention(nn.Module):
|
|
374 |
attn_weights = attn_weights + attention_mask
|
375 |
|
376 |
# upcast attention to fp32
|
377 |
-
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(
|
378 |
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
|
379 |
|
380 |
attn_output = torch.matmul(attn_weights, value_states)
|
@@ -483,8 +485,10 @@ class PhiFlashAttention2(PhiAttention):
|
|
483 |
# in fp32.
|
484 |
|
485 |
if query_states.dtype == torch.float32:
|
|
|
|
|
486 |
# Handle the case where the model is quantized
|
487 |
-
|
488 |
target_dtype = self.config._pre_quantization_dtype
|
489 |
else:
|
490 |
target_dtype = self.q_proj.weight.dtype
|
@@ -1093,7 +1097,7 @@ class PhiForCausalLM(PhiPreTrainedModel):
|
|
1093 |
|
1094 |
# Keep only the unprocessed tokens:
|
1095 |
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
|
1096 |
-
# some of the inputs are
|
1097 |
# input)
|
1098 |
if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
|
1099 |
input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
|
@@ -1225,9 +1229,10 @@ class PhiForSequenceClassification(PhiPreTrainedModel):
|
|
1225 |
sequence_lengths = -1
|
1226 |
else:
|
1227 |
if input_ids is not None:
|
1228 |
-
|
1229 |
-
|
1230 |
-
|
|
|
1231 |
else:
|
1232 |
sequence_lengths = -1
|
1233 |
|
|
|
358 |
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
359 |
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
360 |
|
361 |
+
attn_weights = torch.matmul(
|
362 |
+
query_states.to(torch.float32), key_states.to(torch.float32).transpose(2, 3)
|
363 |
+
) / math.sqrt(self.head_dim)
|
364 |
|
365 |
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
|
366 |
raise ValueError(
|
|
|
376 |
attn_weights = attn_weights + attention_mask
|
377 |
|
378 |
# upcast attention to fp32
|
379 |
+
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(value_states.dtype)
|
380 |
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
|
381 |
|
382 |
attn_output = torch.matmul(attn_weights, value_states)
|
|
|
485 |
# in fp32.
|
486 |
|
487 |
if query_states.dtype == torch.float32:
|
488 |
+
if torch.is_autocast_enabled():
|
489 |
+
target_dtype = torch.get_autocast_gpu_dtype()
|
490 |
# Handle the case where the model is quantized
|
491 |
+
elif hasattr(self.config, "_pre_quantization_dtype"):
|
492 |
target_dtype = self.config._pre_quantization_dtype
|
493 |
else:
|
494 |
target_dtype = self.q_proj.weight.dtype
|
|
|
1097 |
|
1098 |
# Keep only the unprocessed tokens:
|
1099 |
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
|
1100 |
+
# some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
|
1101 |
# input)
|
1102 |
if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
|
1103 |
input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
|
|
|
1229 |
sequence_lengths = -1
|
1230 |
else:
|
1231 |
if input_ids is not None:
|
1232 |
+
# if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
|
1233 |
+
sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
|
1234 |
+
sequence_lengths = sequence_lengths % input_ids.shape[-1]
|
1235 |
+
sequence_lengths = sequence_lengths.to(logits.device)
|
1236 |
else:
|
1237 |
sequence_lengths = -1
|
1238 |
|