mattcracker commited on
Commit
a493bde
·
verified ·
1 Parent(s): ff7ad4f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +77 -78
app.py CHANGED
@@ -1,36 +1,34 @@
1
  # app.py
2
- from threading import Thread
3
  import gradio as gr
4
- from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
5
- import torch
6
  import spaces
 
 
 
 
 
 
 
 
7
 
8
- # ---------------------------------------------
9
  # 1. 加载模型与 Tokenizer
10
- # ---------------------------------------------
11
- # 如果你的模型需要加速/量化等特殊配置,可在 from_pretrained() 中添加相应参数
12
- # 例如 device_map='auto' 或 trust_remote_code=True 等
13
  model_name = "agentica-org/DeepScaleR-1.5B-Preview"
14
  tokenizer = AutoTokenizer.from_pretrained(model_name)
15
  model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto")
16
 
17
- # 根据需要加上 .half()/.float()/.quantize() 等操作
18
- # 例如
19
- # model.half()
20
- # 或者
21
- # model = model.quantize(4/8) # 如果你的模型和环境支持
22
 
23
- # ---------------------------------------------
24
- # 2. 对话历史处理
25
- # ---------------------------------------------
26
  def preprocess_messages(history):
27
  """
28
- 将所有的用户与回复消息拼成一个文本 prompt
29
- 这里仅示例最简单的形式:
30
- User: ...
31
- Assistant: ...
32
- 最后再接上 "Assistant: " 用于提示模型继续生成。
33
- 你也可以修改为自己需要的对话模板。
34
  """
35
  prompt = ""
36
  for user_msg, assistant_msg in history:
@@ -38,27 +36,33 @@ def preprocess_messages(history):
38
  prompt += f"User: {user_msg}\n"
39
  if assistant_msg:
40
  prompt += f"Assistant: {assistant_msg}\n"
41
-
42
- # 继续生成时,让模型再续写 "Assistant:"
43
  prompt += "Assistant: "
44
  return prompt
45
 
46
- # ---------------------------------------------
47
- # 3. 预测函数
48
- # ---------------------------------------------
49
- @spaces.GPU() # 使用 huggingface spaces 的 GPU 装饰器
 
50
  def predict(history, max_length, top_p, temperature):
51
  """
52
- 输入为 history(对话历史)和若干超参,输出流式生成的结果。
53
- 每生成一个 token,就通过 yield 返回给 Gradio,更新界面。
54
  """
55
  prompt = preprocess_messages(history)
56
 
57
- # 组装输入
58
- inputs = tokenizer(prompt, return_tensors="pt")
 
 
 
 
 
59
  input_ids = inputs["input_ids"].to(model.device)
 
60
 
