zhiqu22 commited on
Commit
23f06b4
·
1 Parent(s): 6ff51cd

improve kv cache

Browse files
Files changed (1) hide show
  1. modeling_mitre.py +143 -50
modeling_mitre.py CHANGED
@@ -280,22 +280,48 @@ class MitreDecoder(MitrePreTrainedModel):
280
  registers = input_ids[range(batch_size), torch.argmax(input_ids, dim=-1)].unsqueeze(1).repeat(1, max_register_nums)
281
  return registers, register_nums, total_token_nums
282
 
283
- def combine_src_and_registers(self, input_ids, registers, register_nums, total_token_nums):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
284
  '''
285
  return a expanded_src_tokens for positional embedding.
286
  '''
287
  pads = torch.full_like(registers, self.padding_idx)
288
  expanded_src_tokens = torch.cat((pads, input_ids, registers), dim=1)
289
- indices = torch.arange(total_token_nums).expand(input_ids.size(0), -1).to(input_ids.device)
290
- indices = indices + register_nums.unsqueeze(1)
291
-
292
- batch_indices = torch.arange(input_ids.shape[0]).unsqueeze(1).expand(-1, indices.size(1)).contiguous()
293
- return expanded_src_tokens, batch_indices, indices
 
 
 
 
 
 
 
 
294
 
295
  def fill_with_neg_inf(self, t):
296
  return t.float().fill_(float("-inf")).type_as(t)
 
 
 
297
 
298
- def build_future_mask(self, embeds, src_length, register_nums, padding_mask=None, past_key_values_length=0):
299
  b = register_nums.size(0)
300
  ns = src_length - register_nums
301
  if past_key_values_length == 0:
@@ -331,11 +357,6 @@ class MitreDecoder(MitrePreTrainedModel):
331
  batch_mask[batch_indices[target_indices], row_indices[target_indices], col_indices[target_indices]] = float('-inf')
332
  # shape: batch_size, head_num (1 for broadcasting), seq_len, seq_len
333
  batch_mask = batch_mask.unsqueeze(1)
334
- # 6. masking pads
335
- if padding_mask is not None:
336
- if padding_mask.any():
337
- padding_mask = padding_mask.to(batch_mask.device).unsqueeze(1).unsqueeze(2)
338
- batch_mask = batch_mask.masked_fill(padding_mask == 1, float('-inf'))
339
 
340
  elif past_key_values_length > 0:
341
  # in generation
@@ -350,7 +371,6 @@ class MitreDecoder(MitrePreTrainedModel):
350
  batch_mask[batch_indices[target_to_source_mask], token_indices[target_to_source_mask]] = float('-inf')
351
  batch_mask = batch_mask.unsqueeze(1)
352
 
353
- # ensure contiguous
354
  batch_mask = batch_mask.view(b, 1, batch_mask.shape[-2], batch_mask.shape[-1])
355
  return batch_mask
356
 
@@ -359,13 +379,12 @@ class MitreDecoder(MitrePreTrainedModel):
359
  self,
360
  input_ids: Optional[torch.Tensor] = None,
361
  decoder_input_ids: Optional[torch.Tensor] = None,
 
362
  past_key_values: Optional[List[torch.FloatTensor]] = None,
363
  use_cache: Optional[bool] = None,
364
- output_attentions: Optional[bool] = None,
365
  output_hidden_states: Optional[bool] = None,
366
  registering_cache: dict = None,
367
  ):
368
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
369
  output_hidden_states = (
370
  output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
371
  )
@@ -374,33 +393,49 @@ class MitreDecoder(MitrePreTrainedModel):
374
  # past_key_values_length
375
  past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
376
 
377
- decoder_input_shape = decoder_input_ids.size()
378
- decoder_input_ids = decoder_input_ids.view(-1, decoder_input_shape[-1])
379
- padding_mask = None
380
-
381
  if past_key_values_length > 0:
382
  register_nums = registering_cache["register_nums"]
383
  src_length = registering_cache["src_length"]
384
 
385
  if input_ids is not None and past_key_values_length == 0:
