File size: 9,037 Bytes
dc15a3f
 
374e122
d81f4b6
03d2f46
d81f4b6
 
161b347
03d2f46
161b347
03d2f46
 
dc15a3f
03d2f46
dc15a3f
03d2f46
 
 
 
 
d81f4b6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2941c6d
d81f4b6
469e885
dc15a3f
 
 
 
469e885
 
 
 
 
 
 
dc15a3f
 
 
 
 
 
 
 
374e122
af0c8f0
d81f4b6
1325e72
d81f4b6
 
 
 
bc0b758
 
 
d81f4b6
 
 
161b347
d81f4b6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f68ed0e
d81f4b6
 
 
 
 
 
 
 
 
 
f68ed0e
d81f4b6
95016f0
 
 
 
 
 
 
d81f4b6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
161b347
 
03d2f46
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
from spaces import GPU

from threading import Thread
from transformers import Qwen2VLForConditionalGeneration, Qwen2VLProcessor, TextIteratorStreamer, AutoProcessor
from qwen_vl_utils import process_vision_info

from gradio import ChatMessage, Chatbot, MultimodalTextbox, Slider, Checkbox, CheckboxGroup, Textbox, JSON, Blocks, Row, Column, Markdown, FileData

model_path = "Pectics/Softie-VL-7B-250123"

model = Qwen2VLForConditionalGeneration.from_pretrained(
    model_path,
    torch_dtype="auto",
    device_map="auto",
    attn_implementation="flash_attention_2",
)
min_pixels = 256 * 28 * 28
max_pixels = 1280 * 28 * 28
processor: Qwen2VLProcessor = AutoProcessor.from_pretrained(model_path, min_pixels=min_pixels, max_pixels=max_pixels)

SYSTEM_PROMPT = """
You are Softie, or 小软 in Chinese.
You are an intelligent assistant developed by the School of Software at Hefei University of Technology.
You like to chat with people and help them solve problems.
You will interact with the user in the QQ group chat, and the user message you receive will satisfy the following format:
user_id: 用户ID
nickname: 用户昵称
content: 用户消息内容
You will directly output your reply content without the need to format it.
""".strip()

STREAMING_FLAG = False
STREAMING_STOP_FLAG = False

FORBIDDING_FLAG = False

def interrupt() -> None:
    global STREAMING_FLAG
    if STREAMING_FLAG:
        global STREAMING_STOP_FLAG
        STREAMING_STOP_FLAG = True

def callback(
    input: dict,
    history: list[ChatMessage],
    messages: list
) -> tuple[str, list[ChatMessage], list]:
    if len(history) <= 1 or len(messages) <= 1:
        return input["text"], history, messages
    history.pop(); messages.pop(); messages.pop()
    return history.pop()["content"], history, messages

@GPU
def core_infer(
    inputs: tuple,
    max_tokens: int,
    temperature: float,
    top_p: float,
):
    inputs = processor(
        text=[inputs[0]],
        images=inputs[1],
        videos=inputs[2],
        padding=True,
        return_tensors="pt",
    ).to("cuda")
    streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
    kwargs = dict(
        **inputs,
        streamer=streamer,
        max_new_tokens=max_tokens,
        temperature=temperature,
        top_p=top_p,
    )
    Thread(target=model.generate, kwargs=kwargs).start()
    for token in streamer:
        yield token

def process_model(
    messages: str | list[object],
    user_id: str,
    nickname: str,
    max_tokens: int,
    temperature: float,
    top_p: float,
    use_tools: bool,
    use_agents: bool,
    *args: list,
):
    global STREAMING_FLAG, STREAMING_STOP_FLAG, FORBIDDING_FLAG
    if STREAMING_FLAG or FORBIDDING_FLAG or len(messages) <= 0 or messages[-1]["role"] != "user":
        yield None, messages, args[0]
        return
    # embed user details
    _msgs_copy = messages.copy()
    if isinstance(_msgs_copy[-1]["content"], list):
        if len(_msgs_copy[-1]["content"]) <= 0 or _msgs_copy[-1]["content"][-1]["type"] != "text":
            _msgs_copy[-1]["content"].insert(0, {
                "type": "text",
                "text": f"""
user_id: {user_id}
nickname: {nickname}
content: """.lstrip(),
            })
        else:
            _msgs_copy[-1]["content"][-1]["text"] = f"""
user_id: {user_id}
nickname: {nickname}
content: {_msgs_copy[-1]["content"][-1]["text"]}
            """.strip()
    else:
        _msgs_copy[-1]["content"] = f"""
user_id: {user_id}
nickname: {nickname}
content: {_msgs_copy[-1]["content"]}
        """.strip()
    # process messages
    text_inputs = processor.apply_chat_template(_msgs_copy, tokenize=False, add_generation_prompt=True)
    image_inputs, video_inputs = process_vision_info(_msgs_copy)
    response = ""
    args[0].append(ChatMessage(role="assistant", content=""))
    messages.append({"role": "assistant", "content": ""})
    STREAMING_FLAG = True
    for token in core_infer((text_inputs, image_inputs, video_inputs), max_tokens, temperature, top_p):
        if STREAMING_STOP_FLAG:
            response += "...(Interrupted)"
            args[0][-1].content = response
            messages[-1]["content"] = response
            STREAMING_STOP_FLAG = False
            yield response, messages, args[0]
            break
        response += token
        args[0][-1].content = response
        messages[-1]["content"] = response
        yield response, messages, args[0]
    STREAMING_FLAG = False

