TongkunGuan commited on
Commit
464ed7e
·
verified ·
1 Parent(s): 79d5e07

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +87 -47
app.py CHANGED
@@ -21,6 +21,10 @@ CHECKPOINTS = {
21
 
22
  # 全局变量
23
  HF_TOKEN = os.getenv("HF_TOKEN")
 
 
 
 
24
 
25
  def load_model(check_type):
26
  # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@@ -50,7 +54,8 @@ def load_model(check_type):
50
 
51
  return model.to(device), tokenizer, transform, device
52
 
53
- def process_image(model, tokenizer, transform, device, check_type, image, text, state):
 
54
  src_size = image.size
55
  if 'TokenOCR' in check_type:
56
  images, target_ratio = dynamic_preprocess(image, min_num=1, max_num=12,
@@ -75,6 +80,10 @@ def process_image(model, tokenizer, transform, device, check_type, image, text,
75
  text_embeds = model.tok_embeddings(input_ids)
76
 
77
  vit_embeds, size1 = model.forward_tokenocr(pixel_values.to(torch.bfloat16).to(device))
 
 
 
 
78
  vit_embeds, size2 = post_process(vit_embeds, target_ratio, check_type)
79
 
80
  # 计算相似度
@@ -83,22 +92,47 @@ def process_image(model, tokenizer, transform, device, check_type, image, text,
83
  similarity = text_embeds @ vit_embeds.T
84
  resized_size = size1 if size1 is not None else size2
85
 
 
 
 
 
 
 
86
  attn_map = similarity.reshape(len(text_embeds), resized_size[0], resized_size[1])
 
87
  all_bpe_strings = [tokenizer.decode(input_id) for input_id in input_ids]
88
- vis = generate_similiarity_map([image], attn_map,
89
- [tokenizer.decode([i]) for i in input_ids],
90
- [], target_ratio, src_size)
91
 
92
- bpe = [tokenizer.decode([i]) for i in input_ids]
93
- bpe[-1] = text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
 
95
- # Store results in state
96
- state['current_vis'] = vis
97
- state['current_bpe'] = bpe
98
- return image, vis[0], bpe[0], len(vis) - 1
99
 
100
  # Gradio界面
101
- with gr.Blocks() as demo:
102
  gr.Markdown("## BPE Visualization Demo - TokenFD基座模型能力可视化")
103
 
104
  with gr.Row():
@@ -106,11 +140,13 @@ with gr.Blocks() as demo:
106
  model_type = gr.Dropdown(
107
  choices=["TokenFD_4096_English_seg", "TokenFD_2048_Bilingual_seg", "R50", "R50_siglip"],
108
  label="Select model type",
109
- value="TokenOCR_4096_English_seg"
110
  )
111
  image_input = gr.Image(label="Upload images", type="pil")
112
  text_input = gr.Textbox(label="Input text")
 
113
  run_btn = gr.Button("RUN")
 
114
  gr.Examples(
115
  examples=[
116
  [os.path.join("examples", "examples0.jpg"), "Veterans and Benefits"],
@@ -123,58 +159,62 @@ with gr.Blocks() as demo:
123
 
124
  with gr.Column(scale=2):
125
  gr.Markdown("<p style='font-size:20px;'><span style='color:red;'>If the input text is not included in the image</span>, the attention map will show a lot of noise (the actual response value is very low), since we normalize the attention map according to the relative value.</p>")
126
- orig_img = gr.Image(label="Original picture", interactive=False)
127
- heatmap = gr.Image(label="BPE visualization", interactive=False)
128
- prev_btn = gr.Button("⬅ Last", visible=False)
129
- index_slider = gr.Slider(0, 1, value=0, step=1, label="BPE index", visible=False)
130
- next_btn = gr.Button("⮕ Next", visible=False)
131
- bpe_display = gr.Markdown("Current BPE: ", visible=False)
132
 
133
- state = gr.State()
134
- state['current_vis'] = []
135
- state['current_bpe'] = []
136
- state['current_index'] = 0
 
 
 
 
 
 
137
 
 
138
  @spaces.GPU
139
- def on_run_clicked(model_type, image, text, state):
140
- image, vis, bpe, slider_max_val = process_image(*load_model(model_type), model_type, image, text, state)
141
- state['current_vis'] = vis
142
- state['current_bpe'] = bpe
143
- state['current_index'] = 0
 
144
  bpe_text = format_bpe_display(bpe)
 
 
 
 
145
  return image, vis, bpe_text, slider_max_val
146
 
 
147
  run_btn.click(
148
  on_run_clicked,
149
- inputs=[model_type, image_input, text_input, state],
150
- outputs=[orig_img, heatmap, bpe_display],
151
- _js="""
152
- (orig_img, heatmap, bpe_display, slider_max_val) => {
153
- index_slider.update({ visible: true, maximum: slider_max_val, value: 0 });
154
- prev_btn.update({ visible: true });
155
- next_btn.update({ visible: true });
156
- return [orig_img, heatmap, bpe_display];
157
- }
158
- """
159
  )
160
-
161
  prev_btn.click(
162
- lambda state: update_index(-1, state),
163
- inputs=[state],
164
  outputs=[heatmap, bpe_display, index_slider]
165
  )
166
-
167
  next_btn.click(
168
- lambda state: update_index(1, state),
169
- inputs=[state],
170
  outputs=[heatmap, bpe_display, index_slider]
171
  )
172
 
 
173
  index_slider.change(
174
- lambda x, state: update_slider_index(x, state),
175
- inputs=[index_slider, state],
176
- outputs=[heatmap, bpe_display]
177
- )
 
 
178
 
179
  if __name__ == "__main__":
180
  demo.launch()
 
21
 
22
  # 全局变量
23
  HF_TOKEN = os.getenv("HF_TOKEN")
24
+ current_vis = []
25
+ current_bpe = []
26
+ current_index = 0
27
+
28
 
29
  def load_model(check_type):
30
  # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
54
 
55
  return model.to(device), tokenizer, transform, device
56
 
57
+ def process_image(model, tokenizer, transform, device, check_type, image, text):
58
+ global current_vis, current_bpe, current_index
59
  src_size = image.size
60
  if 'TokenOCR' in check_type:
61
  images, target_ratio = dynamic_preprocess(image, min_num=1, max_num=12,
 
80
  text_embeds = model.tok_embeddings(input_ids)
81
 
82
  vit_embeds, size1 = model.forward_tokenocr(pixel_values.to(torch.bfloat16).to(device))
83
+ print("vit_embeds",vit_embeds)
84
+ print("vit_embeds,shape",vit_embeds.shape)
85
+ print("target_ratio",target_ratio)
86
+ print("check_type",check_type)
87
  vit_embeds, size2 = post_process(vit_embeds, target_ratio, check_type)
88
 
89
  # 计算相似度
 
92
  similarity = text_embeds @ vit_embeds.T
93
  resized_size = size1 if size1 is not None else size2
94
 
95
+ # print(f"text_embeds shape: {text_embeds.shape}, numel: {text_embeds.numel()}") # text_embeds shape: torch.Size([4, 2048]), numel: 8192
96
+ # print(f"vit_embeds shape: {vit_embeds.shape}, numel: {vit_embeds.numel()}") # vit_embeds shape: torch.Size([9728, 2048]), numel: 19922944
97
+ # print(f"similarity shape: {similarity.shape}, numel: {similarity.numel()}")# similarity shape: torch.Size([4, 9728]), numel: 38912
98
+
99
+
100
+ # 生成可视化
101
  attn_map = similarity.reshape(len(text_embeds), resized_size[0], resized_size[1])
102
+ # attn_map = similarity.reshape(len(text_embeds), *target_ratio)
103
  all_bpe_strings = [tokenizer.decode(input_id) for input_id in input_ids]
104
+ current_vis = generate_similiarity_map([image], attn_map,
105
+ [tokenizer.decode([i]) for i in input_ids],
106
+ [], target_ratio, src_size)
107
 
108
+ current_bpe = [tokenizer.decode([i]) for i in input_ids]
109
+ # current_bpe[-1] = 'Input text'
110
+ current_bpe[-1] = text
111
+
112
+ return image, current_vis[0], current_bpe[0]
113
+
114
+ # 事件处理函数
115
+ def update_index(change):
116
+ global current_vis, current_bpe, current_index
117
+ current_index = max(0, min(len(current_vis) - 1, current_index + change))
118
+ return current_vis[current_index], format_bpe_display(current_bpe[current_index])
119
+
120
+ def format_bpe_display(bpe):
121
+ # 使用HTML标签来设置字体大小、颜色,加粗,并居中
122
+ return f"<div style='text-align:center; font-size:20px;'><strong>Current BPE: <span style='color:red;'>{bpe}</span></strong></div>"
123
+
124
+ def update_slider_index(x):
125
+ global current_vis, current_bpe, current_index
126
+ print(f"x: {x}, current_vis length: {len(current_vis)}, current_bpe length: {len(current_bpe)}")
127
+ if 0 <= x < len(current_vis) and 0 <= x < len(current_bpe):
128
+ return current_vis[x], format_bpe_display(current_bpe[x])
129
+ else:
130
+ return None, "索引超出范围"
131
+
132
 
 
 
 
 
133
 
134
  # Gradio界面
135
+ with gr.Blocks(title="BPE Visualization Demo") as demo:
136
  gr.Markdown("## BPE Visualization Demo - TokenFD基座模型能力可视化")
137
 
138
  with gr.Row():
 
140
  model_type = gr.Dropdown(
141
  choices=["TokenFD_4096_English_seg", "TokenFD_2048_Bilingual_seg", "R50", "R50_siglip"],
142
  label="Select model type",
143
+ value="TokenOCR_4096_English_seg" # 设置默认值为第一个选项
144
  )
145
  image_input = gr.Image(label="Upload images", type="pil")
146
  text_input = gr.Textbox(label="Input text")
147
+
148
  run_btn = gr.Button("RUN")
149
+
150
  gr.Examples(
151
  examples=[
152
  [os.path.join("examples", "examples0.jpg"), "Veterans and Benefits"],
 
159
 
160
  with gr.Column(scale=2):
161
  gr.Markdown("<p style='font-size:20px;'><span style='color:red;'>If the input text is not included in the image</span>, the attention map will show a lot of noise (the actual response value is very low), since we normalize the attention map according to the relative value.</p>")
 
 
 
 
 
 
162
 
163
+ with gr.Row():
164
+ orig_img = gr.Image(label="Original picture", interactive=False)
165
+ heatmap = gr.Image(label="BPE visualization", interactive=False)
166
+
167
+ with gr.Row() as controls:
168
+ prev_btn = gr.Button("⬅ Last", visible=False)
169
+ index_slider = gr.Slider(0, 1, value=0, step=1, label="BPE index", visible=False)
170
+ next_btn = gr.Button("⮕ Next", visible=False)
171
+
172
+ bpe_display = gr.Markdown("Current BPE: ", visible=False)
173
 
174
+ # 事件处理
175
  @spaces.GPU
176
+ def on_run_clicked(model_type, image, text):
177
+ global current_vis, current_bpe, current_index
178
+ current_index = 0 # Reset index when new image is processed
179
+ image, vis, bpe = process_image(*load_model(model_type), model_type, image, text)
180
+ # Update the slider range and set value to 0
181
+ slider_max_val = len(current_bpe) - 1
182
  bpe_text = format_bpe_display(bpe)
183
+ print("len_current_vis",len(current_vis))
184
+ print("len_current_bpe",len(current_bpe))
185
+ print("current_vis",current_vis)
186
+ print("current_bpe",current_bpe)
187
  return image, vis, bpe_text, slider_max_val
188
 
189
+
190
  run_btn.click(
191
  on_run_clicked,
192
+ inputs=[model_type, image_input, text_input],
193
+ outputs=[orig_img, heatmap, bpe_display, index_slider],
194
+ ).then(
195
+ lambda max_val: (gr.update(visible=True), gr.update(visible=True, maximum=max_val, value=0), gr.update(visible=True), gr.update(visible=True)),
196
+ inputs=index_slider,
197
+ outputs=[prev_btn, index_slider, next_btn, bpe_display],
 
 
 
 
198
  )
199
+
200
  prev_btn.click(
201
+ lambda: (*update_index(-1), current_index),
 
202
  outputs=[heatmap, bpe_display, index_slider]
203
  )
204
+
205
  next_btn.click(
206
+ lambda: (*update_index(1), current_index),
 
207
  outputs=[heatmap, bpe_display, index_slider]
208
  )
209
 
210
+
211
  index_slider.change(
212
+ update_slider_index,
213
+ inputs=index_slider,
214
+ outputs=[heatmap, bpe_display]
215
+ )
216
+
217
+
218
 
219
  if __name__ == "__main__":
220
  demo.launch()