386
- # .view() additionally ensure that the memory is contiguous
387
- input_shape = input_ids.size()
388
- input_ids = input_ids.view(-1, input_shape[-1])
389
-
390
- registers, register_nums, total_token_nums = self.create_registers(input_ids)
391
- expanded_src_tokens, batch_indices, indices = self.combine_src_and_registers(input_ids, registers, register_nums, total_token_nums)
392
-
393
- # positional embedding for source tokens and registers
394
- inputs_embeds = self.embed_tokens(expanded_src_tokens)
395
- inputs_embeds_1 = inputs_embeds[:,:total_token_nums,:] + self.src_embed_positions(expanded_src_tokens[:,:total_token_nums])
396
- inputs_embeds_2 = inputs_embeds[:,total_token_nums:,:] + self.register_embed_positions(expanded_src_tokens[:,total_token_nums:])
397
- inputs_embeds = torch.cat((inputs_embeds_1, inputs_embeds_2), dim=1)
398
- inputs_embeds = inputs_embeds[batch_indices, indices]
399
 
 
 
 
 
 
 
 
 
 
400
 
401
- # padding mask
402
- source_tokens = expanded_src_tokens[batch_indices, indices]
403
- src_length = source_tokens.shape[1]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
404
 
405
  # replace the inference trigger with langtok
406
  # namely, enc-tgt-dec-tgt strategy
@@ -408,16 +443,47 @@ class MitreDecoder(MitrePreTrainedModel):
408
  decoder_input_ids[:, 0] = source_tokens[:, -1]
409
 
410
  tokens = torch.cat([source_tokens, decoder_input_ids], dim=1)
411
- padding_mask = tokens.eq(self.padding_idx)
412
 
413
  decoder_inputs_embeds = self.embed_tokens(decoder_input_ids)
414
  decoder_inputs_embeds = decoder_inputs_embeds + self.tgt_embed_positions(decoder_input_ids, past_key_values_length, src_length=src_length)
 
415
  if past_key_values_length == 0:
416
  hidden_states = torch.cat([inputs_embeds, decoder_inputs_embeds], dim=1)
417
  else:
418
  hidden_states = decoder_inputs_embeds
419
 
420
- attention_mask = self.build_future_mask(hidden_states, src_length, register_nums, padding_mask, past_key_values_length)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
421
  hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
422
 
423
  if self.gradient_checkpointing and self.training:
@@ -429,8 +495,6 @@ class MitreDecoder(MitrePreTrainedModel):
429
 
430
  # decoder layers
431
  all_hidden_states = () if output_hidden_states else None
432
- all_self_attns = () if output_attentions else None
433
- all_cross_attentions = () if output_attentions else None
434
  next_decoder_cache = () if use_cache else None
435
 
436
  for idx, decoder_layer in enumerate(self.layers):
@@ -458,7 +522,16 @@ class MitreDecoder(MitrePreTrainedModel):
458
  hidden_states = layer_outputs[0]
459
 
460
  if use_cache:
461
- next_decoder_cache += (layer_outputs[1],)
 
 
 
 
 
 
 
 
 
462
 
463
  if past_key_values_length == 0:
464
  hidden_states = hidden_states[:,src_length:,:]
@@ -475,13 +548,19 @@ class MitreDecoder(MitrePreTrainedModel):
475
  last_hidden_state=hidden_states,
476
  past_key_values=next_cache,
477
  hidden_states=all_hidden_states,
478
- attentions=all_self_attns,
479
- cross_attentions=all_cross_attentions,
480
  )
481
- model_output.registering_cache = {
482
- "register_nums": register_nums,
483
- "src_length": src_length
484
- }
 
 
 
 
 
 
 
 
485
  return model_output
486
 
487
 
@@ -579,6 +658,7 @@ class MitreModel(MitrePreTrainedModel):
579
  self,
580
  input_ids: Optional[torch.LongTensor] = None,
581
  decoder_input_ids: Optional[torch.Tensor] = None,
 
582
  past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
583
  use_cache: Optional[bool] = None,
584
  output_attentions: Optional[bool] = None,
