Spaces:
Runtime error
Runtime error
DongfuJiang
commited on
Commit
·
335eee6
1
Parent(s):
c862a9f
update
Browse files- app.py +9 -5
- models/mllava/__init__.py +1 -1
- models/mllava/utils.py +11 -3
- requirements.txt +2 -1
app.py
CHANGED
@@ -4,10 +4,12 @@ import os
|
|
4 |
import time
|
5 |
from PIL import Image
|
6 |
import functools
|
7 |
-
from models.mllava import MLlavaProcessor, LlavaForConditionalGeneration,
|
|
|
8 |
from typing import List
|
9 |
processor = MLlavaProcessor.from_pretrained("TIGER-Lab/Mantis-8B-siglip-llama3")
|
10 |
model = LlavaForConditionalGeneration.from_pretrained("TIGER-Lab/Mantis-8B-siglip-llama3")
|
|
|
11 |
|
12 |
@spaces.GPU
|
13 |
def generate(text:str, images:List[Image.Image], history: List[dict], **kwargs):
|
@@ -15,7 +17,7 @@ def generate(text:str, images:List[Image.Image], history: List[dict], **kwargs):
|
|
15 |
model = model.to("cuda")
|
16 |
if not images:
|
17 |
images = None
|
18 |
-
for text, history in
|
19 |
yield text
|
20 |
|
21 |
return text
|
@@ -38,15 +40,17 @@ def print_like_dislike(x: gr.LikeData):
|
|
38 |
|
39 |
def get_chat_history(history):
|
40 |
chat_history = []
|
|
|
|
|
41 |
for i, message in enumerate(history):
|
42 |
if isinstance(message[0], str):
|
43 |
-
chat_history.append({"role":
|
44 |
if i != len(history) - 1:
|
45 |
assert message[1], "The bot message is not provided, internal error"
|
46 |
-
chat_history.append({"role":
|
47 |
else:
|
48 |
assert not message[1], "the bot message internal error, get: {}".format(message[1])
|
49 |
-
chat_history.append({"role":
|
50 |
return chat_history
|
51 |
|
52 |
|
|
|
4 |
import time
|
5 |
from PIL import Image
|
6 |
import functools
|
7 |
+
from models.mllava import MLlavaProcessor, LlavaForConditionalGeneration, chat_mllava_stream, MLlavaForConditionalGeneration
|
8 |
+
from models.conversation import conv_templates
|
9 |
from typing import List
|
10 |
processor = MLlavaProcessor.from_pretrained("TIGER-Lab/Mantis-8B-siglip-llama3")
|
11 |
model = LlavaForConditionalGeneration.from_pretrained("TIGER-Lab/Mantis-8B-siglip-llama3")
|
12 |
+
conv_template = conv_templates['llama_3']
|
13 |
|
14 |
@spaces.GPU
|
15 |
def generate(text:str, images:List[Image.Image], history: List[dict], **kwargs):
|
|
|
17 |
model = model.to("cuda")
|
18 |
if not images:
|
19 |
images = None
|
20 |
+
for text, history in chat_mllava_stream(text, images, model, processor, history=history, **kwargs):
|
21 |
yield text
|
22 |
|
23 |
return text
|
|
|
40 |
|
41 |
def get_chat_history(history):
|
42 |
chat_history = []
|
43 |
+
user_role = conv_template.roles[0]
|
44 |
+
assistant_role = conv_template.roles[1]
|
45 |
for i, message in enumerate(history):
|
46 |
if isinstance(message[0], str):
|
47 |
+
chat_history.append({"role": user_role, "text": message[0]})
|
48 |
if i != len(history) - 1:
|
49 |
assert message[1], "The bot message is not provided, internal error"
|
50 |
+
chat_history.append({"role": assistant_role, "text": message[1]})
|
51 |
else:
|
52 |
assert not message[1], "the bot message internal error, get: {}".format(message[1])
|
53 |
+
chat_history.append({"role": assistant_role, "text": ""})
|
54 |
return chat_history
|
55 |
|
56 |
|
models/mllava/__init__.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1 |
from .modeling_llava import LlavaForConditionalGeneration, MLlavaForConditionalGeneration
|
2 |
from .processing_llava import MLlavaProcessor
|
3 |
from .configuration_llava import LlavaConfig
|
4 |
-
from .utils import chat_mllava
|
|
|
1 |
from .modeling_llava import LlavaForConditionalGeneration, MLlavaForConditionalGeneration
|
2 |
from .processing_llava import MLlavaProcessor
|
3 |
from .configuration_llava import LlavaConfig
|
4 |
+
from .utils import chat_mllava, chat_mllava_stream
|
models/mllava/utils.py
CHANGED
@@ -44,7 +44,6 @@ def chat_mllava(
|
|
44 |
conv.messages = []
|
45 |
if history is not None:
|
46 |
for message in history:
|
47 |
-
message["role"] = message["role"].upper()
|
48 |
assert message["role"] in conv.roles
|
49 |
conv.append_message(message["role"], message["text"])
|
50 |
else:
|
@@ -105,11 +104,20 @@ def chat_mllava_stream(
|
|
105 |
|
106 |
|
107 |
"""
|
108 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
109 |
conv.messages = []
|
110 |
if history is not None:
|
111 |
for message in history:
|
112 |
-
message["role"] = message["role"].upper()
|
113 |
assert message["role"] in conv.roles
|
114 |
conv.append_message(message["role"], message["text"])
|
115 |
else:
|
|
|
44 |
conv.messages = []
|
45 |
if history is not None:
|
46 |
for message in history:
|
|
|
47 |
assert message["role"] in conv.roles
|
48 |
conv.append_message(message["role"], message["text"])
|
49 |
else:
|
|
|
104 |
|
105 |
|
106 |
"""
|
107 |
+
if "llama-3" in model.language_model.name_or_path.lower():
|
108 |
+
conv = conv_templates['llama_3']
|
109 |
+
terminators = [
|
110 |
+
processor.tokenizer.eos_token_id,
|
111 |
+
processor.tokenizer.convert_tokens_to_ids("<|eot_id|>")
|
112 |
+
]
|
113 |
+
else:
|
114 |
+
conv = default_conv
|
115 |
+
terminators = None
|
116 |
+
kwargs["eos_token_id"] = terminators
|
117 |
+
conv = conv.copy()
|
118 |
conv.messages = []
|
119 |
if history is not None:
|
120 |
for message in history:
|
|
|
121 |
assert message["role"] in conv.roles
|
122 |
conv.append_message(message["role"], message["text"])
|
123 |
else:
|
requirements.txt
CHANGED
@@ -3,4 +3,5 @@ transformers
|
|
3 |
Pillow
|
4 |
gradio
|
5 |
spaces
|
6 |
-
multiprocess
|
|
|
|
3 |
Pillow
|
4 |
gradio
|
5 |
spaces
|
6 |
+
multiprocess
|
7 |
+
flash-attn
|