Spaces:
Running
on
Zero
Running
on
Zero
Update cumo/model/language_model/llava_llama.py
Browse files
cumo/model/language_model/llava_llama.py
CHANGED
@@ -72,24 +72,25 @@ class LlavaLlamaForCausalLM(LlamaForCausalLM, LlavaMetaForCausalLM):
|
|
72 |
|
73 |
if inputs_embeds is None:
|
74 |
(
|
75 |
-
|
76 |
position_ids,
|
77 |
attention_mask,
|
78 |
-
|
79 |
inputs_embeds,
|
80 |
-
|
81 |
*_
|
82 |
) = self.prepare_inputs_labels_for_multimodal(
|
83 |
-
|
84 |
position_ids,
|
85 |
attention_mask,
|
86 |
-
|
87 |
-
|
88 |
images,
|
89 |
-
image_sizes
|
90 |
)
|
91 |
|
92 |
|
|
|
93 |
return super().forward(
|
94 |
input_ids=input_ids,
|
95 |
attention_mask=attention_mask,
|
@@ -116,25 +117,26 @@ class LlavaLlamaForCausalLM(LlamaForCausalLM, LlavaMetaForCausalLM):
|
|
116 |
if "inputs_embeds" in kwargs:
|
117 |
raise NotImplementedError("`inputs_embeds` is not supported")
|
118 |
|
119 |
-
if
|
120 |
(
|
121 |
-
|
122 |
position_ids,
|
123 |
attention_mask,
|
124 |
-
|
125 |
inputs_embeds,
|
126 |
-
|
127 |
*_
|
128 |
) = self.prepare_inputs_labels_for_multimodal(
|
129 |
-
|
130 |
position_ids,
|
131 |
attention_mask,
|
132 |
-
|
133 |
-
|
134 |
images,
|
135 |
-
image_sizes
|
136 |
)
|
137 |
|
|
|
138 |
else:
|
139 |
inputs_embeds = self.get_model().embed_tokens(inputs)
|
140 |
|
|
|
72 |
|
73 |
if inputs_embeds is None:
|
74 |
(
|
75 |
+
input_ids,
|
76 |
position_ids,
|
77 |
attention_mask,
|
78 |
+
past_key_values,
|
79 |
inputs_embeds,
|
80 |
+
labels,
|
81 |
*_
|
82 |
) = self.prepare_inputs_labels_for_multimodal(
|
83 |
+
input_ids,
|
84 |
position_ids,
|
85 |
attention_mask,
|
86 |
+
past_key_values,
|
87 |
+
labels,
|
88 |
images,
|
89 |
+
image_sizes
|
90 |
)
|
91 |
|
92 |
|
93 |
+
|
94 |
return super().forward(
|
95 |
input_ids=input_ids,
|
96 |
attention_mask=attention_mask,
|
|
|
117 |
if "inputs_embeds" in kwargs:
|
118 |
raise NotImplementedError("`inputs_embeds` is not supported")
|
119 |
|
120 |
+
if inputs_embeds is None:
|
121 |
(
|
122 |
+
input_ids,
|
123 |
position_ids,
|
124 |
attention_mask,
|
125 |
+
past_key_values,
|
126 |
inputs_embeds,
|
127 |
+
labels,
|
128 |
*_
|
129 |
) = self.prepare_inputs_labels_for_multimodal(
|
130 |
+
input_ids,
|
131 |
position_ids,
|
132 |
attention_mask,
|
133 |
+
past_key_values,
|
134 |
+
labels,
|
135 |
images,
|
136 |
+
image_sizes
|
137 |
)
|
138 |
|
139 |
+
|
140 |
else:
|
141 |
inputs_embeds = self.get_model().embed_tokens(inputs)
|
142 |
|