|
from spaces import GPU |
|
|
|
from threading import Thread |
|
from transformers import Qwen2VLForConditionalGeneration, Qwen2VLProcessor, TextIteratorStreamer, AutoProcessor |
|
from qwen_vl_utils import process_vision_info |
|
|
|
from gradio import ChatMessage, Chatbot, MultimodalTextbox, Slider, Checkbox, CheckboxGroup, Textbox, JSON, Blocks, Row, Column, Markdown, FileData |
|
|
|
model_path = "Pectics/Softie-VL-7B-250123" |
|
|
|
model = Qwen2VLForConditionalGeneration.from_pretrained( |
|
model_path, |
|
torch_dtype="auto", |
|
device_map="auto", |
|
attn_implementation="flash_attention_2", |
|
) |
|
min_pixels = 256 * 28 * 28 |
|
max_pixels = 1280 * 28 * 28 |
|
processor: Qwen2VLProcessor = AutoProcessor.from_pretrained(model_path, min_pixels=min_pixels, max_pixels=max_pixels) |
|
|
|
SYSTEM_PROMPT = """ |
|
You are Softie, or 小软 in Chinese. |
|
You are an intelligent assistant developed by the School of Software at Hefei University of Technology. |
|
You like to chat with people and help them solve problems. |
|
You will interact with the user in the QQ group chat, and the user message you receive will satisfy the following format: |
|
user_id: 用户ID |
|
nickname: 用户昵称 |
|
content: 用户消息内容 |
|
You will directly output your reply content without the need to format it. |
|
""".strip() |
|
|
|
STREAMING_FLAG = False |
|
STREAMING_STOP_FLAG = False |
|
|
|
FORBIDDING_FLAG = False |
|
|
|
def interrupt() -> None: |
|
global STREAMING_FLAG |
|
if STREAMING_FLAG: |
|
global STREAMING_STOP_FLAG |
|
STREAMING_STOP_FLAG = True |
|
|
|
def callback( |
|
input: dict, |
|
history: list[ChatMessage], |
|
messages: list |
|
) -> tuple[str, list[ChatMessage], list]: |
|
if len(history) <= 1 or len(messages) <= 1: |
|
return input["text"], history, messages |
|
history.pop(); messages.pop(); messages.pop() |
|
return history.pop()["content"], history, messages |
|
|
|
@GPU |
|
def core_infer( |
|
inputs: tuple, |
|
max_tokens: int, |
|
temperature: float, |
|
top_p: float, |
|
): |
|
inputs = processor( |
|
text=[inputs[0]], |
|
images=inputs[1], |
|
videos=inputs[2], |
|
padding=True, |
|
return_tensors="pt", |
|
).to("cuda") |
|
streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True) |
|
kwargs = dict( |
|
**inputs, |
|
streamer=streamer, |
|
max_new_tokens=max_tokens, |
|
temperature=temperature, |
|
top_p=top_p, |
|
) |
|
Thread(target=model.generate, kwargs=kwargs).start() |
|
for token in streamer: |
|
yield token |
|
|
|
def process_model( |
|
messages: str | list[object], |
|
user_id: str, |
|
nickname: str, |
|
max_tokens: int, |
|
temperature: float, |
|
top_p: float, |
|
use_tools: bool, |
|
use_agents: bool, |
|
*args: list, |
|
): |
|
global STREAMING_FLAG, STREAMING_STOP_FLAG, FORBIDDING_FLAG |
|
if STREAMING_FLAG or FORBIDDING_FLAG or len(messages) <= 0 or messages[-1]["role"] != "user": |
|
yield None, messages, args[0] |
|
return |
|
|
|
_msgs_copy = messages.copy() |
|
if isinstance(_msgs_copy[-1]["content"], list): |
|
if len(_msgs_copy[-1]["content"]) <= 0 or _msgs_copy[-1]["content"][-1]["type"] != "text": |
|
_msgs_copy[-1]["content"].insert(0, { |
|
"type": "text", |
|
"text": f""" |
|
user_id: {user_id} |
|
nickname: {nickname} |
|
content: """.lstrip(), |
|
}) |
|
else: |
|
_msgs_copy[-1]["content"][-1]["text"] = f""" |
|
user_id: {user_id} |
|
nickname: {nickname} |
|
content: {_msgs_copy[-1]["content"][-1]["text"]} |
|
""".strip() |
|
else: |
|
_msgs_copy[-1]["content"] = f""" |
|
user_id: {user_id} |
|
nickname: {nickname} |
|
content: {_msgs_copy[-1]["content"]} |
|
""".strip() |
|
|
|
text_inputs = processor.apply_chat_template(_msgs_copy, tokenize=False, add_generation_prompt=True) |
|
image_inputs, video_inputs = process_vision_info(_msgs_copy) |
|
response = "" |
|
args[0].append(ChatMessage(role="assistant", content="")) |
|
messages.append({"role": "assistant", "content": ""}) |
|
STREAMING_FLAG = True |
|
for token in core_infer((text_inputs, image_inputs, video_inputs), max_tokens, temperature, top_p): |
|
if STREAMING_STOP_FLAG: |
|
response += "...(Interrupted)" |
|
args[0][-1].content = response |
|
messages[-1]["content"] = response |
|
STREAMING_STOP_FLAG = False |
|
yield response, messages, args[0] |
|
break |
|
response += token |
|
args[0][-1].content = response |
|
messages[-1]["content"] = response |
|
yield response, messages, args[0] |
|
STREAMING_FLAG = False |
|
|
|
def process_input( |
|
input: dict, |
|
history: list[ChatMessage], |
|
checkbox: list[str], |
|
messages: str |
|
) -> tuple[str, list[ChatMessage], bool, bool, list]: |
|
global STREAMING_FLAG, FORBIDDING_FLAG |
|
if STREAMING_FLAG or not isinstance(input["text"], str) or input["text"].strip() == "": |
|
FORBIDDING_FLAG = True |
|
return ( |
|
input["text"], |
|
history, |
|
"允许使用工具" in checkbox, |
|
"允许使用代理模型" in checkbox, |
|
messages |
|
) |
|
FORBIDDING_FLAG = False |
|
if len(history) <= 0: |
|
messages = [{"role": "system", "content": SYSTEM_PROMPT}] |
|
history.append(ChatMessage(role="system", content=SYSTEM_PROMPT)) |
|
else: |
|
if history[0]["role"] != "system": |
|
history.insert(0, ChatMessage(role="system", content=SYSTEM_PROMPT)) |
|
if messages[0]["role"] != "system": |
|
messages.insert(0, {"role": "system", "content": SYSTEM_PROMPT}) |
|
message = {"role": "user", "content": []} |
|
while isinstance(input["files"], list) and len(input["files"]) > 0: |
|
path = input["files"].pop(0) |
|
message["content"].append({"type": "image", "image": f"file://{path}"}) |
|
history.append(ChatMessage(role="user", content=FileData(path=path))) |
|
message["content"].append({"type": "text", "text": input["text"]}) |
|
if len(message["content"]) == 1 and message["content"][0]["type"] == "text": |
|
message["content"] = message["content"][0]["text"] |
|
history.append(ChatMessage(role="user", content=input["text"])) |
|
messages.append(message) |
|
return ( |
|
"", |
|
history, |
|
"允许使用工具" in checkbox, |
|
"允许使用代理模型" in checkbox, |
|
messages, |
|
) |
|
|
|
with Blocks() as app: |
|
text_response = Textbox("", visible=False, interactive=False) |
|
use_tools = Checkbox(visible=False, interactive=False) |
|
use_agents = Checkbox(visible=False, interactive=False) |
|
Markdown("# 小软Softie") |
|
with Row(): |
|
with Column(scale=3, min_width=500): |
|
chatbot = Chatbot(type="messages", avatar_images=(None, "avatar.jpg"), scale=3, min_height=640, min_width=500, show_label=False, show_copy_button=True, show_copy_all_button=True) |
|
textbox = MultimodalTextbox(file_types=[".jpg", ".jpeg", ".png"], file_count="multiple", stop_btn=True, show_label=False, autofocus=False, placeholder="在此输入内容") |
|
with Column(scale=1, min_width=300): |
|
with Column(scale=0): |
|
max_tokens = Slider(interactive=True, minimum=1, maximum=2048, value=512, step=1, label="max_tokens", info="最大生成长度") |
|
temperature = Slider(interactive=True, minimum=0.01, maximum=4.0, value=0.75, step=0.01, label="temperature", info="温度系数") |
|
top_p = Slider(interactive=True, minimum=0.01, maximum=1.0, value=0.5, step=0.01, label="top_p", info="核取样系数") |
|
checkbox = CheckboxGroup(["允许使用工具", "允许使用代理模型"], label="options", info="功能选项(开发中)") |
|
with Column(scale=0): |
|
user_id = Textbox(value=123456789, label="user_id", info="用户ID") |
|
nickname = Textbox(value="用户1234", label="nickname", info="用户昵称") |
|
json_messages = JSON([], max_height=25, label="json_messages") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
chatbot.retry( |
|
callback, |
|
[textbox, chatbot, json_messages], |
|
[textbox, chatbot, json_messages], |
|
api_name=False, |
|
show_api=False, |
|
) |
|
chatbot.like(lambda: None, api_name=False, show_api=False) |
|
|
|
textbox.submit( |
|
process_input, |
|
[textbox, chatbot, checkbox, json_messages], |
|
[textbox, chatbot, use_tools, use_agents, json_messages], |
|
queue=False, |
|
api_name=False, |
|
show_api=False, |
|
show_progress="hidden", |
|
trigger_mode="once", |
|
).then( |
|
process_model, |
|
[json_messages, user_id, nickname, max_tokens, temperature, top_p, use_tools, use_agents, chatbot], |
|
[text_response, json_messages, chatbot], |
|
queue=True, |
|
api_name="api", |
|
show_api=True, |
|
show_progress="hidden", |
|
trigger_mode="once", |
|
) |
|
|
|
textbox.stop(interrupt, api_name=False, show_api=False) |
|
|
|
if __name__ == "__main__": |
|
app.launch() |
|
|