Spaces:
Running
on
Zero
Running
on
Zero
Update cumo/model/language_model/llava_llama.py
Browse files
cumo/model/language_model/llava_llama.py
CHANGED
@@ -107,7 +107,7 @@ class LlavaLlamaForCausalLM(LlamaForCausalLM, LlavaMetaForCausalLM):
|
|
107 |
@torch.no_grad()
|
108 |
def generate(
|
109 |
self,
|
110 |
-
|
111 |
images: Optional[torch.Tensor] = None,
|
112 |
image_sizes: Optional[torch.Tensor] = None,
|
113 |
**kwargs,
|
@@ -116,30 +116,28 @@ class LlavaLlamaForCausalLM(LlamaForCausalLM, LlavaMetaForCausalLM):
|
|
116 |
attention_mask = kwargs.pop("attention_mask", None)
|
117 |
if "inputs_embeds" in kwargs:
|
118 |
raise NotImplementedError("`inputs_embeds` is not supported")
|
119 |
-
|
120 |
-
if
|
121 |
(
|
122 |
input_ids,
|
123 |
position_ids,
|
124 |
attention_mask,
|
125 |
-
|
126 |
inputs_embeds,
|
127 |
-
|
128 |
*_
|
129 |
) = self.prepare_inputs_labels_for_multimodal(
|
130 |
input_ids,
|
131 |
position_ids,
|
132 |
attention_mask,
|
133 |
-
|
134 |
-
|
135 |
images,
|
136 |
-
image_sizes
|
137 |
)
|
138 |
-
|
139 |
-
|
140 |
else:
|
141 |
-
inputs_embeds = self.get_model().embed_tokens(
|
142 |
-
|
143 |
return super().generate(
|
144 |
position_ids=position_ids,
|
145 |
attention_mask=attention_mask,
|
@@ -147,6 +145,7 @@ class LlavaLlamaForCausalLM(LlamaForCausalLM, LlavaMetaForCausalLM):
|
|
147 |
**kwargs
|
148 |
)
|
149 |
|
|
|
150 |
def prepare_inputs_for_generation(self, input_ids, past_key_values=None,
|
151 |
inputs_embeds=None, **kwargs):
|
152 |
images = kwargs.pop("images", None)
|
|
|
107 |
@torch.no_grad()
|
108 |
def generate(
|
109 |
self,
|
110 |
+
input_ids: Optional[torch.Tensor] = None,
|
111 |
images: Optional[torch.Tensor] = None,
|
112 |
image_sizes: Optional[torch.Tensor] = None,
|
113 |
**kwargs,
|
|
|
116 |
attention_mask = kwargs.pop("attention_mask", None)
|
117 |
if "inputs_embeds" in kwargs:
|
118 |
raise NotImplementedError("`inputs_embeds` is not supported")
|
119 |
+
|
120 |
+
if images is not None:
|
121 |
(
|
122 |
input_ids,
|
123 |
position_ids,
|
124 |
attention_mask,
|
125 |
+
_,
|
126 |
inputs_embeds,
|
127 |
+
_,
|
128 |
*_
|
129 |
) = self.prepare_inputs_labels_for_multimodal(
|
130 |
input_ids,
|
131 |
position_ids,
|
132 |
attention_mask,
|
133 |
+
None,
|
134 |
+
None,
|
135 |
images,
|
136 |
+
image_sizes=image_sizes
|
137 |
)
|
|
|
|
|
138 |
else:
|
139 |
+
inputs_embeds = self.get_model().embed_tokens(input_ids)
|
140 |
+
|
141 |
return super().generate(
|
142 |
position_ids=position_ids,
|
143 |
attention_mask=attention_mask,
|
|
|
145 |
**kwargs
|
146 |
)
|
147 |
|
148 |
+
|
149 |
def prepare_inputs_for_generation(self, input_ids, past_key_values=None,
|
150 |
inputs_embeds=None, **kwargs):
|
151 |
images = kwargs.pop("images", None)
|