BenkHel commited on
Commit
3917c52
·
verified ·
1 Parent(s): adbc5b6

Update cumo/model/language_model/llava_llama.py

Browse files
cumo/model/language_model/llava_llama.py CHANGED
@@ -72,22 +72,24 @@ class LlavaLlamaForCausalLM(LlamaForCausalLM, LlavaMetaForCausalLM):
72
 
73
  if inputs_embeds is None:
74
  (
75
- input_ids,
76
  position_ids,
77
  attention_mask,
78
- past_key_values,
79
  inputs_embeds,
80
- labels
 
81
  ) = self.prepare_inputs_labels_for_multimodal(
82
- input_ids,
83
  position_ids,
84
  attention_mask,
85
- past_key_values,
86
- labels,
87
  images,
88
- image_sizes
89
  )
90
 
 
91
  return super().forward(
92
  input_ids=input_ids,
93
  attention_mask=attention_mask,
@@ -121,7 +123,8 @@ class LlavaLlamaForCausalLM(LlamaForCausalLM, LlavaMetaForCausalLM):
121
  attention_mask,
122
  _,
123
  inputs_embeds,
124
- _
 
125
  ) = self.prepare_inputs_labels_for_multimodal(
126
  inputs,
127
  position_ids,
@@ -131,6 +134,7 @@ class LlavaLlamaForCausalLM(LlamaForCausalLM, LlavaMetaForCausalLM):
131
  images,
132
  image_sizes=image_sizes
133
  )
 
134
  else:
135
  inputs_embeds = self.get_model().embed_tokens(inputs)
136
 
 
72
 
73
  if inputs_embeds is None:
74
  (
75
+ inputs,
76
  position_ids,
77
  attention_mask,
78
+ _,
79
  inputs_embeds,
80
+ _,
81
+ *_
82
  ) = self.prepare_inputs_labels_for_multimodal(
83
+ inputs,
84
  position_ids,
85
  attention_mask,
86
+ None,
87
+ None,
88
  images,
89
+ image_sizes=image_sizes
90
  )
91
 
92
+
93
  return super().forward(
94
  input_ids=input_ids,
95
  attention_mask=attention_mask,
 
123
  attention_mask,
124
  _,
125
  inputs_embeds,
126
+ _,
127
+ *_
128
  ) = self.prepare_inputs_labels_for_multimodal(
129
  inputs,
130
  position_ids,
 
134
  images,
135
  image_sizes=image_sizes
136
  )
137
+
138
  else:
139
  inputs_embeds = self.get_model().embed_tokens(inputs)
140