TongkunGuan commited on
Commit
a1cdc55
·
verified ·
1 Parent(s): 0619b8a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +33 -24
app.py CHANGED
@@ -129,12 +129,13 @@ def format_bpe_display(bpe):
129
  # return current_vis[x], format_bpe_display(current_bpe[x])
130
  # else:
131
  # return None, "索引超出范围"
132
-
133
- def update_slider_index(x):
134
- if 0 <= x < len(current_vis) and 0 <= x < len(current_bpe):
135
- return current_vis[x], format_bpe_display(current_bpe[x])
136
  else:
137
- return None, "索引超出范围"
 
138
 
139
  # Gradio界面
140
  with gr.Blocks(title="BPE Visualization Demo") as demo:
@@ -188,26 +189,38 @@ with gr.Blocks(title="BPE Visualization Demo") as demo:
188
  # print("current_vis",len(current_vis))
189
  # print("current_bpe",len(current_bpe))
190
  # return image, vis, bpe_text, slider_max_val
 
191
  @spaces.GPU
192
  def on_run_clicked(model_type, image, text):
193
- global current_vis, current_bpe, current_index
194
- current_index = 0 # Reset index when new image is processed
195
  model, tokenizer, transform, device = load_model(model_type)
 
196
  image, vis, bpe = process_image(model, tokenizer, transform, device, model_type, image, text)
197
- current_vis = vis
198
- current_bpe = bpe
199
- slider_max_val = len(current_bpe) - 1
200
- bpe_text = format_bpe_display(current_bpe[current_index])
201
- return image, current_vis[current_index], bpe_text, slider_max_val
 
 
202
 
 
 
 
 
 
 
 
 
 
 
203
  run_btn.click(
204
  on_run_clicked,
205
  inputs=[model_type, image_input, text_input],
206
- outputs=[orig_img, heatmap, bpe_display, index_slider],
207
  ).then(
208
- lambda max_val: (gr.update(visible=True), gr.update(visible=True, maximum=max_val, value=0), gr.update(visible=True), gr.update(visible=True)),
209
  inputs=index_slider,
210
- outputs=[prev_btn, index_slider, next_btn, bpe_display],
211
  )
212
 
213
  prev_btn.click(
@@ -219,12 +232,7 @@ with gr.Blocks(title="BPE Visualization Demo") as demo:
219
  lambda: (*update_index(1), current_index),
220
  outputs=[heatmap, bpe_display, index_slider]
221
  )
222
-
223
- # index_slider.change(
224
- # lambda x: (current_vis[x], format_bpe_display(current_bpe[x])) if 0<=x<len(current_vis else (None,"Invaild")
225
- # inputs=index_slider,
226
- # outputs=[heatmap, bpe_display]
227
- # )
228
 
229
  # index_slider.change(
230
  # update_slider_index,
@@ -232,9 +240,10 @@ with gr.Blocks(title="BPE Visualization Demo") as demo:
232
  # outputs=[heatmap, bpe_display]
233
  # )
234
  index_slider.change(
235
- update_slider_index,
236
- inputs=[index_slider],
237
- outputs=[heatmap, bpe_display])
 
238
 
239
 
240
  if __name__ == "__main__":
 
129
  # return current_vis[x], format_bpe_display(current_bpe[x])
130
  # else:
131
  # return None, "索引超出范围"
132
+ # 状态更新函数,利用传递的状态(vis, bpe)
133
+ def update_slider_index(x, vis, bpe):
134
+ if 0 <= x < len(vis):
135
+ return vis[x], format_bpe_display(bpe[x]), vis, bpe
136
  else:
137
+ return None, "索引超出范围", vis, bpe
138
+
139
 
140
  # Gradio界面
141
  with gr.Blocks(title="BPE Visualization Demo") as demo:
 
189
  # print("current_vis",len(current_vis))
190
  # print("current_bpe",len(current_bpe))
191
  # return image, vis, bpe_text, slider_max_val
192
+
193
  @spaces.GPU
194
  def on_run_clicked(model_type, image, text):
 
 
195
  model, tokenizer, transform, device = load_model(model_type)
196
+ current_index = 0 # Reset index when new image is processed
197
  image, vis, bpe = process_image(model, tokenizer, transform, device, model_type, image, text)
198
+ slider_max_val = len(bpe) - 1
199
+ bpe_text = format_bpe_display(bpe[current_index])
200
+ # 将处理结果传递给后续步骤
201
+ return image, vis[current_index], bpe_text, slider_max_val, vis, bpe
202
+
203
+
204
+
205
 
206
+ # run_btn.click(
207
+ # on_run_clicked,
208
+ # inputs=[model_type, image_input, text_input],
209
+ # outputs=[orig_img, heatmap, bpe_display, index_slider],
210
+ # ).then(
211
+ # lambda max_val: (gr.update(visible=True), gr.update(visible=True, maximum=max_val, value=0), gr.update(visible=True), gr.update(visible=True)),
212
+ # inputs=index_slider,
213
+ # outputs=[prev_btn, index_slider, next_btn, bpe_display],
214
+ # )
215
+ # Gradio 按钮点击后的处理
216
  run_btn.click(
217
  on_run_clicked,
218
  inputs=[model_type, image_input, text_input],
219
+ outputs=[orig_img, heatmap, bpe_display, index_slider, 'state', 'state']
220
  ).then(
221
+ lambda outputs: (gr.update(visible=True), gr.update(visible=True, maximum=outputs[3], value=0), gr.update(visible=True), gr.update(visible=True), outputs[4], outputs[5]),
222
  inputs=index_slider,
223
+ outputs=[prev_btn, index_slider, next_btn, bpe_display, 'state', 'state']
224
  )
225
 
226
  prev_btn.click(
 
232
  lambda: (*update_index(1), current_index),
233
  outputs=[heatmap, bpe_display, index_slider]
234
  )
235
+
 
 
 
 
 
236
 
237
  # index_slider.change(
238
  # update_slider_index,
 
240
  # outputs=[heatmap, bpe_display]
241
  # )
242
  index_slider.change(
243
+ update_slider_index,
244
+ inputs=[index_slider, 'state', 'state'],
245
+ outputs=[heatmap, bpe_display, 'state', 'state']
246
+ )
247
 
248
 
249
  if __name__ == "__main__":