BenkHel commited on
Commit
6e93653
·
verified ·
1 Parent(s): 5c45d3a

Update cumo/model/language_model/llava_llama.py

Browse files
cumo/model/language_model/llava_llama.py CHANGED
@@ -107,7 +107,7 @@ class LlavaLlamaForCausalLM(LlamaForCausalLM, LlavaMetaForCausalLM):
107
  @torch.no_grad()
108
  def generate(
109
  self,
110
- inputs: Optional[torch.Tensor] = None,
111
  images: Optional[torch.Tensor] = None,
112
  image_sizes: Optional[torch.Tensor] = None,
113
  **kwargs,
@@ -116,30 +116,28 @@ class LlavaLlamaForCausalLM(LlamaForCausalLM, LlavaMetaForCausalLM):
116
  attention_mask = kwargs.pop("attention_mask", None)
117
  if "inputs_embeds" in kwargs:
118
  raise NotImplementedError("`inputs_embeds` is not supported")
119
-
120
- if inputs_embeds is None:
121
  (
122
  input_ids,
123
  position_ids,
124
  attention_mask,
125
- past_key_values,
126
  inputs_embeds,
127
- labels,
128
  *_
129
  ) = self.prepare_inputs_labels_for_multimodal(
130
  input_ids,
131
  position_ids,
132
  attention_mask,
133
- past_key_values,
134
- labels,
135
  images,
136
- image_sizes
137
  )
138
-
139
-
140
  else:
141
- inputs_embeds = self.get_model().embed_tokens(inputs)
142
-
143
  return super().generate(
144
  position_ids=position_ids,
145
  attention_mask=attention_mask,
@@ -147,6 +145,7 @@ class LlavaLlamaForCausalLM(LlamaForCausalLM, LlavaMetaForCausalLM):
147
  **kwargs
148
  )
149
 
 
150
  def prepare_inputs_for_generation(self, input_ids, past_key_values=None,
151
  inputs_embeds=None, **kwargs):
152
  images = kwargs.pop("images", None)
 
107
  @torch.no_grad()
108
  def generate(
109
  self,
110
+ input_ids: Optional[torch.Tensor] = None,
111
  images: Optional[torch.Tensor] = None,
112
  image_sizes: Optional[torch.Tensor] = None,
113
  **kwargs,
 
116
  attention_mask = kwargs.pop("attention_mask", None)
117
  if "inputs_embeds" in kwargs:
118
  raise NotImplementedError("`inputs_embeds` is not supported")
119
+
120
+ if images is not None:
121
  (
122
  input_ids,
123
  position_ids,
124
  attention_mask,
125
+ _,
126
  inputs_embeds,
127
+ _,
128
  *_
129
  ) = self.prepare_inputs_labels_for_multimodal(
130
  input_ids,
131
  position_ids,
132
  attention_mask,
133
+ None,
134
+ None,
135
  images,
136
+ image_sizes=image_sizes
137
  )
 
 
138
  else:
139
+ inputs_embeds = self.get_model().embed_tokens(input_ids)
140
+
141
  return super().generate(
142
  position_ids=position_ids,
143
  attention_mask=attention_mask,
 
145
  **kwargs
146
  )
147
 
148
+
149
  def prepare_inputs_for_generation(self, input_ids, past_key_values=None,
150
  inputs_embeds=None, **kwargs):
151
  images = kwargs.pop("images", None)