mattcracker's picture
Update app.py
ff7ad4f verified
raw
history blame
5.6 kB
# app.py
from threading import Thread
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
import torch
import spaces
# ---------------------------------------------
# 1. 加载模型与 Tokenizer
# ---------------------------------------------
# 如果你的模型需要加速/量化等特殊配置,可在 from_pretrained() 中添加相应参数
# 例如 device_map='auto' 或 trust_remote_code=True 等
model_name = "agentica-org/DeepScaleR-1.5B-Preview"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto")
# 根据需要加上 .half()/.float()/.quantize() 等操作
# 例如
# model.half()
# 或者
# model = model.quantize(4/8) # 如果你的模型和环境支持
# ---------------------------------------------
# 2. 对话历史处理
# ---------------------------------------------
def preprocess_messages(history):
"""
将所有的用户与回复消息拼成一个文本 prompt。
这里仅示例最简单的形式:
User: ...
Assistant: ...
最后再接上 "Assistant: " 用于提示模型继续生成。
你也可以修改为自己需要的对话模板。
"""
prompt = ""
for user_msg, assistant_msg in history:
if user_msg:
prompt += f"User: {user_msg}\n"
if assistant_msg:
prompt += f"Assistant: {assistant_msg}\n"
# 继续生成时,让模型再续写 "Assistant:"
prompt += "Assistant: "
return prompt
# ---------------------------------------------
# 3. 预测函数
# ---------------------------------------------
@spaces.GPU() # 使用 huggingface spaces 的 GPU 装饰器
def predict(history, max_length, top_p, temperature):
"""
输入为 history(对话历史)和若干超参,输出流式生成的结果。
每生成一个 token,就通过 yield 返回给 Gradio,更新界面。
"""
prompt = preprocess_messages(history)
# 组装输入
inputs = tokenizer(prompt, return_tensors="pt")
input_ids = inputs["input_ids"].to(model.device)
# 使用 TextIteratorStreamer 来实现流式输出
streamer = TextIteratorStreamer(
tokenizer=tokenizer,
timeout=60,
skip_prompt=True,
skip_special_tokens=True
)
generate_kwargs = {
"input_ids": input_ids,
"max_new_tokens": max_length,
"do_sample": True,
"top_p": top_p,
"temperature": temperature,
"repetition_penalty": 1.2,
"streamer": streamer,
# 如果需要自定义一些特殊 token 或其他参数可在此补充
# "eos_token_id": ...
}
# 启动一个线程去执行 generate,然后主线程读取流式输出
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
# history[-1][1] 存放当前最新的 assistant 回复,因此不断累加
partial_output = ""
for new_token in streamer:
partial_output += new_token
history[-1][1] = partial_output
yield history
# ---------------------------------------------
# 4. 搭建 Gradio 界面
# ---------------------------------------------
def main():
with gr.Blocks() as demo:
gr.HTML("<h1 align='center'>DeepScaleR-1.5B-Preview Chat Demo</h1>")
# 聊天窗口
chatbot = gr.Chatbot()
with gr.Row():
with gr.Column(scale=2):
user_input = gr.Textbox(
show_label=True,
placeholder="请输入您的问题...",
label="User Input"
)
submitBtn = gr.Button("Submit")
emptyBtn = gr.Button("Clear History")
with gr.Column(scale=1):
max_length = gr.Slider(
minimum=0,
maximum=32000, # 根据模型能力自行调整
value=512,
step=1,
label="Max New Tokens",
interactive=True
)
top_p = gr.Slider(
minimum=0,
maximum=1,
value=0.8,
step=0.01,
label="Top P",
interactive=True
)
temperature = gr.Slider(
minimum=0.01,
maximum=2.0,
value=0.7,
step=0.01,
label="Temperature",
interactive=True
)
# 用于将用户输入插入到 chatbot 历史中
def user(query, history):
return "", history + [[query, ""]]
# Submit:
# 1) user() -> 新增一条 (user输入,"") 的对话记录
# 2) predict() -> 基于更新后的 history 进行生成
submitBtn.click(
fn=user,
inputs=[user_input, chatbot],
outputs=[user_input, chatbot],
queue=False
).then(
fn=predict,
inputs=[chatbot, max_length, top_p, temperature],
outputs=chatbot
)
# Clear: 清空对话历史
def clear_history():
return [], []
emptyBtn.click(
fn=clear_history,
inputs=[],
outputs=[chatbot, user_input],
queue=False
)
# 可选:让 Gradio 自动对排队请求进行调度
demo.queue()
demo.launch()
if __name__ == "__main__":
main()