@@ -594,6 +674,7 @@ class MitreModel(MitrePreTrainedModel):
594
  decoder_outputs = self.decoder(
595
  input_ids=input_ids,
596
  decoder_input_ids=decoder_input_ids,
 
597
  past_key_values=past_key_values,
598
  use_cache=use_cache,
599
  output_hidden_states=output_hidden_states,
@@ -634,15 +715,18 @@ class MitreForConditionalGeneration(MitrePreTrainedModel, GenerationMixin):
634
  self,
635
  input_ids: Optional[torch.LongTensor] = None,
636
  decoder_input_ids: Optional[torch.LongTensor] = None,
 
637
  past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
638
  labels: Optional[torch.LongTensor] = None,
639
  use_cache: Optional[bool] = None,
640
  output_hidden_states: Optional[bool] = None,
641
  registering_cache: dict = None,
642
  ) -> Union[Tuple[torch.Tensor], Seq2SeqLMOutput]:
 
643
  outputs = self.model(
644
  input_ids=input_ids,
645
  decoder_input_ids=decoder_input_ids,
 
646
  past_key_values=past_key_values,
647
  use_cache=use_cache,
648
  output_hidden_states=output_hidden_states,
@@ -674,8 +758,8 @@ class MitreForConditionalGeneration(MitrePreTrainedModel, GenerationMixin):
674
  return reordered_past
675
 
676
  @staticmethod
677
- def _reorder_register_nums(register_nums, beam_idx):
678
- return register_nums.index_select(0, beam_idx.to(register_nums.device))
679
 
680
  @staticmethod
681
  def _expand_inputs_for_generation(
@@ -752,6 +836,7 @@ class MitreForConditionalGeneration(MitrePreTrainedModel, GenerationMixin):
752
  this_peer_finished = False
753
  past_key_values = None
754
  registering_cache = None
 
755
 
756
  logits_processor = LogitsProcessorList()
757
  stopping_criteria = StoppingCriteriaList()
@@ -763,6 +848,12 @@ class MitreForConditionalGeneration(MitrePreTrainedModel, GenerationMixin):
763
 
764
  if past_key_values is not None:
765
  decoder_input_ids_for_generation = decoder_input_ids[:, -1:]
 
 
 
 
 
 
766
  else:
767
  decoder_input_ids_for_generation = decoder_input_ids
768
 
@@ -817,8 +908,10 @@ class MitreForConditionalGeneration(MitrePreTrainedModel, GenerationMixin):
817
  del outputs
818
 
819
  past_key_values = self._reorder_cache(past_key_values, beam_idx)
820
- registering_cache["register_nums"] = self._reorder_register_nums(registering_cache["register_nums"], beam_idx)
821
-
 
 
822
  cur_len = cur_len + 1
823
 
824
  if beam_scorer.is_done:
 
280
  registers = input_ids[range(batch_size), torch.argmax(input_ids, dim=-1)].unsqueeze(1).repeat(1, max_register_nums)
281
  return registers, register_nums, total_token_nums
282
 
283
+ def get_token_indices(self, input_ids, total_token_nums, register_nums):
284
+ '''
285
+ return a token_indices for selecting source tokens from expanded_src_tokens
286
+ '''
287
+ token_indices = torch.arange(total_token_nums).expand(input_ids.size(0), -1).to(input_ids.device)
288
+ token_indices = token_indices + register_nums.unsqueeze(1)
289
+ return token_indices
290
+
291
+ def get_batch_indices(self, input_ids, token_indices):
292
+ '''
293
+ return a batch_indices for selecting source tokens from expanded_src_tokens
294
+ '''
295
+ batch_indices = torch.arange(input_ids.shape[0]).unsqueeze(1).expand(-1, token_indices.size(1)).contiguous()
296
+ return batch_indices
297
+
298
+ def combine_src_and_registers(self, input_ids, registers):
299
  '''
300
  return a expanded_src_tokens for positional embedding.
301
  '''
302
  pads = torch.full_like(registers, self.padding_idx)
303
  expanded_src_tokens = torch.cat((pads, input_ids, registers), dim=1)
304
+ return expanded_src_tokens
305
+
306
+ def source_tokens_embedding_with_positions(self, expanded_src_tokens, total_token_nums, batch_indices, indices):
307
+ '''
308
+ return the embeds of source tokens
309
+ '''
310
+ inputs_embeds = self.embed_tokens(expanded_src_tokens)
311
+ inputs_embeds_1 = inputs_embeds[:,:total_token_nums,:] + self.src_embed_positions(expanded_src_tokens[:,:total_token_nums])
312
+ inputs_embeds_2 = inputs_embeds[:,total_token_nums:,:] + self.register_embed_positions(expanded_src_tokens[:,total_token_nums:])
313
+ inputs_embeds = torch.cat((inputs_embeds_1, inputs_embeds_2), dim=1)
314
+ inputs_embeds = inputs_embeds[batch_indices, indices]
315
+
316
+ return inputs_embeds
317
 
318
  def fill_with_neg_inf(self, t):
319
  return t.float().fill_(float("-inf")).type_as(t)
320
+
321
+ def check_contiguous(self, t: torch.Tensor):
322
+ return t if t.is_contiguous() else t.contiguous()
323
 
324
+ def build_future_mask(self, embeds, src_length, register_nums, past_key_values_length=0):
325
  b = register_nums.size(0)
326
  ns = src_length - register_nums
327
  if past_key_values_length == 0:
 
357
  batch_mask[batch_indices[target_indices], row_indices[target_indices], col_indices[target_indices]] = float('-inf')
358
  # shape: batch_size, head_num (1 for broadcasting), seq_len, seq_len
359
  batch_mask = batch_mask.unsqueeze(1)
 
 
 
 
 
360
 
361
  elif past_key_values_length > 0:
362
  # in generation
 
371
  batch_mask[batch_indices[target_to_source_mask], token_indices[target_to_source_mask]] = float('-inf')
372
  batch_mask = batch_mask.unsqueeze(1)
373
 
 
374
  batch_mask = batch_mask.view(b, 1, batch_mask.shape[-2], batch_mask.shape[-1])
375
  return batch_mask
376
 
 
379
  self,
380
  input_ids: Optional[torch.Tensor] = None,
381
  decoder_input_ids: Optional[torch.Tensor] = None,
382
+ attention_mask: Optional[torch.Tensor] = None,
383
  past_key_values: Optional[List[torch.FloatTensor]] = None,
384
  use_cache: Optional[bool] = None,
 
385
  output_hidden_states: Optional[bool] = None,
386
  registering_cache: dict = None,
387
  ):
 
388
  output_hidden_states = (
389
  output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
390
  )
 
393
  # past_key_values_length
394
  past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
395
 
 
 
 
 
396
  if past_key_values_length > 0:
397
  register_nums = registering_cache["register_nums"]
398
  src_length = registering_cache["src_length"]
399
 
400
  if input_ids is not None and past_key_values_length == 0:
401
+ # ensure contiguous
402
+ input_ids = self.check_contiguous(input_ids)
403
+ decoder_input_ids = self.check_contiguous(decoder_input_ids)
 
 
 
 
 
 
 
 
 
 
404
 
405
+ if attention_mask is None:
406
+ # create registers from input_ids
407
+ registers, register_nums, total_token_nums = self.create_registers(input_ids)
408
+ # 'expanded_src_tokens' is combined by input_ids, registers, and pads.
409
+ expanded_src_tokens = self.combine_src_and_registers(input_ids, registers)
410
+ token_indices = self.get_token_indices(input_ids, total_token_nums, register_nums)
411
+ batch_indices = self.get_batch_indices(input_ids, token_indices)
412
+ # source tokens (input_ids + registers)
413
+ source_tokens = expanded_src_tokens[batch_indices, token_indices]
414
 
415
+ else:
416
+ # although we do not give the attention mask in training and the 1st step of generation,
417
+ # we still leave this block here.
418
+ if registering_cache is None or \
419
+ not all(key in registering_cache for key in \
420
+ ("register_nums", "total_token_nums", "expanded_src_tokens",\
421
+ "batch_indices", "token_indices", "source_tokens")):
422
+ raise ValueError(
423
+ "If you generate registers by external codes, \
424
+ you must provide 'register_nums', 'total_token_nums', \
425
+ 'expanded_src_tokens', 'batch_indices', 'token_indices' \
426
+ and 'source_tokens' in 'registering_cache' in the training."
427
+ )
428
+ register_nums, total_token_nums = registering_cache["register_nums"], registering_cache["total_token_nums"]
429
+ expanded_src_tokens = registering_cache["expanded_src_tokens"]
430
+ batch_indices, token_indices = registering_cache["batch_indices"], registering_cache["token_indices"]
431
+ source_tokens = registering_cache["source_tokens"]
432
+
433
+ # ensure contiguous
434
+ expanded_src_tokens = self.check_contiguous(expanded_src_tokens)
435
+ source_tokens = self.check_contiguous(source_tokens)
436
+
437
+ # get embeds with positions for source tokens (input_ids + registers)
438
+ inputs_embeds = self.source_tokens_embedding_with_positions(expanded_src_tokens, total_token_nums, batch_indices, token_indices)
439
 
440
  # replace the inference trigger with langtok
441
  # namely, enc-tgt-dec-tgt strategy
 
443
  decoder_input_ids[:, 0] = source_tokens[:, -1]
444
 
445
  tokens = torch.cat([source_tokens, decoder_input_ids], dim=1)
446
+ src_length = source_tokens.shape[1]
447
 
448
  decoder_inputs_embeds = self.embed_tokens(decoder_input_ids)
449
  decoder_inputs_embeds = decoder_inputs_embeds + self.tgt_embed_positions(decoder_input_ids, past_key_values_length, src_length=src_length)
450
+
451
  if past_key_values_length == 0:
452
  hidden_states = torch.cat([inputs_embeds, decoder_inputs_embeds], dim=1)
453
  else:
454
  hidden_states = decoder_inputs_embeds
455
 
456
+ # ensure contiguous
457
+ hidden_states = self.check_contiguous(hidden_states)
458
+
459
+ # if attention_mask is NOT given, we build the attention mask from current hyperparams
460
+ # if attention_mask is given, check the shape of attention mask
461
+ if attention_mask is None:
462
+ attention_mask = self.build_future_mask(hidden_states, src_length, register_nums, past_key_values_length)
463
+ else:
464
+ bsz, src_len = hidden_states.shape[0], hidden_states.shape[1]
465
+ tgt_len = hidden_states.shape[1] if past_key_values_length == 0 else past_key_values_length + 1
466
+ if attention_mask.size() != (bsz, 1, src_len, tgt_len):
467
+ raise ValueError(
468
+ f"Attention mask should be of size {(bsz, 1, src_len, tgt_len)}, but is {attention_mask.size()}"
469
+ )
470
+
471
+ # ensure contiguous
472
+ attention_mask = self.check_contiguous(attention_mask)
473
+
474
+ # this is a param to turncate kv cache
475
+ # in training, it's None, namely, unactivated.
476
+ max_register_num = None
477
+ # masking pads for attention_mask in the training or the 1st step of generation
478
+ if past_key_values_length == 0:
479
+ # if in generation, activate
480
+ max_register_num = register_nums.max().item() if use_cache else None
481
+
482
+ padding_mask = tokens.eq(self.padding_idx)
483
+ if padding_mask.any():
484
+ padding_mask = padding_mask.unsqueeze(1).unsqueeze(2)
485
+ attention_mask = attention_mask.masked_fill(padding_mask == 1, float('-inf'))
486
+
487
  hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
488
 
489
  if self.gradient_checkpointing and self.training:
 
495
 
496
  # decoder layers
497
  all_hidden_states = () if output_hidden_states else None
 
 
498
  next_decoder_cache = () if use_cache else None
499
 
500
  for idx, decoder_layer in enumerate(self.layers):
 
522
  hidden_states = layer_outputs[0]
523
 
524
  if use_cache:
525
+ if past_key_values_length > 0:
526
+ next_decoder_cache += (layer_outputs[1],)
527
+ else:
528
+ cache_key, cache_value = layer_outputs[1]
529
+ clipped_rep = (
530
+ cache_key[:, :, src_length - max_register_num:, :],
531
+ cache_value[:, :, src_length - max_register_num:, :]
532
+ )
533
+ next_decoder_cache += (clipped_rep,)
534
+
535
 
536
  if past_key_values_length == 0:
537
  hidden_states = hidden_states[:,src_length:,:]
 
548
  last_hidden_state=hidden_states,
549
  past_key_values=next_cache,
550
  hidden_states=all_hidden_states,
 
 
551
  )
