gugarosa commited on
Commit
5fd430c
·
1 Parent(s): d212a78

Fixes exceeding maximum sequence length when using generate().

Browse files
Files changed (1) hide show
  1. modeling_phi.py +16 -8
modeling_phi.py CHANGED
@@ -481,7 +481,7 @@ def _update_kv_cache(kv: torch.FloatTensor, inference_params: InferenceParams, l
481
  num_heads, head_dim = kv.shape[-2:]
482
 
483
  if layer_idx not in inference_params.key_value_memory_dict:
484
- kv_cache = torch.empty(
485
  inference_params.max_batch_size,
486
  inference_params.max_seqlen,
487
  2,
@@ -490,9 +490,6 @@ def _update_kv_cache(kv: torch.FloatTensor, inference_params: InferenceParams, l
490
  dtype=kv.dtype,
491
  device=kv.device,
492
  )
493
- inference_params.key_value_memory_dict[layer_idx] = kv_cache
494
- else:
495
- kv_cache = inference_params.key_value_memory_dict[layer_idx]
496
 
497
  batch_start = inference_params.batch_size_offset
498
  batch_end = batch_start + kv.shape[0]
@@ -500,9 +497,14 @@ def _update_kv_cache(kv: torch.FloatTensor, inference_params: InferenceParams, l
500
  sequence_start = inference_params.seqlen_offset
501
  sequence_end = sequence_start + kv.shape[1]
502
 
503
- kv_cache[batch_start:batch_end, sequence_start:sequence_end, ...] = kv
504
- kv = kv_cache[batch_start:batch_end, :sequence_end, ...]
 
 
505
 
 
 
 
506
  return kv
507
 
508
 
@@ -710,7 +712,6 @@ class MHA(nn.Module):
710
  attention_mask: Optional[Union[torch.LongTensor, torch.BoolTensor]] = None,
711
  **kwargs,
712
  ) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
713
- # TODO: Need an alternative way for dynamic control flow: torch.any(~attention_mask.bool())
714
  if attention_mask is not None:
715
  attention_mask = attention_mask.bool()
716
  else:
@@ -863,6 +864,13 @@ class PhiPreTrainedModel(PreTrainedModel):
863
  attention_mask: Optional[Union[torch.LongTensor, torch.BoolTensor]] = None,
864
  **kwargs,
865
  ) -> Dict[str, Any]:
 
 
 
 
 
 
 
866
  if past_key_values is None or not (isinstance(past_key_values, InferenceParams)):
867
  past_key_values = InferenceParams(
868
  max_seqlen=self.config.n_positions,
@@ -874,7 +882,7 @@ class PhiPreTrainedModel(PreTrainedModel):
874
  )
875
  else:
876
  # Assume that `past_key_values` has cached all tokens up to the last token in `input_ids`
877
- past_key_values.seqlen_offset = len(input_ids[0]) - 1
878
  input_ids = input_ids[:, -1].unsqueeze(-1)
879
 
880
  return {
 
481
  num_heads, head_dim = kv.shape[-2:]
482
 
483
  if layer_idx not in inference_params.key_value_memory_dict:
484
+ inference_params.key_value_memory_dict[layer_idx] = torch.empty(
485
  inference_params.max_batch_size,
486
  inference_params.max_seqlen,
487
  2,
 
490
  dtype=kv.dtype,
491
  device=kv.device,
492
  )
 
 
 
493
 
494
  batch_start = inference_params.batch_size_offset
495
  batch_end = batch_start + kv.shape[0]
 
497
  sequence_start = inference_params.seqlen_offset
498
  sequence_end = sequence_start + kv.shape[1]
499
 
500
+ # When the current sequence length is equal to or larger than the maximum sequence length,
501
+ # we need to roll the cache to the left and update it
502
+ if sequence_end >= inference_params.max_seqlen:
503
+ inference_params.key_value_memory_dict[layer_idx] = inference_params.key_value_memory_dict[layer_idx].roll(-(sequence_end - sequence_start), 1)
504
 
505
+ inference_params.key_value_memory_dict[layer_idx][batch_start:batch_end, sequence_start:sequence_end, ...] = kv
506
+ kv = inference_params.key_value_memory_dict[layer_idx][batch_start:batch_end, :sequence_end, ...]
507
+
508
  return kv
509
 
510
 
 
712
  attention_mask: Optional[Union[torch.LongTensor, torch.BoolTensor]] = None,
713
  **kwargs,
714
  ) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
 
715
  if attention_mask is not None:
716
  attention_mask = attention_mask.bool()
717
  else:
 
864
  attention_mask: Optional[Union[torch.LongTensor, torch.BoolTensor]] = None,
865
  **kwargs,
866
  ) -> Dict[str, Any]:
867
+ # Truncate `input_ids` and `attention_mask` (if necessary) to prevent exceeding
868
+ # the maximum sequence length
869
+ if input_ids.shape[1] > self.config.n_positions:
870
+ input_ids = input_ids[:, -self.config.n_positions :]
871
+ if attention_mask is not None:
872
+ attention_mask = attention_mask[:, -self.config.n_positions :]
873
+
874
  if past_key_values is None or not (isinstance(past_key_values, InferenceParams)):
875
  past_key_values = InferenceParams(
876
  max_seqlen=self.config.n_positions,
 
882
  )
883
  else:
884
  # Assume that `past_key_values` has cached all tokens up to the last token in `input_ids`
885
+ past_key_values.seqlen_offset = input_ids.shape[1] - 1
886
  input_ids = input_ids[:, -1].unsqueeze(-1)
887
 
888
  return {