File size: 5,131 Bytes
ba12288
a11ec0f
b6f75f9
a493bde
 
 
 
 
 
 
 
a11ec0f
a493bde
ba12288
a493bde
ba12288
 
 
 
a493bde
 
 
 
ba12288
a493bde
 
 
b6f75f9
ba12288
a493bde
 
ba12288
 
 
b6f75f9
ba12288
 
 
a493bde
ba12288
 
 
a493bde
 
 
 
 
b6f75f9
ba12288
a493bde
 
ba12288
 
 
a493bde
 
 
 
 
 
 
ba12288
a493bde
ba12288
a493bde
ba12288
 
 
 
 
 
 
b6f75f9
ba12288
a493bde
 
b6f75f9
 
 
 
ba12288
b6f75f9
 
a493bde
b6f75f9
 
ba12288
a493bde
ba12288
b6f75f9
ba12288
 
b6f75f9
 
a493bde
 
 
 
b6f75f9
 
a493bde
b6f75f9
ba12288
b6f75f9
 
 
ba12288
a493bde
 
ba12288
 
b6f75f9
a493bde
b6f75f9
ba12288
a493bde
 
 
 
 
ba12288
 
 
a493bde
 
 
 
 
ba12288
 
 
a493bde
 
 
 
 
ba12288
 
 
a493bde
b6f75f9
 
 
ba12288
a493bde
 
ba12288
a493bde
ba12288
a493bde
 
ba12288
b6f75f9
 
a493bde
ba12288
 
b6f75f9
a493bde
 
 
 
ba12288
b6f75f9
a493bde
 
 
b6f75f9
ba12288
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
# app.py
import gradio as gr
import spaces
from threading import Thread
import torch

from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    TextIteratorStreamer,
)

# ------------------------------
# 1. 加载模型与 Tokenizer
# ------------------------------
model_name = "agentica-org/DeepScaleR-1.5B-Preview"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto")

# 如果 tokenizer 没有设置 pad_token_id,则显式指定为 eos_token_id
if tokenizer.pad_token_id is None:
    tokenizer.pad_token_id = tokenizer.eos_token_id


# ------------------------------
# 2. 对话历史 -> Prompt 格式
# ------------------------------
def preprocess_messages(history):
    """
    将聊天记录拼成一个最简单的 Prompt。
    你可以自定义更适合该模型的提示格式或特殊 Token。
    """
    prompt = ""
    for user_msg, assistant_msg in history:
        if user_msg:
            prompt += f"User: {user_msg}\n"
        if assistant_msg:
            prompt += f"Assistant: {assistant_msg}\n"
    # 继续生成时,提示 "Assistant:"
    prompt += "Assistant: "
    return prompt


# ------------------------------
# 3. 预测 / 推理函数
# ------------------------------
@spaces.GPU()  # 让 huggingface spaces 调用 GPU
def predict(history, max_length, top_p, temperature):
    """
    基于当前的 history 做文本生成。
    使用 HF 提供的 TextIteratorStreamer 实现流式生成。
    """
    prompt = preprocess_messages(history)

    inputs = tokenizer(
        prompt,
        return_tensors="pt",
        padding=True,       # 自动 padding
        truncation=True,    # 超长截断
        max_length=2048     # 你可根据显存大小或模型上限做调整
    )
    input_ids = inputs["input_ids"].to(model.device)
    attention_mask = inputs["attention_mask"].to(model.device)

    # 流式输出器
    streamer = TextIteratorStreamer(
        tokenizer=tokenizer,
        timeout=60,
        skip_prompt=True,
        skip_special_tokens=True
    )

    generate_kwargs = {
        "input_ids": input_ids,
        "attention_mask": attention_mask,
        "max_new_tokens": max_length,   # 新生成的 token 数
        "do_sample": True,
        "top_p": top_p,
        "temperature": temperature,
        "repetition_penalty": 1.2,
        "streamer": streamer,
    }

    # 在后台线程中执行 generate,主线程循环读取新 token
    t = Thread(target=model.generate, kwargs=generate_kwargs)
    t.start()

    # 将最新生成的 token 依次拼接到 history[-1][1]
    partial_output = ""
    for new_token in streamer:
        partial_output += new_token
        history[-1][1] = partial_output
        yield history


# ------------------------------
# 4. Gradio UI
# ------------------------------
def main():
    with gr.Blocks() as demo:
        gr.HTML("<h1 align='center'>DeepScaleR-1.5B Chat Demo</h1>")

        chatbot = gr.Chatbot()

        with gr.Row():
            with gr.Column(scale=2):
                user_input = gr.Textbox(
                    show_label=True,
                    placeholder="请输入您的问题...",
                    label="User Input"
                )
                submitBtn = gr.Button("Submit")
                clearBtn = gr.Button("Clear History")
            with gr.Column(scale=1):
                max_length = gr.Slider(
                    minimum=0,
                    maximum=1024,  # 可根据需要调大/调小
                    value=512,
                    step=1,
                    label="Max New Tokens",
                    interactive=True
                )
                top_p = gr.Slider(
                    minimum=0,
                    maximum=1,
                    value=0.8,
                    step=0.01,
                    label="Top P",
                    interactive=True
                )
                temperature = gr.Slider(
                    minimum=0.0,
                    maximum=2.0,
                    value=0.7,
                    step=0.01,
                    label="Temperature",
                    interactive=True
                )

        # 用户点击 Submit 时,先将输入添加到 history,然后再调用 predict 生成
        def user(query, history):
            return "", history + [[query, ""]]

        submitBtn.click(
            fn=user,
            inputs=[user_input, chatbot],
            outputs=[user_input, chatbot],
            queue=False  # 不排队
        ).then(
            fn=predict,
            inputs=[chatbot, max_length, top_p, temperature],
            outputs=chatbot
        )

        # 清空聊天记录
        def clear_history():
            return [], []

        clearBtn.click(fn=clear_history, inputs=[], outputs=[chatbot, user_input], queue=False)

        # 可选:启用队列防止并发冲突
        demo.queue(concurrency_count=1)
        demo.launch()

# ------------------------------
# 入口
# ------------------------------
if __name__ == "__main__":
    main()