552
+
553
+ # the registering cache used in generation
554
+ # in the 1st step, we turncate the kv cache to save cost, so we have to change the src_length
555
+ if use_cache:
556
+ model_output.registering_cache = {
557
+ "register_nums": register_nums,
558
+ "src_length": src_length if past_key_values_length > 0 else max_register_num,
559
+ "attention_mask": attention_mask if past_key_values_length > 0 else None
560
+ }
561
+ else:
562
+ model_output.registering_cache = None
563
+
564
  return model_output
565
 
566
 
 
658
  self,
659
  input_ids: Optional[torch.LongTensor] = None,
660
  decoder_input_ids: Optional[torch.Tensor] = None,
661
+ attention_mask: Optional[torch.Tensor] = None,
662
  past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
663
  use_cache: Optional[bool] = None,
664
  output_attentions: Optional[bool] = None,
 
674
  decoder_outputs = self.decoder(
675
  input_ids=input_ids,
676
  decoder_input_ids=decoder_input_ids,
677
+ attention_mask=attention_mask,
678
  past_key_values=past_key_values,
679
  use_cache=use_cache,
680
  output_hidden_states=output_hidden_states,
 
715
  self,
716
  input_ids: Optional[torch.LongTensor] = None,
717
  decoder_input_ids: Optional[torch.LongTensor] = None,
718
+ attention_mask: Optional[torch.Tensor] = None,
719
  past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
720
  labels: Optional[torch.LongTensor] = None,
721
  use_cache: Optional[bool] = None,
722
  output_hidden_states: Optional[bool] = None,
723
  registering_cache: dict = None,
724
  ) -> Union[Tuple[torch.Tensor], Seq2SeqLMOutput]:
