File size: 5,604 Bytes
ba12288
b6f75f9
a11ec0f
b6f75f9
ba12288
b6f75f9
a11ec0f
ba12288
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b6f75f9
ba12288
 
 
 
 
 
 
 
 
 
b6f75f9
ba12288
 
 
 
 
 
 
 
 
 
 
 
b6f75f9
ba12288
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b6f75f9
ba12288
b6f75f9
 
 
 
 
ba12288
 
 
b6f75f9
 
ba12288
b6f75f9
 
ba12288
 
 
b6f75f9
ba12288
 
b6f75f9
 
ba12288
 
 
b6f75f9
 
ba12288
b6f75f9
ba12288
 
b6f75f9
 
 
ba12288
 
 
 
 
b6f75f9
 
 
ba12288
 
ff7ad4f
ba12288
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b6f75f9
 
 
ba12288
 
 
 
 
 
 
 
 
 
 
 
b6f75f9
 
ba12288
 
 
 
 
 
 
 
 
b6f75f9
ba12288
 
 
b6f75f9
 
ba12288
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
# app.py
from threading import Thread
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
import torch
import spaces

# ---------------------------------------------
# 1. 加载模型与 Tokenizer
# ---------------------------------------------
# 如果你的模型需要加速/量化等特殊配置,可在 from_pretrained() 中添加相应参数
# 例如 device_map='auto' 或 trust_remote_code=True 等
model_name = "agentica-org/DeepScaleR-1.5B-Preview"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto")

# 根据需要加上 .half()/.float()/.quantize() 等操作
# 例如
# model.half()
# 或者
# model = model.quantize(4/8)  # 如果你的模型和环境支持

# ---------------------------------------------
# 2. 对话历史处理
# ---------------------------------------------
def preprocess_messages(history):
    """
    将所有的用户与回复消息拼成一个文本 prompt。
    这里仅示例最简单的形式:
       User: ...
       Assistant: ...
    最后再接上 "Assistant: " 用于提示模型继续生成。
    你也可以修改为自己需要的对话模板。
    """
    prompt = ""
    for user_msg, assistant_msg in history:
        if user_msg:
            prompt += f"User: {user_msg}\n"
        if assistant_msg:
            prompt += f"Assistant: {assistant_msg}\n"

    # 继续生成时,让模型再续写 "Assistant:"
    prompt += "Assistant: "
    return prompt

# ---------------------------------------------
# 3. 预测函数
# ---------------------------------------------
@spaces.GPU()  # 使用 huggingface spaces 的 GPU 装饰器
def predict(history, max_length, top_p, temperature):
    """
    输入为 history(对话历史)和若干超参,输出流式生成的结果。
    每生成一个 token,就通过 yield 返回给 Gradio,更新界面。
    """
    prompt = preprocess_messages(history)

    # 组装输入
    inputs = tokenizer(prompt, return_tensors="pt")
    input_ids = inputs["input_ids"].to(model.device)

    # 使用 TextIteratorStreamer 来实现流式输出
    streamer = TextIteratorStreamer(
        tokenizer=tokenizer,
        timeout=60,
        skip_prompt=True,
        skip_special_tokens=True
    )

    generate_kwargs = {
        "input_ids": input_ids,
        "max_new_tokens": max_length,
        "do_sample": True,
        "top_p": top_p,
        "temperature": temperature,
        "repetition_penalty": 1.2,
        "streamer": streamer,
        # 如果需要自定义一些特殊 token 或其他参数可在此补充
        # "eos_token_id": ...
    }

    # 启动一个线程去执行 generate,然后主线程读取流式输出
    t = Thread(target=model.generate, kwargs=generate_kwargs)
    t.start()

    # history[-1][1] 存放当前最新的 assistant 回复,因此不断累加
    partial_output = ""
    for new_token in streamer:
        partial_output += new_token
        history[-1][1] = partial_output
        yield history

# ---------------------------------------------
# 4. 搭建 Gradio 界面
# ---------------------------------------------
def main():
    with gr.Blocks() as demo:
        gr.HTML("<h1 align='center'>DeepScaleR-1.5B-Preview Chat Demo</h1>")

        # 聊天窗口
        chatbot = gr.Chatbot()

        with gr.Row():
            with gr.Column(scale=2):
                user_input = gr.Textbox(
                    show_label=True, 
                    placeholder="请输入您的问题...", 
                    label="User Input"
                )
                submitBtn = gr.Button("Submit")
                emptyBtn = gr.Button("Clear History")
            with gr.Column(scale=1):
                max_length = gr.Slider(
                    minimum=0, 
                    maximum=32000,  # 根据模型能力自行调整
                    value=512, 
                    step=1, 
                    label="Max New Tokens", 
                    interactive=True
                )
                top_p = gr.Slider(
                    minimum=0, 
                    maximum=1, 
                    value=0.8, 
                    step=0.01, 
                    label="Top P", 
                    interactive=True
                )
                temperature = gr.Slider(
                    minimum=0.01, 
                    maximum=2.0, 
                    value=0.7, 
                    step=0.01, 
                    label="Temperature", 
                    interactive=True
                )

        # 用于将用户输入插入到 chatbot 历史中
        def user(query, history):
            return "", history + [[query, ""]]

        # Submit: 
        # 1) user() -> 新增一条 (user输入,"") 的对话记录
        # 2) predict() -> 基于更新后的 history 进行生成
        submitBtn.click(
            fn=user, 
            inputs=[user_input, chatbot], 
            outputs=[user_input, chatbot],
            queue=False
        ).then(
            fn=predict, 
            inputs=[chatbot, max_length, top_p, temperature], 
            outputs=chatbot
        )

        # Clear: 清空对话历史
        def clear_history():
            return [], []
        emptyBtn.click(
            fn=clear_history, 
            inputs=[], 
            outputs=[chatbot, user_input], 
            queue=False
        )

        # 可选:让 Gradio 自动对排队请求进行调度
        demo.queue()
        demo.launch()

if __name__ == "__main__":
    main()