def process_input(
    input: dict,
    history: list[ChatMessage],
    checkbox: list[str],
    messages: str
) -> tuple[str, list[ChatMessage], bool, bool, list]:
    global STREAMING_FLAG, FORBIDDING_FLAG
    if STREAMING_FLAG or not isinstance(input["text"], str) or input["text"].strip() == "":
        FORBIDDING_FLAG = True
        return (
            input["text"],
            history,
            "允许使用工具" in checkbox,
            "允许使用代理模型" in checkbox,
            messages
        )
    FORBIDDING_FLAG = False
    if len(history) <= 0:
        messages = [{"role": "system", "content": SYSTEM_PROMPT}]
        history.append(ChatMessage(role="system", content=SYSTEM_PROMPT))
    else:
        if history[0]["role"] != "system":
            history.insert(0, ChatMessage(role="system", content=SYSTEM_PROMPT))
        if messages[0]["role"] != "system":
            messages.insert(0, {"role": "system", "content": SYSTEM_PROMPT})
    message = {"role": "user", "content": []}
    while isinstance(input["files"], list) and len(input["files"]) > 0:
        path = input["files"].pop(0)
        message["content"].append({"type": "image", "image": f"file://{path}"}) # Qwen2VL format
        history.append(ChatMessage(role="user", content=FileData(path=path)))
    message["content"].append({"type": "text", "text": input["text"]})
    if len(message["content"]) == 1 and message["content"][0]["type"] == "text":
        message["content"] = message["content"][0]["text"]
    history.append(ChatMessage(role="user", content=input["text"]))
    messages.append(message)
    return (
        "",
        history,
        "允许使用工具" in checkbox,
        "允许使用代理模型" in checkbox,
        messages,
    )

with Blocks() as app:
    text_response = Textbox("", visible=False, interactive=False)
    use_tools = Checkbox(visible=False, interactive=False)
    use_agents = Checkbox(visible=False, interactive=False)
    Markdown("# 小软Softie")
    with Row():
        with Column(scale=3, min_width=500):
            chatbot = Chatbot(type="messages", avatar_images=(None, "avatar.jpg"), scale=3, min_height=640, min_width=500, show_label=False, show_copy_button=True, show_copy_all_button=True)
            textbox = MultimodalTextbox(file_types=[".jpg", ".jpeg", ".png"], file_count="multiple", stop_btn=True, show_label=False, autofocus=False, placeholder="在此输入内容")
        with Column(scale=1, min_width=300):
            with Column(scale=0):
                max_tokens = Slider(interactive=True, minimum=1, maximum=2048, value=512, step=1, label="max_tokens", info="最大生成长度")
                temperature = Slider(interactive=True, minimum=0.01, maximum=4.0, value=0.75, step=0.01, label="temperature", info="温度系数")
                top_p = Slider(interactive=True, minimum=0.01, maximum=1.0, value=0.5, step=0.01, label="top_p", info="核取样系数")
            checkbox = CheckboxGroup(["允许使用工具", "允许使用代理模型"], label="options", info="功能选项(开发中)")
            with Column(scale=0):
                user_id = Textbox(value=123456789, label="user_id", info="用户ID")
                nickname = Textbox(value="用户1234", label="nickname", info="用户昵称")
            json_messages = JSON([], max_height=25, label="json_messages")

    ## NOT SUPPORT IN GRADIO 5.0.1
    # chatbot.clear(
    #     lambda: ("[]", ""),
    #     outputs=[json_messages, text_response],
    #     api_name=False,
    #     show_api=False,
    # )
    chatbot.retry(
        callback,
        [textbox, chatbot, json_messages],
        [textbox, chatbot, json_messages],
        api_name=False,
        show_api=False,
    )
    chatbot.like(lambda: None, api_name=False, show_api=False)

    textbox.submit(
        process_input,
        [textbox, chatbot, checkbox, json_messages],
        [textbox, chatbot, use_tools, use_agents, json_messages],
        queue=False,
        api_name=False,
        show_api=False,
        show_progress="hidden",
        trigger_mode="once",
    ).then(
        process_model,
        [json_messages, user_id, nickname, max_tokens, temperature, top_p, use_tools, use_agents, chatbot],
        [text_response, json_messages, chatbot],
        queue=True,
        api_name="api",
        show_api=True,
        show_progress="hidden",
        trigger_mode="once",
    )

    textbox.stop(interrupt, api_name=False, show_api=False)

if __name__ == "__main__":
    app.launch()