hezhihui
commited on
Commit
•
6d7ce17
1
Parent(s):
b352d20
multi-images
Browse files- modeling_minicpmv.py +26 -5
modeling_minicpmv.py
CHANGED
@@ -3,6 +3,7 @@ import json
|
|
3 |
import torch
|
4 |
from threading import Thread
|
5 |
from copy import deepcopy
|
|
|
6 |
from torchvision import transforms
|
7 |
from transformers import LlamaPreTrainedModel, LlamaForCausalLM, TextIteratorStreamer
|
8 |
from transformers.models.idefics2.modeling_idefics2 import Idefics2VisionTransformer
|
@@ -291,17 +292,37 @@ class MiniCPMV(MiniCPMVPreTrainedModel):
|
|
291 |
msgs = json.loads(msgs)
|
292 |
copy_msgs = deepcopy(msgs)
|
293 |
|
294 |
-
assert len(msgs) > 0,
|
295 |
-
assert sampling or not stream,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
296 |
|
297 |
-
if image is not None and isinstance(msgs[0]['content'], str):
|
298 |
-
copy_msgs[0]['content'] = '(<image>./</image>)\n' + copy_msgs[0]['content']
|
299 |
if system_prompt:
|
300 |
sys_msg = {'role': 'system', 'content': system_prompt}
|
301 |
copy_msgs = [sys_msg] + copy_msgs
|
302 |
|
303 |
prompt = processor.tokenizer.apply_chat_template(copy_msgs, tokenize=False, add_generation_prompt=True)
|
304 |
-
inputs = processor(prompt,
|
305 |
|
306 |
if sampling:
|
307 |
generation_config = {
|
|
|
3 |
import torch
|
4 |
from threading import Thread
|
5 |
from copy import deepcopy
|
6 |
+
from PIL import Image
|
7 |
from torchvision import transforms
|
8 |
from transformers import LlamaPreTrainedModel, LlamaForCausalLM, TextIteratorStreamer
|
9 |
from transformers.models.idefics2.modeling_idefics2 import Idefics2VisionTransformer
|
|
|
292 |
msgs = json.loads(msgs)
|
293 |
copy_msgs = deepcopy(msgs)
|
294 |
|
295 |
+
assert len(msgs) > 0, "msgs is empty"
|
296 |
+
assert sampling or not stream, "if use stream mode, make sure sampling=True"
|
297 |
+
|
298 |
+
if image is not None and isinstance(copy_msgs[0]["content"], str):
|
299 |
+
# copy_msgs[0]['content'] = '(<image>./</image>)\n' + copy_msgs[0]['content']
|
300 |
+
copy_msgs[0]["content"] = [image, copy_msgs[0]["content"]]
|
301 |
+
|
302 |
+
images = []
|
303 |
+
for i, msg in enumerate(copy_msgs):
|
304 |
+
role = msg["role"]
|
305 |
+
content = msg["content"]
|
306 |
+
assert role in ["user", "assistant"]
|
307 |
+
if i == 0:
|
308 |
+
assert role == "user", "The role of first msg should be user"
|
309 |
+
if isinstance(content, str):
|
310 |
+
content = [content]
|
311 |
+
cur_msgs = []
|
312 |
+
for c in content:
|
313 |
+
if isinstance(c, Image.Image):
|
314 |
+
images.append(c)
|
315 |
+
cur_msgs.append("(<image>./</image>)")
|
316 |
+
elif isinstance(c, str):
|
317 |
+
cur_msgs.append(c)
|
318 |
+
msg["content"] = "\n".join(cur_msgs)
|
319 |
|
|
|
|
|
320 |
if system_prompt:
|
321 |
sys_msg = {'role': 'system', 'content': system_prompt}
|
322 |
copy_msgs = [sys_msg] + copy_msgs
|
323 |
|
324 |
prompt = processor.tokenizer.apply_chat_template(copy_msgs, tokenize=False, add_generation_prompt=True)
|
325 |
+
inputs = processor(prompt, images, return_tensors="pt", max_length=max_inp_length).to(self.device)
|
326 |
|
327 |
if sampling:
|
328 |
generation_config = {
|