Spaces:
Runtime error
Runtime error
File size: 5,131 Bytes
ba12288 a11ec0f b6f75f9 a493bde a11ec0f a493bde ba12288 a493bde ba12288 a493bde ba12288 a493bde b6f75f9 ba12288 a493bde ba12288 b6f75f9 ba12288 a493bde ba12288 a493bde b6f75f9 ba12288 a493bde ba12288 a493bde ba12288 a493bde ba12288 a493bde ba12288 b6f75f9 ba12288 a493bde b6f75f9 ba12288 b6f75f9 a493bde b6f75f9 ba12288 a493bde ba12288 b6f75f9 ba12288 b6f75f9 a493bde b6f75f9 a493bde b6f75f9 ba12288 b6f75f9 ba12288 a493bde ba12288 b6f75f9 a493bde b6f75f9 ba12288 a493bde ba12288 a493bde ba12288 a493bde ba12288 a493bde b6f75f9 ba12288 a493bde ba12288 a493bde ba12288 a493bde ba12288 b6f75f9 a493bde ba12288 b6f75f9 a493bde ba12288 b6f75f9 a493bde b6f75f9 ba12288 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 |
# app.py
import gradio as gr
import spaces
from threading import Thread
import torch
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
TextIteratorStreamer,
)
# ------------------------------
# 1. 加载模型与 Tokenizer
# ------------------------------
model_name = "agentica-org/DeepScaleR-1.5B-Preview"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto")
# 如果 tokenizer 没有设置 pad_token_id,则显式指定为 eos_token_id
if tokenizer.pad_token_id is None:
tokenizer.pad_token_id = tokenizer.eos_token_id
# ------------------------------
# 2. 对话历史 -> Prompt 格式
# ------------------------------
def preprocess_messages(history):
"""
将聊天记录拼成一个最简单的 Prompt。
你可以自定义更适合该模型的提示格式或特殊 Token。
"""
prompt = ""
for user_msg, assistant_msg in history:
if user_msg:
prompt += f"User: {user_msg}\n"
if assistant_msg:
prompt += f"Assistant: {assistant_msg}\n"
# 继续生成时,提示 "Assistant:"
prompt += "Assistant: "
return prompt
# ------------------------------
# 3. 预测 / 推理函数
# ------------------------------
@spaces.GPU() # 让 huggingface spaces 调用 GPU
def predict(history, max_length, top_p, temperature):
"""
基于当前的 history 做文本生成。
使用 HF 提供的 TextIteratorStreamer 实现流式生成。
"""
prompt = preprocess_messages(history)
inputs = tokenizer(
prompt,
return_tensors="pt",
padding=True, # 自动 padding
truncation=True, # 超长截断
max_length=2048 # 你可根据显存大小或模型上限做调整
)
input_ids = inputs["input_ids"].to(model.device)
attention_mask = inputs["attention_mask"].to(model.device)
# 流式输出器
streamer = TextIteratorStreamer(
tokenizer=tokenizer,
timeout=60,
skip_prompt=True,
skip_special_tokens=True
)
generate_kwargs = {
"input_ids": input_ids,
"attention_mask": attention_mask,
"max_new_tokens": max_length, # 新生成的 token 数
"do_sample": True,
"top_p": top_p,
"temperature": temperature,
"repetition_penalty": 1.2,
"streamer": streamer,
}
# 在后台线程中执行 generate,主线程循环读取新 token
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
# 将最新生成的 token 依次拼接到 history[-1][1]
partial_output = ""
for new_token in streamer:
partial_output += new_token
history[-1][1] = partial_output
yield history
# ------------------------------
# 4. Gradio UI
# ------------------------------
def main():
with gr.Blocks() as demo:
gr.HTML("<h1 align='center'>DeepScaleR-1.5B Chat Demo</h1>")
chatbot = gr.Chatbot()
with gr.Row():
with gr.Column(scale=2):
user_input = gr.Textbox(
show_label=True,
placeholder="请输入您的问题...",
label="User Input"
)
submitBtn = gr.Button("Submit")
clearBtn = gr.Button("Clear History")
with gr.Column(scale=1):
max_length = gr.Slider(
minimum=0,
maximum=1024, # 可根据需要调大/调小
value=512,
step=1,
label="Max New Tokens",
interactive=True
)
top_p = gr.Slider(
minimum=0,
maximum=1,
value=0.8,
step=0.01,
label="Top P",
interactive=True
)
temperature = gr.Slider(
minimum=0.0,
maximum=2.0,
value=0.7,
step=0.01,
label="Temperature",
interactive=True
)
# 用户点击 Submit 时,先将输入添加到 history,然后再调用 predict 生成
def user(query, history):
return "", history + [[query, ""]]
submitBtn.click(
fn=user,
inputs=[user_input, chatbot],
outputs=[user_input, chatbot],
queue=False # 不排队
).then(
fn=predict,
inputs=[chatbot, max_length, top_p, temperature],
outputs=chatbot
)
# 清空聊天记录
def clear_history():
return [], []
clearBtn.click(fn=clear_history, inputs=[], outputs=[chatbot, user_input], queue=False)
# 可选:启用队列防止并发冲突
demo.queue(concurrency_count=1)
demo.launch()
# ------------------------------
# 入口
# ------------------------------
if __name__ == "__main__":
main()
|