root commited on
Commit
68fded4
·
1 Parent(s): 24af34e

Batch generation

Browse files
Files changed (1) hide show
  1. modeling_ovis.py +18 -13
modeling_ovis.py CHANGED
@@ -353,7 +353,8 @@ class Ovis(OvisPreTrainedModel):
353
  text_input_ids: torch.Tensor,
354
  text_attention_masks: torch.Tensor,
355
  text_labels: Optional[torch.Tensor],
356
- pixel_values: List[Optional[torch.Tensor]]
 
357
  ):
358
  input_device = text_input_ids.device
359
  visual_vocab_szie = self.get_visual_tokenizer().config.vocab_size
@@ -393,8 +394,8 @@ class Ovis(OvisPreTrainedModel):
393
  visual_embeds = [None] * len(num_images)
394
  visual_input_ids = [None] * len(num_images)
395
  visual_labels = [None] * len(num_images)
396
- # just placeholders
397
- text_labels = torch.full(text_input_ids.shape, IGNORE_ID, dtype=torch.long, device=input_device)
398
 
399
  input_embeds = []
400
  attention_masks = []
@@ -451,16 +452,20 @@ class Ovis(OvisPreTrainedModel):
451
  input_embeds[0] = torch.nn.ConstantPad2d((0, 0, 0, padding_size), 0.0)(input_embeds[0])
452
  attention_masks[0] = torch.nn.ConstantPad1d((0, padding_size), False)(attention_masks[0])
453
  labels[0] = torch.nn.ConstantPad1d((0, padding_size), IGNORE_ID)(labels[0])
454
- batch_input_embeds = torch.nn.utils.rnn.pad_sequence(input_embeds, batch_first=True, padding_value=0.0)[:,
455
- :self.config.multimodal_max_length, :]
456
- batch_attention_mask = torch.nn.utils.rnn.pad_sequence(attention_masks, batch_first=True, padding_value=False)[
457
- :,
458
- :self.config.multimodal_max_length]
459
- batch_labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=IGNORE_ID)[:,
460
- :self.config.multimodal_max_length]
461
 
462
  return visual_input_ids, batch_input_embeds, batch_labels, batch_attention_mask
463
 
 
 
 
 
 
 
 
 
464
  def preprocess_inputs(
465
  self,
466
  text_or_conversations: Union[List[Dict], str],
@@ -580,16 +585,16 @@ class Ovis(OvisPreTrainedModel):
580
  inputs: Optional[torch.Tensor] = None,
581
  **kwargs
582
  ) -> Union[GenerateOutput, torch.LongTensor]:
583
- assert inputs.shape[0] == 1, 'Currently, only support `batch_size=1`'
584
  _, inputs_embeds, labels, attention_mask = self.merge_multimodal(
585
  text_input_ids=inputs,
586
  text_attention_masks=kwargs.pop('attention_mask'),
587
  text_labels=None,
588
- pixel_values=kwargs.pop('pixel_values')
 
589
  )
590
  if getattr(self.generation_config, 'cache_implementation') == 'hybrid': # mainly for Gemma2
591
  kwargs['past_key_values'] = self._get_hybrid_cache_for_llm(
592
- getattr(kwargs, "num_beams", 1), kwargs['max_new_tokens'] + inputs_embeds.shape[-2])
593
  self.get_llm()._supports_cache_class = True
594
  kwargs['cache_implementation'] = None
595
 
 
353
  text_input_ids: torch.Tensor,
354
  text_attention_masks: torch.Tensor,
355
  text_labels: Optional[torch.Tensor],
356
+ pixel_values: List[Optional[torch.Tensor]],
357
+ left_padding: bool = False
358
  ):
359
  input_device = text_input_ids.device
360
  visual_vocab_szie = self.get_visual_tokenizer().config.vocab_size
 
394
  visual_embeds = [None] * len(num_images)
395
  visual_input_ids = [None] * len(num_images)
396
  visual_labels = [None] * len(num_images)
397
+ if text_labels is None:
398
+ text_labels = torch.full(text_input_ids.shape, IGNORE_ID, dtype=torch.long, device=input_device)
399
 
400
  input_embeds = []
401
  attention_masks = []
 
452
  input_embeds[0] = torch.nn.ConstantPad2d((0, 0, 0, padding_size), 0.0)(input_embeds[0])
453
  attention_masks[0] = torch.nn.ConstantPad1d((0, padding_size), False)(attention_masks[0])
454
  labels[0] = torch.nn.ConstantPad1d((0, padding_size), IGNORE_ID)(labels[0])
455
+ batch_input_embeds = self.pad_truncate_sequence(input_embeds, batch_first=True, padding_value=0.0, left_padding=left_padding)
456
+ batch_attention_mask = self.pad_truncate_sequence(attention_masks, batch_first=True, padding_value=False, left_padding=left_padding)
457
+ batch_labels = self.pad_truncate_sequence(labels, batch_first=True, padding_value=IGNORE_ID, left_padding=left_padding)
 
 
 
 
458
 
459
  return visual_input_ids, batch_input_embeds, batch_labels, batch_attention_mask
460
 
461
+ def pad_truncate_sequence(self, sequences: List[torch.Tensor], batch_first: bool = True, padding_value: float = 0.0, left_padding: bool = False) -> torch.Tensor:
462
+ if left_padding == False:
463
+ pad_sequence = torch.nn.utils.rnn.pad_sequence(sequences, batch_first=batch_first, padding_value=padding_value)
464
+ return pad_sequence[:,:self.config.multimodal_max_length]
465
+ else:
466
+ pad_sequence = torch.nn.utils.rnn.pad_sequence([i.flip(dims=[0]) for i in sequences],batch_first=True, padding_value=padding_value).flip(dims=[1])
467
+ return pad_sequence[:,-self.config.multimodal_max_length:]
468
+
469
  def preprocess_inputs(
470
  self,
471
  text_or_conversations: Union[List[Dict], str],
 
585
  inputs: Optional[torch.Tensor] = None,
586
  **kwargs
587
  ) -> Union[GenerateOutput, torch.LongTensor]:
 
588
  _, inputs_embeds, labels, attention_mask = self.merge_multimodal(
589
  text_input_ids=inputs,
590
  text_attention_masks=kwargs.pop('attention_mask'),
591
  text_labels=None,
592
+ pixel_values=kwargs.pop('pixel_values'),
593
+ left_padding=True
594
  )
595
  if getattr(self.generation_config, 'cache_implementation') == 'hybrid': # mainly for Gemma2
596
  kwargs['past_key_values'] = self._get_hybrid_cache_for_llm(
597
+ getattr(kwargs, "num_beams", inputs_embeds.shape[0]), kwargs['max_new_tokens'] + inputs_embeds.shape[-2])
598
  self.get_llm()._supports_cache_class = True
599
  kwargs['cache_implementation'] = None
600