TongkunGuan commited on
Commit
79d5e07
·
verified ·
1 Parent(s): b70aad2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +32 -36
app.py CHANGED
@@ -98,7 +98,7 @@ def process_image(model, tokenizer, transform, device, check_type, image, text,
98
  return image, vis[0], bpe[0], len(vis) - 1
99
 
100
  # Gradio界面
101
- with gr.Blocks(title="BPE Visualization Demo") as demo:
102
  gr.Markdown("## BPE Visualization Demo - TokenFD基座模型能力可视化")
103
 
104
  with gr.Row():
@@ -106,13 +106,11 @@ with gr.Blocks(title="BPE Visualization Demo") as demo:
106
  model_type = gr.Dropdown(
107
  choices=["TokenFD_4096_English_seg", "TokenFD_2048_Bilingual_seg", "R50", "R50_siglip"],
108
  label="Select model type",
109
- value="TokenOCR_4096_English_seg" # 设置默认值为第一个选项
110
  )
111
  image_input = gr.Image(label="Upload images", type="pil")
112
  text_input = gr.Textbox(label="Input text")
113
-
114
  run_btn = gr.Button("RUN")
115
-
116
  gr.Examples(
117
  examples=[
118
  [os.path.join("examples", "examples0.jpg"), "Veterans and Benefits"],
@@ -125,60 +123,58 @@ with gr.Blocks(title="BPE Visualization Demo") as demo:
125
 
126
  with gr.Column(scale=2):
127
  gr.Markdown("<p style='font-size:20px;'><span style='color:red;'>If the input text is not included in the image</span>, the attention map will show a lot of noise (the actual response value is very low), since we normalize the attention map according to the relative value.</p>")
128
-
129
- with gr.Row():
130
- orig_img = gr.Image(label="Original picture", interactive=False)
131
- heatmap = gr.Image(label="BPE visualization", interactive=False)
132
-
133
- with gr.Row() as controls:
134
- prev_btn = gr.Button("⬅ Last", visible=False)
135
- index_slider = gr.Slider(0, 1, value=0, step=1, label="BPE index", visible=False)
136
- next_btn = gr.Button("⮕ Next", visible=False)
137
-
138
  bpe_display = gr.Markdown("Current BPE: ", visible=False)
139
 
140
- state = gr.State(current_vis=[], current_bpe=[], current_index=0)
 
 
 
141
 
142
  @spaces.GPU
143
  def on_run_clicked(model_type, image, text, state):
144
  image, vis, bpe, slider_max_val = process_image(*load_model(model_type), model_type, image, text, state)
 
 
 
145
  bpe_text = format_bpe_display(bpe)
146
- index_slider.update(visible=True, maximum=slider_max_val, value=0)
147
- prev_btn.update(visible=True)
148
- next_btn.update(visible=True)
149
- return image, vis, bpe_text
150
-
151
- def update_index(change, state):
152
- state['current_index'] = max(0, min(len(state['current_vis']) - 1, state['current_index'] + change))
153
- return state['current_vis'][state['current_index']], format_bpe_display(state['current_bpe'][state['current_index']])
154
-
155
- def format_bpe_display(bpe):
156
- return f"<div style='text-align:center; font-size:20px;'><strong>Current BPE: <span style='color:red;'>{bpe}</span></strong></div>"
157
 
158
  run_btn.click(
159
  on_run_clicked,
160
  inputs=[model_type, image_input, text_input, state],
161
  outputs=[orig_img, heatmap, bpe_display],
 
 
 
 
 
 
 
 
162
  )
163
-
164
  prev_btn.click(
165
- lambda state: (*update_index(-1, state), state['current_index']),
166
  inputs=[state],
167
  outputs=[heatmap, bpe_display, index_slider]
168
  )
169
-
170
  next_btn.click(
171
- lambda state: (*update_index(1, state), state['current_index']),
172
  inputs=[state],
173
  outputs=[heatmap, bpe_display, index_slider]
174
  )
175
 
176
-
177
  index_slider.change(
178
- lambda x, state: update_slider_index(x, state),
179
- inputs=[index_slider, state],
180
- outputs=[heatmap, bpe_display]
181
- )
182
 
183
  if __name__ == "__main__":
184
- demo.launch()
 
98
  return image, vis[0], bpe[0], len(vis) - 1
99
 
100
  # Gradio界面
101
+ with gr.Blocks() as demo:
102
  gr.Markdown("## BPE Visualization Demo - TokenFD基座模型能力可视化")
103
 
104
  with gr.Row():
 
106
  model_type = gr.Dropdown(
107
  choices=["TokenFD_4096_English_seg", "TokenFD_2048_Bilingual_seg", "R50", "R50_siglip"],
108
  label="Select model type",
109
+ value="TokenOCR_4096_English_seg"
110
  )
111
  image_input = gr.Image(label="Upload images", type="pil")
112
  text_input = gr.Textbox(label="Input text")
 
113
  run_btn = gr.Button("RUN")
 
114
  gr.Examples(
115
  examples=[
116
  [os.path.join("examples", "examples0.jpg"), "Veterans and Benefits"],
 
123
 
124
  with gr.Column(scale=2):
125
  gr.Markdown("<p style='font-size:20px;'><span style='color:red;'>If the input text is not included in the image</span>, the attention map will show a lot of noise (the actual response value is very low), since we normalize the attention map according to the relative value.</p>")
126
+ orig_img = gr.Image(label="Original picture", interactive=False)
127
+ heatmap = gr.Image(label="BPE visualization", interactive=False)
128
+ prev_btn = gr.Button(" Last", visible=False)
129
+ index_slider = gr.Slider(0, 1, value=0, step=1, label="BPE index", visible=False)
130
+ next_btn = gr.Button("⮕ Next", visible=False)
 
 
 
 
 
131
  bpe_display = gr.Markdown("Current BPE: ", visible=False)
132
 
133
+ state = gr.State()
134
+ state['current_vis'] = []
135
+ state['current_bpe'] = []
136
+ state['current_index'] = 0
137
 
138
  @spaces.GPU
139
  def on_run_clicked(model_type, image, text, state):
140
  image, vis, bpe, slider_max_val = process_image(*load_model(model_type), model_type, image, text, state)
141
+ state['current_vis'] = vis
142
+ state['current_bpe'] = bpe
143
+ state['current_index'] = 0
144
  bpe_text = format_bpe_display(bpe)
145
+ return image, vis, bpe_text, slider_max_val
 
 
 
 
 
 
 
 
 
 
146
 
147
  run_btn.click(
148
  on_run_clicked,
149
  inputs=[model_type, image_input, text_input, state],
150
  outputs=[orig_img, heatmap, bpe_display],
151
+ _js="""
152
+ (orig_img, heatmap, bpe_display, slider_max_val) => {
153
+ index_slider.update({ visible: true, maximum: slider_max_val, value: 0 });
154
+ prev_btn.update({ visible: true });
155
+ next_btn.update({ visible: true });
156
+ return [orig_img, heatmap, bpe_display];
157
+ }
158
+ """
159
  )
160
+
161
  prev_btn.click(
162
+ lambda state: update_index(-1, state),
163
  inputs=[state],
164
  outputs=[heatmap, bpe_display, index_slider]
165
  )
166
+
167
  next_btn.click(
168
+ lambda state: update_index(1, state),
169
  inputs=[state],
170
  outputs=[heatmap, bpe_display, index_slider]
171
  )
172
 
 
173
  index_slider.change(
174
+ lambda x, state: update_slider_index(x, state),
175
+ inputs=[index_slider, state],
176
+ outputs=[heatmap, bpe_display]
177
+ )
178
 
179
  if __name__ == "__main__":
180
+ demo.launch()