zhiqu22
commited on
Commit
·
23f06b4
1
Parent(s):
6ff51cd
improve kv cache
Browse files- modeling_mitre.py +143 -50
modeling_mitre.py
CHANGED
@@ -280,22 +280,48 @@ class MitreDecoder(MitrePreTrainedModel):
|
|
280 |
registers = input_ids[range(batch_size), torch.argmax(input_ids, dim=-1)].unsqueeze(1).repeat(1, max_register_nums)
|
281 |
return registers, register_nums, total_token_nums
|
282 |
|
283 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
284 |
'''
|
285 |
return a expanded_src_tokens for positional embedding.
|
286 |
'''
|
287 |
pads = torch.full_like(registers, self.padding_idx)
|
288 |
expanded_src_tokens = torch.cat((pads, input_ids, registers), dim=1)
|
289 |
-
|
290 |
-
|
291 |
-
|
292 |
-
|
293 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
294 |
|
295 |
def fill_with_neg_inf(self, t):
|
296 |
return t.float().fill_(float("-inf")).type_as(t)
|
|
|
|
|
|
|
297 |
|
298 |
-
def build_future_mask(self, embeds, src_length, register_nums,
|
299 |
b = register_nums.size(0)
|
300 |
ns = src_length - register_nums
|
301 |
if past_key_values_length == 0:
|
@@ -331,11 +357,6 @@ class MitreDecoder(MitrePreTrainedModel):
|
|
331 |
batch_mask[batch_indices[target_indices], row_indices[target_indices], col_indices[target_indices]] = float('-inf')
|
332 |
# shape: batch_size, head_num (1 for broadcasting), seq_len, seq_len
|
333 |
batch_mask = batch_mask.unsqueeze(1)
|
334 |
-
# 6. masking pads
|
335 |
-
if padding_mask is not None:
|
336 |
-
if padding_mask.any():
|
337 |
-
padding_mask = padding_mask.to(batch_mask.device).unsqueeze(1).unsqueeze(2)
|
338 |
-
batch_mask = batch_mask.masked_fill(padding_mask == 1, float('-inf'))
|
339 |
|
340 |
elif past_key_values_length > 0:
|
341 |
# in generation
|
@@ -350,7 +371,6 @@ class MitreDecoder(MitrePreTrainedModel):
|
|
350 |
batch_mask[batch_indices[target_to_source_mask], token_indices[target_to_source_mask]] = float('-inf')
|
351 |
batch_mask = batch_mask.unsqueeze(1)
|
352 |
|
353 |
-
# ensure contiguous
|
354 |
batch_mask = batch_mask.view(b, 1, batch_mask.shape[-2], batch_mask.shape[-1])
|
355 |
return batch_mask
|
356 |
|
@@ -359,13 +379,12 @@ class MitreDecoder(MitrePreTrainedModel):
|
|
359 |
self,
|
360 |
input_ids: Optional[torch.Tensor] = None,
|
361 |
decoder_input_ids: Optional[torch.Tensor] = None,
|
|
|
362 |
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
363 |
use_cache: Optional[bool] = None,
|
364 |
-
output_attentions: Optional[bool] = None,
|
365 |
output_hidden_states: Optional[bool] = None,
|
366 |
registering_cache: dict = None,
|
367 |
):
|
368 |
-
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
369 |
output_hidden_states = (
|
370 |
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
371 |
)
|
@@ -374,33 +393,49 @@ class MitreDecoder(MitrePreTrainedModel):
|
|
374 |
# past_key_values_length
|
375 |
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
|
376 |
|
377 |
-
decoder_input_shape = decoder_input_ids.size()
|
378 |
-
decoder_input_ids = decoder_input_ids.view(-1, decoder_input_shape[-1])
|
379 |
-
padding_mask = None
|
380 |
-
|
381 |
if past_key_values_length > 0:
|
382 |
register_nums = registering_cache["register_nums"]
|
383 |
src_length = registering_cache["src_length"]
|
384 |
|
385 |
if input_ids is not None and past_key_values_length == 0:
|
386 |
-
#
|
387 |
-
|
388 |
-
|
389 |
-
|
390 |
-
registers, register_nums, total_token_nums = self.create_registers(input_ids)
|
391 |
-
expanded_src_tokens, batch_indices, indices = self.combine_src_and_registers(input_ids, registers, register_nums, total_token_nums)
|
392 |
-
|
393 |
-
# positional embedding for source tokens and registers
|
394 |
-
inputs_embeds = self.embed_tokens(expanded_src_tokens)
|
395 |
-
inputs_embeds_1 = inputs_embeds[:,:total_token_nums,:] + self.src_embed_positions(expanded_src_tokens[:,:total_token_nums])
|
396 |
-
inputs_embeds_2 = inputs_embeds[:,total_token_nums:,:] + self.register_embed_positions(expanded_src_tokens[:,total_token_nums:])
|
397 |
-
inputs_embeds = torch.cat((inputs_embeds_1, inputs_embeds_2), dim=1)
|
398 |
-
inputs_embeds = inputs_embeds[batch_indices, indices]
|
399 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
400 |
|
401 |
-
|
402 |
-
|
403 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
404 |
|
405 |
# replace the inference trigger with langtok
|
406 |
# namely, enc-tgt-dec-tgt strategy
|
@@ -408,16 +443,47 @@ class MitreDecoder(MitrePreTrainedModel):
|
|
408 |
decoder_input_ids[:, 0] = source_tokens[:, -1]
|
409 |
|
410 |
tokens = torch.cat([source_tokens, decoder_input_ids], dim=1)
|
411 |
-
|
412 |
|
413 |
decoder_inputs_embeds = self.embed_tokens(decoder_input_ids)
|
414 |
decoder_inputs_embeds = decoder_inputs_embeds + self.tgt_embed_positions(decoder_input_ids, past_key_values_length, src_length=src_length)
|
|
|
415 |
if past_key_values_length == 0:
|
416 |
hidden_states = torch.cat([inputs_embeds, decoder_inputs_embeds], dim=1)
|
417 |
else:
|
418 |
hidden_states = decoder_inputs_embeds
|
419 |
|
420 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
421 |
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
422 |
|
423 |
if self.gradient_checkpointing and self.training:
|
@@ -429,8 +495,6 @@ class MitreDecoder(MitrePreTrainedModel):
|
|
429 |
|
430 |
# decoder layers
|
431 |
all_hidden_states = () if output_hidden_states else None
|
432 |
-
all_self_attns = () if output_attentions else None
|
433 |
-
all_cross_attentions = () if output_attentions else None
|
434 |
next_decoder_cache = () if use_cache else None
|
435 |
|
436 |
for idx, decoder_layer in enumerate(self.layers):
|
@@ -458,7 +522,16 @@ class MitreDecoder(MitrePreTrainedModel):
|
|
458 |
hidden_states = layer_outputs[0]
|
459 |
|
460 |
if use_cache:
|
461 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
462 |
|
463 |
if past_key_values_length == 0:
|
464 |
hidden_states = hidden_states[:,src_length:,:]
|
@@ -475,13 +548,19 @@ class MitreDecoder(MitrePreTrainedModel):
|
|
475 |
last_hidden_state=hidden_states,
|
476 |
past_key_values=next_cache,
|
477 |
hidden_states=all_hidden_states,
|
478 |
-
attentions=all_self_attns,
|
479 |
-
cross_attentions=all_cross_attentions,
|
480 |
)
|
481 |
-
|
482 |
-
|
483 |
-
|
484 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
485 |
return model_output
|
486 |
|
487 |
|
@@ -579,6 +658,7 @@ class MitreModel(MitrePreTrainedModel):
|
|
579 |
self,
|
580 |
input_ids: Optional[torch.LongTensor] = None,
|
581 |
decoder_input_ids: Optional[torch.Tensor] = None,
|
|
|
582 |
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
583 |
use_cache: Optional[bool] = None,
|
584 |
output_attentions: Optional[bool] = None,
|
@@ -594,6 +674,7 @@ class MitreModel(MitrePreTrainedModel):
|
|
594 |
decoder_outputs = self.decoder(
|
595 |
input_ids=input_ids,
|
596 |
decoder_input_ids=decoder_input_ids,
|
|
|
597 |
past_key_values=past_key_values,
|
598 |
use_cache=use_cache,
|
599 |
output_hidden_states=output_hidden_states,
|
@@ -634,15 +715,18 @@ class MitreForConditionalGeneration(MitrePreTrainedModel, GenerationMixin):
|
|
634 |
self,
|
635 |
input_ids: Optional[torch.LongTensor] = None,
|
636 |
decoder_input_ids: Optional[torch.LongTensor] = None,
|
|
|
637 |
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
638 |
labels: Optional[torch.LongTensor] = None,
|
639 |
use_cache: Optional[bool] = None,
|
640 |
output_hidden_states: Optional[bool] = None,
|
641 |
registering_cache: dict = None,
|
642 |
) -> Union[Tuple[torch.Tensor], Seq2SeqLMOutput]:
|
|
|
643 |
outputs = self.model(
|
644 |
input_ids=input_ids,
|
645 |
decoder_input_ids=decoder_input_ids,
|
|
|
646 |
past_key_values=past_key_values,
|
647 |
use_cache=use_cache,
|
648 |
output_hidden_states=output_hidden_states,
|
@@ -674,8 +758,8 @@ class MitreForConditionalGeneration(MitrePreTrainedModel, GenerationMixin):
|
|
674 |
return reordered_past
|
675 |
|
676 |
@staticmethod
|
677 |
-
def
|
678 |
-
return
|
679 |
|
680 |
@staticmethod
|
681 |
def _expand_inputs_for_generation(
|
@@ -752,6 +836,7 @@ class MitreForConditionalGeneration(MitrePreTrainedModel, GenerationMixin):
|
|
752 |
this_peer_finished = False
|
753 |
past_key_values = None
|
754 |
registering_cache = None
|
|
|
755 |
|
756 |
logits_processor = LogitsProcessorList()
|
757 |
stopping_criteria = StoppingCriteriaList()
|
@@ -763,6 +848,12 @@ class MitreForConditionalGeneration(MitrePreTrainedModel, GenerationMixin):
|
|
763 |
|
764 |
if past_key_values is not None:
|
765 |
decoder_input_ids_for_generation = decoder_input_ids[:, -1:]
|
|
|
|
|
|
|
|
|
|
|
|
|
766 |
else:
|
767 |
decoder_input_ids_for_generation = decoder_input_ids
|
768 |
|
@@ -817,8 +908,10 @@ class MitreForConditionalGeneration(MitrePreTrainedModel, GenerationMixin):
|
|
817 |
del outputs
|
818 |
|
819 |
past_key_values = self._reorder_cache(past_key_values, beam_idx)
|
820 |
-
registering_cache["register_nums"] = self.
|
821 |
-
|
|
|
|
|
822 |
cur_len = cur_len + 1
|
823 |
|
824 |
if beam_scorer.is_done:
|
|
|
280 |
registers = input_ids[range(batch_size), torch.argmax(input_ids, dim=-1)].unsqueeze(1).repeat(1, max_register_nums)
|
281 |
return registers, register_nums, total_token_nums
|
282 |
|
283 |
+
def get_token_indices(self, input_ids, total_token_nums, register_nums):
|
284 |
+
'''
|
285 |
+
return a token_indices for selecting source tokens from expanded_src_tokens
|
286 |
+
'''
|
287 |
+
token_indices = torch.arange(total_token_nums).expand(input_ids.size(0), -1).to(input_ids.device)
|
288 |
+
token_indices = token_indices + register_nums.unsqueeze(1)
|
289 |
+
return token_indices
|
290 |
+
|
291 |
+
def get_batch_indices(self, input_ids, token_indices):
|
292 |
+
'''
|
293 |
+
return a batch_indices for selecting source tokens from expanded_src_tokens
|
294 |
+
'''
|
295 |
+
batch_indices = torch.arange(input_ids.shape[0]).unsqueeze(1).expand(-1, token_indices.size(1)).contiguous()
|
296 |
+
return batch_indices
|
297 |
+
|
298 |
+
def combine_src_and_registers(self, input_ids, registers):
|
299 |
'''
|
300 |
return a expanded_src_tokens for positional embedding.
|
301 |
'''
|
302 |
pads = torch.full_like(registers, self.padding_idx)
|
303 |
expanded_src_tokens = torch.cat((pads, input_ids, registers), dim=1)
|
304 |
+
return expanded_src_tokens
|
305 |
+
|
306 |
+
def source_tokens_embedding_with_positions(self, expanded_src_tokens, total_token_nums, batch_indices, indices):
|
307 |
+
'''
|
308 |
+
return the embeds of source tokens
|
309 |
+
'''
|
310 |
+
inputs_embeds = self.embed_tokens(expanded_src_tokens)
|
311 |
+
inputs_embeds_1 = inputs_embeds[:,:total_token_nums,:] + self.src_embed_positions(expanded_src_tokens[:,:total_token_nums])
|
312 |
+
inputs_embeds_2 = inputs_embeds[:,total_token_nums:,:] + self.register_embed_positions(expanded_src_tokens[:,total_token_nums:])
|
313 |
+
inputs_embeds = torch.cat((inputs_embeds_1, inputs_embeds_2), dim=1)
|
314 |
+
inputs_embeds = inputs_embeds[batch_indices, indices]
|
315 |
+
|
316 |
+
return inputs_embeds
|
317 |
|
318 |
def fill_with_neg_inf(self, t):
|
319 |
return t.float().fill_(float("-inf")).type_as(t)
|
320 |
+
|
321 |
+
def check_contiguous(self, t: torch.Tensor):
|
322 |
+
return t if t.is_contiguous() else t.contiguous()
|
323 |
|
324 |
+
def build_future_mask(self, embeds, src_length, register_nums, past_key_values_length=0):
|
325 |
b = register_nums.size(0)
|
326 |
ns = src_length - register_nums
|
327 |
if past_key_values_length == 0:
|
|
|
357 |
batch_mask[batch_indices[target_indices], row_indices[target_indices], col_indices[target_indices]] = float('-inf')
|
358 |
# shape: batch_size, head_num (1 for broadcasting), seq_len, seq_len
|
359 |
batch_mask = batch_mask.unsqueeze(1)
|
|
|
|
|
|
|
|
|
|
|
360 |
|
361 |
elif past_key_values_length > 0:
|
362 |
# in generation
|
|
|
371 |
batch_mask[batch_indices[target_to_source_mask], token_indices[target_to_source_mask]] = float('-inf')
|
372 |
batch_mask = batch_mask.unsqueeze(1)
|
373 |
|
|
|
374 |
batch_mask = batch_mask.view(b, 1, batch_mask.shape[-2], batch_mask.shape[-1])
|
375 |
return batch_mask
|
376 |
|
|
|
379 |
self,
|
380 |
input_ids: Optional[torch.Tensor] = None,
|
381 |
decoder_input_ids: Optional[torch.Tensor] = None,
|
382 |
+
attention_mask: Optional[torch.Tensor] = None,
|
383 |
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
384 |
use_cache: Optional[bool] = None,
|
|
|
385 |
output_hidden_states: Optional[bool] = None,
|
386 |
registering_cache: dict = None,
|
387 |
):
|
|
|
388 |
output_hidden_states = (
|
389 |
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
390 |
)
|
|
|
393 |
# past_key_values_length
|
394 |
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
|
395 |
|
|
|
|
|
|
|
|
|
396 |
if past_key_values_length > 0:
|
397 |
register_nums = registering_cache["register_nums"]
|
398 |
src_length = registering_cache["src_length"]
|
399 |
|
400 |
if input_ids is not None and past_key_values_length == 0:
|
401 |
+
# ensure contiguous
|
402 |
+
input_ids = self.check_contiguous(input_ids)
|
403 |
+
decoder_input_ids = self.check_contiguous(decoder_input_ids)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
404 |
|
405 |
+
if attention_mask is None:
|
406 |
+
# create registers from input_ids
|
407 |
+
registers, register_nums, total_token_nums = self.create_registers(input_ids)
|
408 |
+
# 'expanded_src_tokens' is combined by input_ids, registers, and pads.
|
409 |
+
expanded_src_tokens = self.combine_src_and_registers(input_ids, registers)
|
410 |
+
token_indices = self.get_token_indices(input_ids, total_token_nums, register_nums)
|
411 |
+
batch_indices = self.get_batch_indices(input_ids, token_indices)
|
412 |
+
# source tokens (input_ids + registers)
|
413 |
+
source_tokens = expanded_src_tokens[batch_indices, token_indices]
|
414 |
|
415 |
+
else:
|
416 |
+
# although we do not give the attention mask in training and the 1st step of generation,
|
417 |
+
# we still leave this block here.
|
418 |
+
if registering_cache is None or \
|
419 |
+
not all(key in registering_cache for key in \
|
420 |
+
("register_nums", "total_token_nums", "expanded_src_tokens",\
|
421 |
+
"batch_indices", "token_indices", "source_tokens")):
|
422 |
+
raise ValueError(
|
423 |
+
"If you generate registers by external codes, \
|
424 |
+
you must provide 'register_nums', 'total_token_nums', \
|
425 |
+
'expanded_src_tokens', 'batch_indices', 'token_indices' \
|
426 |
+
and 'source_tokens' in 'registering_cache' in the training."
|
427 |
+
)
|
428 |
+
register_nums, total_token_nums = registering_cache["register_nums"], registering_cache["total_token_nums"]
|
429 |
+
expanded_src_tokens = registering_cache["expanded_src_tokens"]
|
430 |
+
batch_indices, token_indices = registering_cache["batch_indices"], registering_cache["token_indices"]
|
431 |
+
source_tokens = registering_cache["source_tokens"]
|
432 |
+
|
433 |
+
# ensure contiguous
|
434 |
+
expanded_src_tokens = self.check_contiguous(expanded_src_tokens)
|
435 |
+
source_tokens = self.check_contiguous(source_tokens)
|
436 |
+
|
437 |
+
# get embeds with positions for source tokens (input_ids + registers)
|
438 |
+
inputs_embeds = self.source_tokens_embedding_with_positions(expanded_src_tokens, total_token_nums, batch_indices, token_indices)
|
439 |
|
440 |
# replace the inference trigger with langtok
|
441 |
# namely, enc-tgt-dec-tgt strategy
|
|
|
443 |
decoder_input_ids[:, 0] = source_tokens[:, -1]
|
444 |
|
445 |
tokens = torch.cat([source_tokens, decoder_input_ids], dim=1)
|
446 |
+
src_length = source_tokens.shape[1]
|
447 |
|
448 |
decoder_inputs_embeds = self.embed_tokens(decoder_input_ids)
|
449 |
decoder_inputs_embeds = decoder_inputs_embeds + self.tgt_embed_positions(decoder_input_ids, past_key_values_length, src_length=src_length)
|
450 |
+
|
451 |
if past_key_values_length == 0:
|
452 |
hidden_states = torch.cat([inputs_embeds, decoder_inputs_embeds], dim=1)
|
453 |
else:
|
454 |
hidden_states = decoder_inputs_embeds
|
455 |
|
456 |
+
# ensure contiguous
|
457 |
+
hidden_states = self.check_contiguous(hidden_states)
|
458 |
+
|
459 |
+
# if attention_mask is NOT given, we build the attention mask from current hyperparams
|
460 |
+
# if attention_mask is given, check the shape of attention mask
|
461 |
+
if attention_mask is None:
|
462 |
+
attention_mask = self.build_future_mask(hidden_states, src_length, register_nums, past_key_values_length)
|
463 |
+
else:
|
464 |
+
bsz, src_len = hidden_states.shape[0], hidden_states.shape[1]
|
465 |
+
tgt_len = hidden_states.shape[1] if past_key_values_length == 0 else past_key_values_length + 1
|
466 |
+
if attention_mask.size() != (bsz, 1, src_len, tgt_len):
|
467 |
+
raise ValueError(
|
468 |
+
f"Attention mask should be of size {(bsz, 1, src_len, tgt_len)}, but is {attention_mask.size()}"
|
469 |
+
)
|
470 |
+
|
471 |
+
# ensure contiguous
|
472 |
+
attention_mask = self.check_contiguous(attention_mask)
|
473 |
+
|
474 |
+
# this is a param to turncate kv cache
|
475 |
+
# in training, it's None, namely, unactivated.
|
476 |
+
max_register_num = None
|
477 |
+
# masking pads for attention_mask in the training or the 1st step of generation
|
478 |
+
if past_key_values_length == 0:
|
479 |
+
# if in generation, activate
|
480 |
+
max_register_num = register_nums.max().item() if use_cache else None
|
481 |
+
|
482 |
+
padding_mask = tokens.eq(self.padding_idx)
|
483 |
+
if padding_mask.any():
|
484 |
+
padding_mask = padding_mask.unsqueeze(1).unsqueeze(2)
|
485 |
+
attention_mask = attention_mask.masked_fill(padding_mask == 1, float('-inf'))
|
486 |
+
|
487 |
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
488 |
|
489 |
if self.gradient_checkpointing and self.training:
|
|
|
495 |
|
496 |
# decoder layers
|
497 |
all_hidden_states = () if output_hidden_states else None
|
|
|
|
|
498 |
next_decoder_cache = () if use_cache else None
|
499 |
|
500 |
for idx, decoder_layer in enumerate(self.layers):
|
|
|
522 |
hidden_states = layer_outputs[0]
|
523 |
|
524 |
if use_cache:
|
525 |
+
if past_key_values_length > 0:
|
526 |
+
next_decoder_cache += (layer_outputs[1],)
|
527 |
+
else:
|
528 |
+
cache_key, cache_value = layer_outputs[1]
|
529 |
+
clipped_rep = (
|
530 |
+
cache_key[:, :, src_length - max_register_num:, :],
|
531 |
+
cache_value[:, :, src_length - max_register_num:, :]
|
532 |
+
)
|
533 |
+
next_decoder_cache += (clipped_rep,)
|
534 |
+
|
535 |
|
536 |
if past_key_values_length == 0:
|
537 |
hidden_states = hidden_states[:,src_length:,:]
|
|
|
548 |
last_hidden_state=hidden_states,
|
549 |
past_key_values=next_cache,
|
550 |
hidden_states=all_hidden_states,
|
|
|
|
|
551 |
)
|
552 |
+
|
553 |
+
# the registering cache used in generation
|
554 |
+
# in the 1st step, we turncate the kv cache to save cost, so we have to change the src_length
|
555 |
+
if use_cache:
|
556 |
+
model_output.registering_cache = {
|
557 |
+
"register_nums": register_nums,
|
558 |
+
"src_length": src_length if past_key_values_length > 0 else max_register_num,
|
559 |
+
"attention_mask": attention_mask if past_key_values_length > 0 else None
|
560 |
+
}
|
561 |
+
else:
|
562 |
+
model_output.registering_cache = None
|
563 |
+
|
564 |
return model_output
|
565 |
|
566 |
|
|
|
658 |
self,
|
659 |
input_ids: Optional[torch.LongTensor] = None,
|
660 |
decoder_input_ids: Optional[torch.Tensor] = None,
|
661 |
+
attention_mask: Optional[torch.Tensor] = None,
|
662 |
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
663 |
use_cache: Optional[bool] = None,
|
664 |
output_attentions: Optional[bool] = None,
|
|
|
674 |
decoder_outputs = self.decoder(
|
675 |
input_ids=input_ids,
|
676 |
decoder_input_ids=decoder_input_ids,
|
677 |
+
attention_mask=attention_mask,
|
678 |
past_key_values=past_key_values,
|
679 |
use_cache=use_cache,
|
680 |
output_hidden_states=output_hidden_states,
|
|
|
715 |
self,
|
716 |
input_ids: Optional[torch.LongTensor] = None,
|
717 |
decoder_input_ids: Optional[torch.LongTensor] = None,
|
718 |
+
attention_mask: Optional[torch.Tensor] = None,
|
719 |
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
720 |
labels: Optional[torch.LongTensor] = None,
|
721 |
use_cache: Optional[bool] = None,
|
722 |
output_hidden_states: Optional[bool] = None,
|
723 |
registering_cache: dict = None,
|
724 |
) -> Union[Tuple[torch.Tensor], Seq2SeqLMOutput]:
|
725 |
+
|
726 |
outputs = self.model(
|
727 |
input_ids=input_ids,
|
728 |
decoder_input_ids=decoder_input_ids,
|
729 |
+
attention_mask=attention_mask,
|
730 |
past_key_values=past_key_values,
|
731 |
use_cache=use_cache,
|
732 |
output_hidden_states=output_hidden_states,
|
|
|
758 |
return reordered_past
|
759 |
|
760 |
@staticmethod
|
761 |
+
def _reorder_register_cache(t, beam_idx):
|
762 |
+
return t.index_select(dim=0, index=beam_idx.to(t.device))
|
763 |
|
764 |
@staticmethod
|
765 |
def _expand_inputs_for_generation(
|
|
|
836 |
this_peer_finished = False
|
837 |
past_key_values = None
|
838 |
registering_cache = None
|
839 |
+
attention_mask = None
|
840 |
|
841 |
logits_processor = LogitsProcessorList()
|
842 |
stopping_criteria = StoppingCriteriaList()
|
|
|
848 |
|
849 |
if past_key_values is not None:
|
850 |
decoder_input_ids_for_generation = decoder_input_ids[:, -1:]
|
851 |
+
attention_mask = registering_cache["attention_mask"]
|
852 |
+
# Get the mask when the first time using kv cache.
|
853 |
+
# After it, we can simply repeat 0. (the last column of mask) to get the next mask.
|
854 |
+
# As a result, we avoid generate the mask from scratch in kv cache and save memory.
|
855 |
+
if attention_mask is not None:
|
856 |
+
attention_mask = torch.cat((attention_mask, attention_mask[..., -1:]), dim=-1)
|
857 |
else:
|
858 |
decoder_input_ids_for_generation = decoder_input_ids
|
859 |
|
|
|
908 |
del outputs
|
909 |
|
910 |
past_key_values = self._reorder_cache(past_key_values, beam_idx)
|
911 |
+
registering_cache["register_nums"] = self._reorder_register_cache(registering_cache["register_nums"], beam_idx)
|
912 |
+
if registering_cache["attention_mask"] is not None:
|
913 |
+
registering_cache["attention_mask"] = self._reorder_register_cache(registering_cache["attention_mask"], beam_idx)
|
914 |
+
|
915 |
cur_len = cur_len + 1
|
916 |
|
917 |
if beam_scorer.is_done:
|