725
+
726
  outputs = self.model(
727
  input_ids=input_ids,
728
  decoder_input_ids=decoder_input_ids,
729
+ attention_mask=attention_mask,
730
  past_key_values=past_key_values,
731
  use_cache=use_cache,
732
  output_hidden_states=output_hidden_states,
 
758
  return reordered_past
759
 
760
  @staticmethod
761
+ def _reorder_register_cache(t, beam_idx):
762
+ return t.index_select(dim=0, index=beam_idx.to(t.device))
763
 
764
  @staticmethod
765
  def _expand_inputs_for_generation(
 
836
  this_peer_finished = False
837
  past_key_values = None
838
  registering_cache = None
839
+ attention_mask = None
840
 
841
  logits_processor = LogitsProcessorList()
842
  stopping_criteria = StoppingCriteriaList()
 
848
 
849
  if past_key_values is not None:
850
  decoder_input_ids_for_generation = decoder_input_ids[:, -1:]
851
+ attention_mask = registering_cache["attention_mask"]
852
+ # Get the mask when the first time using kv cache.
853
+ # After it, we can simply repeat 0. (the last column of mask) to get the next mask.
854
+ # As a result, we avoid generate the mask from scratch in kv cache and save memory.
855
+ if attention_mask is not None:
856
+ attention_mask = torch.cat((attention_mask, attention_mask[..., -1:]), dim=-1)
857
  else:
858
  decoder_input_ids_for_generation = decoder_input_ids
859
 
 
908
  del outputs
909
 
910
  past_key_values = self._reorder_cache(past_key_values, beam_idx)
911
+ registering_cache["register_nums"] = self._reorder_register_cache(registering_cache["register_nums"], beam_idx)
912
+ if registering_cache["attention_mask"] is not None:
913
+ registering_cache["attention_mask"] = self._reorder_register_cache(registering_cache["attention_mask"], beam_idx)
914
+
915
  cur_len = cur_len + 1
916
 
917
  if beam_scorer.is_done: