Spaces:
Runtime error
Runtime error
# app.py | |
from threading import Thread | |
import gradio as gr | |
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer | |
import torch | |
import spaces | |
# --------------------------------------------- | |
# 1. 加载模型与 Tokenizer | |
# --------------------------------------------- | |
# 如果你的模型需要加速/量化等特殊配置,可在 from_pretrained() 中添加相应参数 | |
# 例如 device_map='auto' 或 trust_remote_code=True 等 | |
model_name = "agentica-org/DeepScaleR-1.5B-Preview" | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto") | |
# 根据需要加上 .half()/.float()/.quantize() 等操作 | |
# 例如 | |
# model.half() | |
# 或者 | |
# model = model.quantize(4/8) # 如果你的模型和环境支持 | |
# --------------------------------------------- | |
# 2. 对话历史处理 | |
# --------------------------------------------- | |
def preprocess_messages(history): | |
""" | |
将所有的用户与回复消息拼成一个文本 prompt。 | |
这里仅示例最简单的形式: | |
User: ... | |
Assistant: ... | |
最后再接上 "Assistant: " 用于提示模型继续生成。 | |
你也可以修改为自己需要的对话模板。 | |
""" | |
prompt = "" | |
for user_msg, assistant_msg in history: | |
if user_msg: | |
prompt += f"User: {user_msg}\n" | |
if assistant_msg: | |
prompt += f"Assistant: {assistant_msg}\n" | |
# 继续生成时,让模型再续写 "Assistant:" | |
prompt += "Assistant: " | |
return prompt | |
# --------------------------------------------- | |
# 3. 预测函数 | |
# --------------------------------------------- | |
# 使用 huggingface spaces 的 GPU 装饰器 | |
def predict(history, max_length, top_p, temperature): | |
""" | |
输入为 history(对话历史)和若干超参,输出流式生成的结果。 | |
每生成一个 token,就通过 yield 返回给 Gradio,更新界面。 | |
""" | |
prompt = preprocess_messages(history) | |
# 组装输入 | |
inputs = tokenizer(prompt, return_tensors="pt") | |
input_ids = inputs["input_ids"].to(model.device) | |
# 使用 TextIteratorStreamer 来实现流式输出 | |
streamer = TextIteratorStreamer( | |
tokenizer=tokenizer, | |
timeout=60, | |
skip_prompt=True, | |
skip_special_tokens=True | |
) | |
generate_kwargs = { | |
"input_ids": input_ids, | |
"max_new_tokens": max_length, | |
"do_sample": True, | |
"top_p": top_p, | |
"temperature": temperature, | |
"repetition_penalty": 1.2, | |
"streamer": streamer, | |
# 如果需要自定义一些特殊 token 或其他参数可在此补充 | |
# "eos_token_id": ... | |
} | |
# 启动一个线程去执行 generate,然后主线程读取流式输出 | |
t = Thread(target=model.generate, kwargs=generate_kwargs) | |
t.start() | |
# history[-1][1] 存放当前最新的 assistant 回复,因此不断累加 | |
partial_output = "" | |
for new_token in streamer: | |
partial_output += new_token | |
history[-1][1] = partial_output | |
yield history | |
# --------------------------------------------- | |
# 4. 搭建 Gradio 界面 | |
# --------------------------------------------- | |
def main(): | |
with gr.Blocks() as demo: | |
gr.HTML("<h1 align='center'>DeepScaleR-1.5B-Preview Chat Demo</h1>") | |
# 聊天窗口 | |
chatbot = gr.Chatbot() | |
with gr.Row(): | |
with gr.Column(scale=2): | |
user_input = gr.Textbox( | |
show_label=True, | |
placeholder="请输入您的问题...", | |
label="User Input" | |
) | |
submitBtn = gr.Button("Submit") | |
emptyBtn = gr.Button("Clear History") | |
with gr.Column(scale=1): | |
max_length = gr.Slider( | |
minimum=0, | |
maximum=32000, # 根据模型能力自行调整 | |
value=512, | |
step=1, | |
label="Max New Tokens", | |
interactive=True | |
) | |
top_p = gr.Slider( | |
minimum=0, | |
maximum=1, | |
value=0.8, | |
step=0.01, | |
label="Top P", | |
interactive=True | |
) | |
temperature = gr.Slider( | |
minimum=0.01, | |
maximum=2.0, | |
value=0.7, | |
step=0.01, | |
label="Temperature", | |
interactive=True | |
) | |
# 用于将用户输入插入到 chatbot 历史中 | |
def user(query, history): | |
return "", history + [[query, ""]] | |
# Submit: | |
# 1) user() -> 新增一条 (user输入,"") 的对话记录 | |
# 2) predict() -> 基于更新后的 history 进行生成 | |
submitBtn.click( | |
fn=user, | |
inputs=[user_input, chatbot], | |
outputs=[user_input, chatbot], | |
queue=False | |
).then( | |
fn=predict, | |
inputs=[chatbot, max_length, top_p, temperature], | |
outputs=chatbot | |
) | |
# Clear: 清空对话历史 | |
def clear_history(): | |
return [], [] | |
emptyBtn.click( | |
fn=clear_history, | |
inputs=[], | |
outputs=[chatbot, user_input], | |
queue=False | |
) | |
# 可选:让 Gradio 自动对排队请求进行调度 | |
demo.queue() | |
demo.launch() | |
if __name__ == "__main__": | |
main() | |