TongkunGuan commited on
Commit
0619b8a
·
verified ·
1 Parent(s): 17cc18a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +35 -14
app.py CHANGED
@@ -122,10 +122,15 @@ def format_bpe_display(bpe):
122
  # 使用HTML标签来设置字体大小、颜色,加粗,并居中
123
  return f"<div style='text-align:center; font-size:20px;'><strong>Current BPE: <span style='color:red;'>{bpe}</span></strong></div>"
124
 
125
- @spaces.GPU
 
 
 
 
 
 
 
126
  def update_slider_index(x):
127
- global current_vis, current_bpe, current_index
128
- print(f"x: {x}, current_vis length: {len(current_vis)}, current_bpe length: {len(current_bpe)}")
129
  if 0 <= x < len(current_vis) and 0 <= x < len(current_bpe):
130
  return current_vis[x], format_bpe_display(current_bpe[x])
131
  else:
@@ -172,18 +177,29 @@ with gr.Blocks(title="BPE Visualization Demo") as demo:
172
  bpe_display = gr.Markdown("Current BPE: ", visible=False)
173
 
174
  # 事件处理
 
 
 
 
 
 
 
 
 
 
 
175
  @spaces.GPU
176
  def on_run_clicked(model_type, image, text):
177
  global current_vis, current_bpe, current_index
178
  current_index = 0 # Reset index when new image is processed
179
- image, vis, bpe = process_image(*load_model(model_type), model_type, image, text)
180
- # Update the slider range and set value to 0
 
 
181
  slider_max_val = len(current_bpe) - 1
182
- bpe_text = format_bpe_display(bpe)
183
- print("current_vis",len(current_vis))
184
- print("current_bpe",len(current_bpe))
185
- return image, vis, bpe_text, slider_max_val
186
-
187
  run_btn.click(
188
  on_run_clicked,
189
  inputs=[model_type, image_input, text_input],
@@ -210,11 +226,16 @@ with gr.Blocks(title="BPE Visualization Demo") as demo:
210
  # outputs=[heatmap, bpe_display]
211
  # )
212
 
 
 
 
 
 
213
  index_slider.change(
214
- update_slider_index,
215
- inputs=index_slider,
216
- outputs=[heatmap, bpe_display]
217
- )
218
 
219
  if __name__ == "__main__":
220
  demo.launch()
 
122
  # 使用HTML标签来设置字体大小、颜色,加粗,并居中
123
  return f"<div style='text-align:center; font-size:20px;'><strong>Current BPE: <span style='color:red;'>{bpe}</span></strong></div>"
124
 
125
+ # def update_slider_index(x):
126
+ # global current_vis, current_bpe, current_index
127
+ # print(f"x: {x}, current_vis length: {len(current_vis)}, current_bpe length: {len(current_bpe)}")
128
+ # if 0 <= x < len(current_vis) and 0 <= x < len(current_bpe):
129
+ # return current_vis[x], format_bpe_display(current_bpe[x])
130
+ # else:
131
+ # return None, "索引超出范围"
132
+
133
  def update_slider_index(x):
 
 
134
  if 0 <= x < len(current_vis) and 0 <= x < len(current_bpe):
135
  return current_vis[x], format_bpe_display(current_bpe[x])
136
  else:
 
177
  bpe_display = gr.Markdown("Current BPE: ", visible=False)
178
 
179
  # 事件处理
180
+ # @spaces.GPU
181
+ # def on_run_clicked(model_type, image, text):
182
+ # global current_vis, current_bpe, current_index
183
+ # current_index = 0 # Reset index when new image is processed
184
+ # image, vis, bpe = process_image(*load_model(model_type), model_type, image, text)
185
+ # # Update the slider range and set value to 0
186
+ # slider_max_val = len(current_bpe) - 1
187
+ # bpe_text = format_bpe_display(bpe)
188
+ # print("current_vis",len(current_vis))
189
+ # print("current_bpe",len(current_bpe))
190
+ # return image, vis, bpe_text, slider_max_val
191
  @spaces.GPU
192
  def on_run_clicked(model_type, image, text):
193
  global current_vis, current_bpe, current_index
194
  current_index = 0 # Reset index when new image is processed
195
+ model, tokenizer, transform, device = load_model(model_type)
196
+ image, vis, bpe = process_image(model, tokenizer, transform, device, model_type, image, text)
197
+ current_vis = vis
198
+ current_bpe = bpe
199
  slider_max_val = len(current_bpe) - 1
200
+ bpe_text = format_bpe_display(current_bpe[current_index])
201
+ return image, current_vis[current_index], bpe_text, slider_max_val
202
+
 
 
203
  run_btn.click(
204
  on_run_clicked,
205
  inputs=[model_type, image_input, text_input],
 
226
  # outputs=[heatmap, bpe_display]
227
  # )
228
 
229
+ # index_slider.change(
230
+ # update_slider_index,
231
+ # inputs=index_slider,
232
+ # outputs=[heatmap, bpe_display]
233
+ # )
234
  index_slider.change(
235
+ update_slider_index,
236
+ inputs=[index_slider],
237
+ outputs=[heatmap, bpe_display])
238
+
239
 
240
  if __name__ == "__main__":
241
  demo.launch()