TongkunGuan commited on
Commit
9c191d1
·
verified ·
1 Parent(s): a1cdc55

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -14
app.py CHANGED
@@ -25,6 +25,11 @@ current_vis = []
25
  current_bpe = []
26
  current_index = 0
27
 
 
 
 
 
 
28
 
29
  def load_model(check_type):
30
  # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@@ -130,11 +135,17 @@ def format_bpe_display(bpe):
130
  # else:
131
  # return None, "索引超出范围"
132
  # 状态更新函数,利用传递的状态(vis, bpe)
133
- def update_slider_index(x, vis, bpe):
 
 
 
134
  if 0 <= x < len(vis):
135
- return vis[x], format_bpe_display(bpe[x]), vis, bpe
136
  else:
137
- return None, "索引超出范围", vis, bpe
 
 
 
138
 
139
 
140
  # Gradio界面
@@ -191,14 +202,17 @@ with gr.Blocks(title="BPE Visualization Demo") as demo:
191
  # return image, vis, bpe_text, slider_max_val
192
 
193
  @spaces.GPU
194
- def on_run_clicked(model_type, image, text):
195
  model, tokenizer, transform, device = load_model(model_type)
196
  current_index = 0 # Reset index when new image is processed
197
  image, vis, bpe = process_image(model, tokenizer, transform, device, model_type, image, text)
198
  slider_max_val = len(bpe) - 1
199
  bpe_text = format_bpe_display(bpe[current_index])
200
- # 将处理结果传递给后续步骤
201
- return image, vis[current_index], bpe_text, slider_max_val, vis, bpe
 
 
 
202
 
203
 
204
 
@@ -213,14 +227,12 @@ with gr.Blocks(title="BPE Visualization Demo") as demo:
213
  # outputs=[prev_btn, index_slider, next_btn, bpe_display],
214
  # )
215
  # Gradio 按钮点击后的处理
 
216
  run_btn.click(
217
  on_run_clicked,
218
- inputs=[model_type, image_input, text_input],
219
- outputs=[orig_img, heatmap, bpe_display, index_slider, 'state', 'state']
220
- ).then(
221
- lambda outputs: (gr.update(visible=True), gr.update(visible=True, maximum=outputs[3], value=0), gr.update(visible=True), gr.update(visible=True), outputs[4], outputs[5]),
222
- inputs=index_slider,
223
- outputs=[prev_btn, index_slider, next_btn, bpe_display, 'state', 'state']
224
  )
225
 
226
  prev_btn.click(
@@ -241,8 +253,8 @@ with gr.Blocks(title="BPE Visualization Demo") as demo:
241
  # )
242
  index_slider.change(
243
  update_slider_index,
244
- inputs=[index_slider, 'state', 'state'],
245
- outputs=[heatmap, bpe_display, 'state', 'state']
246
  )
247
 
248
 
 
25
  current_bpe = []
26
  current_index = 0
27
 
28
+ # 设置初始状态
29
+ initial_state = {
30
+ "vis": [],
31
+ "bpe": []
32
+ }
33
 
34
  def load_model(check_type):
35
  # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
135
  # else:
136
  # return None, "索引超出范围"
137
  # 状态更新函数,利用传递的状态(vis, bpe)
138
+ # 使用状态信息来处理滑动条改变
139
+ def update_slider_index(x, state):
140
+ vis = state['vis']
141
+ bpe = state['bpe']
142
  if 0 <= x < len(vis):
143
+ return vis[x], format_bpe_display(bpe[x]), state
144
  else:
145
+ return None, "索引超出范围", state
146
+
147
+
148
+
149
 
150
 
151
  # Gradio界面
 
202
  # return image, vis, bpe_text, slider_max_val
203
 
204
  @spaces.GPU
205
+ def on_run_clicked(model_type, image, text, state):
206
  model, tokenizer, transform, device = load_model(model_type)
207
  current_index = 0 # Reset index when new image is processed
208
  image, vis, bpe = process_image(model, tokenizer, transform, device, model_type, image, text)
209
  slider_max_val = len(bpe) - 1
210
  bpe_text = format_bpe_display(bpe[current_index])
211
+ # 更新状态并返回
212
+ state['vis'] = vis
213
+ state['bpe'] = bpe
214
+ return image, vis[current_index], bpe_text, slider_max_val, state
215
+
216
 
217
 
218
 
 
227
  # outputs=[prev_btn, index_slider, next_btn, bpe_display],
228
  # )
229
  # Gradio 按钮点击后的处理
230
+ # Gradio 按钮点击后的处理
231
  run_btn.click(
232
  on_run_clicked,
233
+ inputs=[model_type, image_input, text_input, 'state'],
234
+ outputs=[orig_img, heatmap, bpe_display, index_slider, 'state'],
235
+ _js="{state: { vis: [], bpe: []}}"
 
 
 
236
  )
237
 
238
  prev_btn.click(
 
253
  # )
254
  index_slider.change(
255
  update_slider_index,
256
+ inputs=[index_slider, 'state'],
257
+ outputs=[heatmap, bpe_display, 'state']
258
  )
259
 
260