61
- # 使用 TextIteratorStreamer 来实现流式输出
62
  streamer = TextIteratorStreamer(
63
  tokenizer=tokenizer,
64
  timeout=60,
@@ -68,103 +72,98 @@ def predict(history, max_length, top_p, temperature):
68
 
69
  generate_kwargs = {
70
  "input_ids": input_ids,
71
- "max_new_tokens": max_length,
 
72
  "do_sample": True,
73
  "top_p": top_p,
74
  "temperature": temperature,
75
  "repetition_penalty": 1.2,
76
  "streamer": streamer,
77
- # 如果需要自定义一些特殊 token 或其他参数可在此补充
78
- # "eos_token_id": ...
79
  }
80
 
81
- # 启动一个线程去执行 generate,然后主线程读取流式输出
82
  t = Thread(target=model.generate, kwargs=generate_kwargs)
83
  t.start()
84
 
85
- # history[-1][1] 存放当前最新的 assistant 回复,因此不断累加
86
  partial_output = ""
87
  for new_token in streamer:
88
  partial_output += new_token
89
  history[-1][1] = partial_output
90
  yield history
91
 
92
- # ---------------------------------------------
93
- # 4. 搭建 Gradio 界面
94
- # ---------------------------------------------
 
95
  def main():
96
  with gr.Blocks() as demo:
97
- gr.HTML("<h1 align='center'>DeepScaleR-1.5B-Preview Chat Demo</h1>")
98
 
99
- # 聊天窗口
100
  chatbot = gr.Chatbot()
101
 
102
  with gr.Row():
103
  with gr.Column(scale=2):
104
  user_input = gr.Textbox(
105
- show_label=True,
106
- placeholder="请输入您的问题...",
107
  label="User Input"
108
  )
109
  submitBtn = gr.Button("Submit")
110
- emptyBtn = gr.Button("Clear History")
111
  with gr.Column(scale=1):
112
  max_length = gr.Slider(
113
- minimum=0,
114
- maximum=32000, # 根据模型能力自行调整
115
- value=512,
116
- step=1,
117
- label="Max New Tokens",
118
  interactive=True
119
  )
120
  top_p = gr.Slider(
121
- minimum=0,
122
- maximum=1,
123
- value=0.8,
124
- step=0.01,
125
- label="Top P",
126
  interactive=True
127
  )
128
  temperature = gr.Slider(
129
- minimum=0.01,
130
- maximum=2.0,
131
- value=0.7,
132
- step=0.01,
133
- label="Temperature",
134
  interactive=True
135
  )
136
 
137
- # 用于将用户输入插入到 chatbot 历史中
138
  def user(query, history):
139
  return "", history + [[query, ""]]
140
 
141
- # Submit:
142
- # 1) user() -> 新增一条 (user输入,"") 的对话记录
143
- # 2) predict() -> 基于更新后的 history 进行生成
144
  submitBtn.click(
145
- fn=user,
146
- inputs=[user_input, chatbot],
147
  outputs=[user_input, chatbot],
148
- queue=False
149
  ).then(
150
- fn=predict,
151
- inputs=[chatbot, max_length, top_p, temperature],
152
  outputs=chatbot
153
  )
154
 
155
- # Clear: 清空对话历史
156
  def clear_history():
157
  return [], []
158
- emptyBtn.click(
159
- fn=clear_history,
160
- inputs=[],
161
- outputs=[chatbot, user_input],
162
- queue=False
163
- )
164
 
165
- # 可选:让 Gradio 自动对排队请求进行调度
166
- demo.queue()
 
 
167
  demo.launch()
168
 
 
 
 
169
  if __name__ == "__main__":
170
  main()
 
1
  # app.py
 
2
  import gradio as gr
 
 
3
  import spaces
4
+ from threading import Thread
5
+ import torch
6
+
7
+ from transformers import (
8
+ AutoTokenizer,
9
+ AutoModelForCausalLM,
10
+ TextIteratorStreamer,
11
+ )
12
 
13
+ # ------------------------------
14
  # 1. 加载模型与 Tokenizer
15
+ # ------------------------------
 
 
16
  model_name = "agentica-org/DeepScaleR-1.5B-Preview"
17
  tokenizer = AutoTokenizer.from_pretrained(model_name)
18
  model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto")
19
 
20
+ # 如果 tokenizer 没有设置 pad_token_id,则显式指定为 eos_token_id
21
+ if tokenizer.pad_token_id is None:
22
+ tokenizer.pad_token_id = tokenizer.eos_token_id
23
+
 
24
 
25
+ # ------------------------------
26
+ # 2. 对话历史 -> Prompt 格式
27
+ # ------------------------------
28
  def preprocess_messages(history):
29
  """
30
+ 将聊天记录拼成一个最简单的 Prompt
31
+ 你可以自定义更适合该模型的提示格式或特殊 Token。
 
 
 
 
32
  """
33
  prompt = ""
34
  for user_msg, assistant_msg in history:
 
36
  prompt += f"User: {user_msg}\n"
37
  if assistant_msg:
38
  prompt += f"Assistant: {assistant_msg}\n"
39
+ # 继续生成时,提示 "Assistant:"
 
40
  prompt += "Assistant: "
41
  return prompt
42
 
43
+
44
+ # ------------------------------
45
+ # 3. 预测 / 推理函数
46
+ # ------------------------------
47
+ @spaces.GPU() # 让 huggingface spaces 调用 GPU
48
  def predict(history, max_length, top_p, temperature):
49
  """
50
+ 基于当前的 history 做文本生成。
51
+ 使用 HF 提供的 TextIteratorStreamer 实现流式生成。
52
  """
53
  prompt = preprocess_messages(history)
54
 
55
+ inputs = tokenizer(
56
+ prompt,
57
+ return_tensors="pt",
58
+ padding=True, # 自动 padding
59
+ truncation=True, # 超长截断
60
+ max_length=2048 # 你可根据显存大小或模型上限做调整
61
+ )
62
  input_ids = inputs["input_ids"].to(model.device)
63
+ attention_mask = inputs["attention_mask"].to(model.device)
64
 
65
+ # 流式输出器
66
  streamer = TextIteratorStreamer(
67
  tokenizer=tokenizer,
68
  timeout=60,
 
72
 
73
  generate_kwargs = {
74
  "input_ids": input_ids,
75
+ "attention_mask": attention_mask,
76
+ "max_new_tokens": max_length, # 新生成的 token 数
77
  "do_sample": True,
78
  "top_p": top_p,
79
  "temperature": temperature,
80
  "repetition_penalty": 1.2,
81
  "streamer": streamer,
 
 
82
  }
83
 
84
+ # 在后台线程中执行 generate,主线程循环读取新 token
85
  t = Thread(target=model.generate, kwargs=generate_kwargs)
86
  t.start()
87
 
88
+ # 将最新生成的 token 依次拼接到 history[-1][1]
89
  partial_output = ""
90
  for new_token in streamer:
91
  partial_output += new_token
92
  history[-1][1] = partial_output
93
  yield history
94
 
95
+
96
+ # ------------------------------
97
+ # 4. Gradio UI
98
+ # ------------------------------
99
  def main():
100
  with gr.Blocks() as demo:
101
+ gr.HTML("<h1 align='center'>DeepScaleR-1.5B Chat Demo</h1>")
102
 
 
103
  chatbot = gr.Chatbot()
104
 
105
  with gr.Row():
106
  with gr.Column(scale=2):
107
  user_input = gr.Textbox(
108
+ show_label=True,
109
+ placeholder="请输入您的问题...",
110
  label="User Input"
111
  )
112
  submitBtn = gr.Button("Submit")
113
+ clearBtn = gr.Button("Clear History")
114
  with gr.Column(scale=1):
115
  max_length = gr.Slider(
116
+ minimum=0,
117
+ maximum=1024, # 可根据需要调大/调小
118
+ value=512,
119
+ step=1,
120
+ label="Max New Tokens",
121
  interactive=True
122
  )
123
  top_p = gr.Slider(
124
+ minimum=0,
125
+ maximum=1,
126
+ value=0.8,
127
+ step=0.01,
128
+ label="Top P",
129
  interactive=True
130
  )
131
  temperature = gr.Slider(
132
+ minimum=0.0,
133
+ maximum=2.0,
134
+ value=0.7,
135
+ step=0.01,
136
+ label="Temperature",
137
  interactive=True
138
  )
139
 
140
+ # 用户点击 Submit 时,先将输入添加到 history,然后再调用 predict 生成
141
  def user(query, history):
142
  return "", history + [[query, ""]]
143
 
 
 
 
144
  submitBtn.click(
145
+ fn=user,
146
+ inputs=[user_input, chatbot],
147
  outputs=[user_input, chatbot],
148
+ queue=False # 不排队
149
  ).then(
150
+ fn=predict,
151
+ inputs=[chatbot, max_length, top_p, temperature],
152
  outputs=chatbot
153
  )
154
 
155
+ # 清空聊天记录
156
  def clear_history():
157
  return [], []
 
 
 
 
 
 
158
 
159
+ clearBtn.click(fn=clear_history, inputs=[], outputs=[chatbot, user_input], queue=False)
160
+
161
+ # 可选:启用队列防止并发冲突
162
+ demo.queue(concurrency_count=1)
163
  demo.launch()
164
 
165
+ # ------------------------------
166
+ # 入口
167
+ # ------------------------------
168
  if __name__ == "__main__":
169
  main()