Spaces:
Runtime error
Runtime error
File size: 5,604 Bytes
ba12288 b6f75f9 a11ec0f b6f75f9 ba12288 b6f75f9 a11ec0f ba12288 b6f75f9 ba12288 b6f75f9 ba12288 b6f75f9 ba12288 b6f75f9 ba12288 b6f75f9 ba12288 b6f75f9 ba12288 b6f75f9 ba12288 b6f75f9 ba12288 b6f75f9 ba12288 b6f75f9 ba12288 b6f75f9 ba12288 b6f75f9 ba12288 b6f75f9 ba12288 ff7ad4f ba12288 b6f75f9 ba12288 b6f75f9 ba12288 b6f75f9 ba12288 b6f75f9 ba12288 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 |
# app.py
from threading import Thread
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
import torch
import spaces
# ---------------------------------------------
# 1. 加载模型与 Tokenizer
# ---------------------------------------------
# 如果你的模型需要加速/量化等特殊配置,可在 from_pretrained() 中添加相应参数
# 例如 device_map='auto' 或 trust_remote_code=True 等
model_name = "agentica-org/DeepScaleR-1.5B-Preview"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto")
# 根据需要加上 .half()/.float()/.quantize() 等操作
# 例如
# model.half()
# 或者
# model = model.quantize(4/8) # 如果你的模型和环境支持
# ---------------------------------------------
# 2. 对话历史处理
# ---------------------------------------------
def preprocess_messages(history):
"""
将所有的用户与回复消息拼成一个文本 prompt。
这里仅示例最简单的形式:
User: ...
Assistant: ...
最后再接上 "Assistant: " 用于提示模型继续生成。
你也可以修改为自己需要的对话模板。
"""
prompt = ""
for user_msg, assistant_msg in history:
if user_msg:
prompt += f"User: {user_msg}\n"
if assistant_msg:
prompt += f"Assistant: {assistant_msg}\n"
# 继续生成时,让模型再续写 "Assistant:"
prompt += "Assistant: "
return prompt
# ---------------------------------------------
# 3. 预测函数
# ---------------------------------------------
@spaces.GPU() # 使用 huggingface spaces 的 GPU 装饰器
def predict(history, max_length, top_p, temperature):
"""
输入为 history(对话历史)和若干超参,输出流式生成的结果。
每生成一个 token,就通过 yield 返回给 Gradio,更新界面。
"""
prompt = preprocess_messages(history)
# 组装输入
inputs = tokenizer(prompt, return_tensors="pt")
input_ids = inputs["input_ids"].to(model.device)
# 使用 TextIteratorStreamer 来实现流式输出
streamer = TextIteratorStreamer(
tokenizer=tokenizer,
timeout=60,
skip_prompt=True,
skip_special_tokens=True
)
generate_kwargs = {
"input_ids": input_ids,
"max_new_tokens": max_length,
"do_sample": True,
"top_p": top_p,
"temperature": temperature,
"repetition_penalty": 1.2,
"streamer": streamer,
# 如果需要自定义一些特殊 token 或其他参数可在此补充
# "eos_token_id": ...
}
# 启动一个线程去执行 generate,然后主线程读取流式输出
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
# history[-1][1] 存放当前最新的 assistant 回复,因此不断累加
partial_output = ""
for new_token in streamer:
partial_output += new_token
history[-1][1] = partial_output
yield history
# ---------------------------------------------
# 4. 搭建 Gradio 界面
# ---------------------------------------------
def main():
with gr.Blocks() as demo:
gr.HTML("<h1 align='center'>DeepScaleR-1.5B-Preview Chat Demo</h1>")
# 聊天窗口
chatbot = gr.Chatbot()
with gr.Row():
with gr.Column(scale=2):
user_input = gr.Textbox(
show_label=True,
placeholder="请输入您的问题...",
label="User Input"
)
submitBtn = gr.Button("Submit")
emptyBtn = gr.Button("Clear History")
with gr.Column(scale=1):
max_length = gr.Slider(
minimum=0,
maximum=32000, # 根据模型能力自行调整
value=512,
step=1,
label="Max New Tokens",
interactive=True
)
top_p = gr.Slider(
minimum=0,
maximum=1,
value=0.8,
step=0.01,
label="Top P",
interactive=True
)
temperature = gr.Slider(
minimum=0.01,
maximum=2.0,
value=0.7,
step=0.01,
label="Temperature",
interactive=True
)
# 用于将用户输入插入到 chatbot 历史中
def user(query, history):
return "", history + [[query, ""]]
# Submit:
# 1) user() -> 新增一条 (user输入,"") 的对话记录
# 2) predict() -> 基于更新后的 history 进行生成
submitBtn.click(
fn=user,
inputs=[user_input, chatbot],
outputs=[user_input, chatbot],
queue=False
).then(
fn=predict,
inputs=[chatbot, max_length, top_p, temperature],
outputs=chatbot
)
# Clear: 清空对话历史
def clear_history():
return [], []
emptyBtn.click(
fn=clear_history,
inputs=[],
outputs=[chatbot, user_input],
queue=False
)
# 可选:让 Gradio 自动对排队请求进行调度
demo.queue()
demo.launch()
if __name__ == "__main__":
main()
|