root
commited on
Commit
·
68fded4
1
Parent(s):
24af34e
Batch generation
Browse files- modeling_ovis.py +18 -13
modeling_ovis.py
CHANGED
@@ -353,7 +353,8 @@ class Ovis(OvisPreTrainedModel):
|
|
353 |
text_input_ids: torch.Tensor,
|
354 |
text_attention_masks: torch.Tensor,
|
355 |
text_labels: Optional[torch.Tensor],
|
356 |
-
pixel_values: List[Optional[torch.Tensor]]
|
|
|
357 |
):
|
358 |
input_device = text_input_ids.device
|
359 |
visual_vocab_szie = self.get_visual_tokenizer().config.vocab_size
|
@@ -393,8 +394,8 @@ class Ovis(OvisPreTrainedModel):
|
|
393 |
visual_embeds = [None] * len(num_images)
|
394 |
visual_input_ids = [None] * len(num_images)
|
395 |
visual_labels = [None] * len(num_images)
|
396 |
-
|
397 |
-
|
398 |
|
399 |
input_embeds = []
|
400 |
attention_masks = []
|
@@ -451,16 +452,20 @@ class Ovis(OvisPreTrainedModel):
|
|
451 |
input_embeds[0] = torch.nn.ConstantPad2d((0, 0, 0, padding_size), 0.0)(input_embeds[0])
|
452 |
attention_masks[0] = torch.nn.ConstantPad1d((0, padding_size), False)(attention_masks[0])
|
453 |
labels[0] = torch.nn.ConstantPad1d((0, padding_size), IGNORE_ID)(labels[0])
|
454 |
-
batch_input_embeds =
|
455 |
-
|
456 |
-
|
457 |
-
:,
|
458 |
-
:self.config.multimodal_max_length]
|
459 |
-
batch_labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=IGNORE_ID)[:,
|
460 |
-
:self.config.multimodal_max_length]
|
461 |
|
462 |
return visual_input_ids, batch_input_embeds, batch_labels, batch_attention_mask
|
463 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
464 |
def preprocess_inputs(
|
465 |
self,
|
466 |
text_or_conversations: Union[List[Dict], str],
|
@@ -580,16 +585,16 @@ class Ovis(OvisPreTrainedModel):
|
|
580 |
inputs: Optional[torch.Tensor] = None,
|
581 |
**kwargs
|
582 |
) -> Union[GenerateOutput, torch.LongTensor]:
|
583 |
-
assert inputs.shape[0] == 1, 'Currently, only support `batch_size=1`'
|
584 |
_, inputs_embeds, labels, attention_mask = self.merge_multimodal(
|
585 |
text_input_ids=inputs,
|
586 |
text_attention_masks=kwargs.pop('attention_mask'),
|
587 |
text_labels=None,
|
588 |
-
pixel_values=kwargs.pop('pixel_values')
|
|
|
589 |
)
|
590 |
if getattr(self.generation_config, 'cache_implementation') == 'hybrid': # mainly for Gemma2
|
591 |
kwargs['past_key_values'] = self._get_hybrid_cache_for_llm(
|
592 |
-
getattr(kwargs, "num_beams",
|
593 |
self.get_llm()._supports_cache_class = True
|
594 |
kwargs['cache_implementation'] = None
|
595 |
|
|
|
353 |
text_input_ids: torch.Tensor,
|
354 |
text_attention_masks: torch.Tensor,
|
355 |
text_labels: Optional[torch.Tensor],
|
356 |
+
pixel_values: List[Optional[torch.Tensor]],
|
357 |
+
left_padding: bool = False
|
358 |
):
|
359 |
input_device = text_input_ids.device
|
360 |
visual_vocab_szie = self.get_visual_tokenizer().config.vocab_size
|
|
|
394 |
visual_embeds = [None] * len(num_images)
|
395 |
visual_input_ids = [None] * len(num_images)
|
396 |
visual_labels = [None] * len(num_images)
|
397 |
+
if text_labels is None:
|
398 |
+
text_labels = torch.full(text_input_ids.shape, IGNORE_ID, dtype=torch.long, device=input_device)
|
399 |
|
400 |
input_embeds = []
|
401 |
attention_masks = []
|
|
|
452 |
input_embeds[0] = torch.nn.ConstantPad2d((0, 0, 0, padding_size), 0.0)(input_embeds[0])
|
453 |
attention_masks[0] = torch.nn.ConstantPad1d((0, padding_size), False)(attention_masks[0])
|
454 |
labels[0] = torch.nn.ConstantPad1d((0, padding_size), IGNORE_ID)(labels[0])
|
455 |
+
batch_input_embeds = self.pad_truncate_sequence(input_embeds, batch_first=True, padding_value=0.0, left_padding=left_padding)
|
456 |
+
batch_attention_mask = self.pad_truncate_sequence(attention_masks, batch_first=True, padding_value=False, left_padding=left_padding)
|
457 |
+
batch_labels = self.pad_truncate_sequence(labels, batch_first=True, padding_value=IGNORE_ID, left_padding=left_padding)
|
|
|
|
|
|
|
|
|
458 |
|
459 |
return visual_input_ids, batch_input_embeds, batch_labels, batch_attention_mask
|
460 |
|
461 |
+
def pad_truncate_sequence(self, sequences: List[torch.Tensor], batch_first: bool = True, padding_value: float = 0.0, left_padding: bool = False) -> torch.Tensor:
|
462 |
+
if left_padding == False:
|
463 |
+
pad_sequence = torch.nn.utils.rnn.pad_sequence(sequences, batch_first=batch_first, padding_value=padding_value)
|
464 |
+
return pad_sequence[:,:self.config.multimodal_max_length]
|
465 |
+
else:
|
466 |
+
pad_sequence = torch.nn.utils.rnn.pad_sequence([i.flip(dims=[0]) for i in sequences],batch_first=True, padding_value=padding_value).flip(dims=[1])
|
467 |
+
return pad_sequence[:,-self.config.multimodal_max_length:]
|
468 |
+
|
469 |
def preprocess_inputs(
|
470 |
self,
|
471 |
text_or_conversations: Union[List[Dict], str],
|
|
|
585 |
inputs: Optional[torch.Tensor] = None,
|
586 |
**kwargs
|
587 |
) -> Union[GenerateOutput, torch.LongTensor]:
|
|
|
588 |
_, inputs_embeds, labels, attention_mask = self.merge_multimodal(
|
589 |
text_input_ids=inputs,
|
590 |
text_attention_masks=kwargs.pop('attention_mask'),
|
591 |
text_labels=None,
|
592 |
+
pixel_values=kwargs.pop('pixel_values'),
|
593 |
+
left_padding=True
|
594 |
)
|
595 |
if getattr(self.generation_config, 'cache_implementation') == 'hybrid': # mainly for Gemma2
|
596 |
kwargs['past_key_values'] = self._get_hybrid_cache_for_llm(
|
597 |
+
getattr(kwargs, "num_beams", inputs_embeds.shape[0]), kwargs['max_new_tokens'] + inputs_embeds.shape[-2])
|
598 |
self.get_llm()._supports_cache_class = True
|
599 |
kwargs['cache_implementation'] = None
|
600 |
|