gugarosa commited on
Commit
914c8fb
1 Parent(s): 3a705a2

Upload modeling_phi.py

Browse files
Files changed (1) hide show
  1. modeling_phi.py +12 -7
modeling_phi.py CHANGED
@@ -358,7 +358,9 @@ class PhiAttention(nn.Module):
358
  key_states = repeat_kv(key_states, self.num_key_value_groups)
359
  value_states = repeat_kv(value_states, self.num_key_value_groups)
360
 
361
- attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
 
 
362
 
363
  if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
364
  raise ValueError(
@@ -374,7 +376,7 @@ class PhiAttention(nn.Module):
374
  attn_weights = attn_weights + attention_mask
375
 
376
  # upcast attention to fp32
377
- attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
378
  attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
379
 
380
  attn_output = torch.matmul(attn_weights, value_states)
@@ -483,8 +485,10 @@ class PhiFlashAttention2(PhiAttention):
483
  # in fp32.
484
 
485
  if query_states.dtype == torch.float32:
 
 
486
  # Handle the case where the model is quantized
487
- if hasattr(self.config, "_pre_quantization_dtype"):
488
  target_dtype = self.config._pre_quantization_dtype
489
  else:
490
  target_dtype = self.q_proj.weight.dtype
@@ -1093,7 +1097,7 @@ class PhiForCausalLM(PhiPreTrainedModel):
1093
 
1094
  # Keep only the unprocessed tokens:
1095
  # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
1096
- # some of the inputs are exclusivelly passed as part of the cache (e.g. when passing input_embeds as
1097
  # input)
1098
  if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
1099
  input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
@@ -1225,9 +1229,10 @@ class PhiForSequenceClassification(PhiPreTrainedModel):
1225
  sequence_lengths = -1
1226
  else:
1227
  if input_ids is not None:
1228
- sequence_lengths = (torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1).to(
1229
- logits.device
1230
- )
 
1231
  else:
1232
  sequence_lengths = -1
1233
 
 
358
  key_states = repeat_kv(key_states, self.num_key_value_groups)
359
  value_states = repeat_kv(value_states, self.num_key_value_groups)
360
 
361
+ attn_weights = torch.matmul(
362
+ query_states.to(torch.float32), key_states.to(torch.float32).transpose(2, 3)
363
+ ) / math.sqrt(self.head_dim)
364
 
365
  if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
366
  raise ValueError(
 
376
  attn_weights = attn_weights + attention_mask
377
 
378
  # upcast attention to fp32
379
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(value_states.dtype)
380
  attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
381
 
382
  attn_output = torch.matmul(attn_weights, value_states)
 
485
  # in fp32.
486
 
487
  if query_states.dtype == torch.float32:
488
+ if torch.is_autocast_enabled():
489
+ target_dtype = torch.get_autocast_gpu_dtype()
490
  # Handle the case where the model is quantized
491
+ elif hasattr(self.config, "_pre_quantization_dtype"):
492
  target_dtype = self.config._pre_quantization_dtype
493
  else:
494
  target_dtype = self.q_proj.weight.dtype
 
1097
 
1098
  # Keep only the unprocessed tokens:
1099
  # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
1100
+ # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
1101
  # input)
1102
  if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
1103
  input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
 
1229
  sequence_lengths = -1
1230
  else:
1231
  if input_ids is not None:
1232
+ # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
1233
+ sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
1234
+ sequence_lengths = sequence_lengths % input_ids.shape[-1]
1235
+ sequence_lengths = sequence_lengths.to(logits.device)
1236
  else:
1237
  sequence_lengths = -1
1238