BenkHel commited on
Commit
5c45d3a
·
verified ·
1 Parent(s): 3917c52

Update cumo/model/language_model/llava_llama.py

Browse files
cumo/model/language_model/llava_llama.py CHANGED
@@ -72,24 +72,25 @@ class LlavaLlamaForCausalLM(LlamaForCausalLM, LlavaMetaForCausalLM):
72
 
73
  if inputs_embeds is None:
74
  (
75
- inputs,
76
  position_ids,
77
  attention_mask,
78
- _,
79
  inputs_embeds,
80
- _,
81
  *_
82
  ) = self.prepare_inputs_labels_for_multimodal(
83
- inputs,
84
  position_ids,
85
  attention_mask,
86
- None,
87
- None,
88
  images,
89
- image_sizes=image_sizes
90
  )
91
 
92
 
 
93
  return super().forward(
94
  input_ids=input_ids,
95
  attention_mask=attention_mask,
@@ -116,25 +117,26 @@ class LlavaLlamaForCausalLM(LlamaForCausalLM, LlavaMetaForCausalLM):
116
  if "inputs_embeds" in kwargs:
117
  raise NotImplementedError("`inputs_embeds` is not supported")
118
 
119
- if images is not None:
120
  (
121
- inputs,
122
  position_ids,
123
  attention_mask,
124
- _,
125
  inputs_embeds,
126
- _,
127
  *_
128
  ) = self.prepare_inputs_labels_for_multimodal(
129
- inputs,
130
  position_ids,
131
  attention_mask,
132
- None,
133
- None,
134
  images,
135
- image_sizes=image_sizes
136
  )
137
 
 
138
  else:
139
  inputs_embeds = self.get_model().embed_tokens(inputs)
140
 
 
72
 
73
  if inputs_embeds is None:
74
  (
75
+ input_ids,
76
  position_ids,
77
  attention_mask,
78
+ past_key_values,
79
  inputs_embeds,
80
+ labels,
81
  *_
82
  ) = self.prepare_inputs_labels_for_multimodal(
83
+ input_ids,
84
  position_ids,
85
  attention_mask,
86
+ past_key_values,
87
+ labels,
88
  images,
89
+ image_sizes
90
  )
91
 
92
 
93
+
94
  return super().forward(
95
  input_ids=input_ids,
96
  attention_mask=attention_mask,
 
117
  if "inputs_embeds" in kwargs:
118
  raise NotImplementedError("`inputs_embeds` is not supported")
119
 
120
+ if inputs_embeds is None:
121
  (
122
+ input_ids,
123
  position_ids,
124
  attention_mask,
125
+ past_key_values,
126
  inputs_embeds,
127
+ labels,
128
  *_
129
  ) = self.prepare_inputs_labels_for_multimodal(
130
+ input_ids,
131
  position_ids,
132
  attention_mask,
133
+ past_key_values,
134
+ labels,
135
  images,
136
+ image_sizes
137
  )
138
 
139
+
140
  else:
141
  inputs_embeds = self.get_model().embed_tokens(inputs)
142