TongkunGuan commited on
Commit
333ea05
·
verified ·
1 Parent(s): aa84990

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +70 -22
app.py CHANGED
@@ -21,12 +21,14 @@ CHECKPOINTS = {
21
 
22
  # 全局变量
23
  HF_TOKEN = os.getenv("HF_TOKEN")
24
- current_vis = [] # 存储所有 heatmap
25
- current_bpe = [] # 存储所有 BPE
26
- current_index = 0 # 当前显示的 heatmap 和 BPE 的索引
 
27
 
28
  def load_model(check_type):
29
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
30
  if check_type == 'R50':
31
  tokenizer = load_tokenizer('tokenizer_path')
32
  model = build_model(argparse.Namespace()).eval()
@@ -78,6 +80,10 @@ def process_image(model, tokenizer, transform, device, check_type, image, text):
78
  text_embeds = model.tok_embeddings(input_ids)
79
 
80
  vit_embeds, size1 = model.forward_tokenocr(pixel_values.to(torch.bfloat16).to(device))
 
 
 
 
81
  vit_embeds, size2 = post_process(vit_embeds, target_ratio, check_type)
82
 
83
  # 计算相似度
@@ -86,26 +92,45 @@ def process_image(model, tokenizer, transform, device, check_type, image, text):
86
  similarity = text_embeds @ vit_embeds.T
87
  resized_size = size1 if size1 is not None else size2
88
 
 
 
 
 
 
89
  # 生成可视化
90
  attn_map = similarity.reshape(len(text_embeds), resized_size[0], resized_size[1])
 
91
  all_bpe_strings = [tokenizer.decode(input_id) for input_id in input_ids]
92
  current_vis = generate_similiarity_map([image], attn_map,
93
  [tokenizer.decode([i]) for i in input_ids],
94
  [], target_ratio, src_size)
95
 
96
  current_bpe = [tokenizer.decode([i]) for i in input_ids]
 
97
  current_bpe[-1] = text
98
- current_index = 0 # 重置索引
99
- return image, current_vis[current_index], format_bpe_display(current_bpe[current_index])
100
-
101
- def format_bpe_display(bpe):
102
- return f"<div style='text-align:center; font-size:20px;'><strong>Current BPE: <span style='color:red;'>{bpe}</span></strong></div>"
103
 
 
104
  def update_index(change):
105
  global current_vis, current_bpe, current_index
106
  current_index = max(0, min(len(current_vis) - 1, current_index + change))
107
  return current_vis[current_index], format_bpe_display(current_bpe[current_index])
108
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
  # Gradio界面
110
  with gr.Blocks(title="BPE Visualization Demo") as demo:
111
  gr.Markdown("## BPE Visualization Demo - TokenFD基座模型能力可视化")
@@ -115,7 +140,7 @@ with gr.Blocks(title="BPE Visualization Demo") as demo:
115
  model_type = gr.Dropdown(
116
  choices=["TokenFD_4096_English_seg", "TokenFD_2048_Bilingual_seg", "R50", "R50_siglip"],
117
  label="Select model type",
118
- value="TokenOCR_4096_English_seg"
119
  )
120
  image_input = gr.Image(label="Upload images", type="pil")
121
  text_input = gr.Textbox(label="Input text")
@@ -139,34 +164,57 @@ with gr.Blocks(title="BPE Visualization Demo") as demo:
139
  orig_img = gr.Image(label="Original picture", interactive=False)
140
  heatmap = gr.Image(label="BPE visualization", interactive=False)
141
 
142
- with gr.Row():
143
- prev_btn = gr.Button("⬅ Previous")
144
- next_btn = gr.Button("Next ")
 
145
 
146
- bpe_display = gr.Markdown("Current BPE: ", visible=True)
147
 
148
  # 事件处理
149
  @spaces.GPU
150
  def on_run_clicked(model_type, image, text):
151
  global current_vis, current_bpe, current_index
 
152
  image, vis, bpe = process_image(*load_model(model_type), model_type, image, text)
153
- return image, vis, bpe
 
 
 
 
 
154
 
155
  run_btn.click(
156
  on_run_clicked,
157
  inputs=[model_type, image_input, text_input],
158
- outputs=[orig_img, heatmap, bpe_display],
 
 
 
 
159
  )
160
-
161
  prev_btn.click(
162
- lambda: update_index(-1),
163
- outputs=[heatmap, bpe_display]
164
  )
165
-
166
  next_btn.click(
167
- lambda: update_index(1),
168
- outputs=[heatmap, bpe_display]
169
  )
 
 
 
 
 
 
 
 
 
 
 
 
170
 
171
  if __name__ == "__main__":
172
  demo.launch()
 
21
 
22
  # 全局变量
23
  HF_TOKEN = os.getenv("HF_TOKEN")
24
+ current_vis = []
25
+ current_bpe = []
26
+ current_index = 0
27
+
28
 
29
  def load_model(check_type):
30
+ # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
31
+ device = torch.device("cuda")
32
  if check_type == 'R50':
33
  tokenizer = load_tokenizer('tokenizer_path')
34
  model = build_model(argparse.Namespace()).eval()
 
80
  text_embeds = model.tok_embeddings(input_ids)
81
 
82
  vit_embeds, size1 = model.forward_tokenocr(pixel_values.to(torch.bfloat16).to(device))
83
+ print("vit_embeds",vit_embeds)
84
+ print("vit_embeds,shape",vit_embeds.shape)
85
+ print("target_ratio",target_ratio)
86
+ print("check_type",check_type)
87
  vit_embeds, size2 = post_process(vit_embeds, target_ratio, check_type)
88
 
89
  # 计算相似度
 
92
  similarity = text_embeds @ vit_embeds.T
93
  resized_size = size1 if size1 is not None else size2
94
 
95
+ # print(f"text_embeds shape: {text_embeds.shape}, numel: {text_embeds.numel()}") # text_embeds shape: torch.Size([4, 2048]), numel: 8192
96
+ # print(f"vit_embeds shape: {vit_embeds.shape}, numel: {vit_embeds.numel()}") # vit_embeds shape: torch.Size([9728, 2048]), numel: 19922944
97
+ # print(f"similarity shape: {similarity.shape}, numel: {similarity.numel()}")# similarity shape: torch.Size([4, 9728]), numel: 38912
98
+
99
+
100
  # 生成可视化
101
  attn_map = similarity.reshape(len(text_embeds), resized_size[0], resized_size[1])
102
+ # attn_map = similarity.reshape(len(text_embeds), *target_ratio)
103
  all_bpe_strings = [tokenizer.decode(input_id) for input_id in input_ids]
104
  current_vis = generate_similiarity_map([image], attn_map,
105
  [tokenizer.decode([i]) for i in input_ids],
106
  [], target_ratio, src_size)
107
 
108
  current_bpe = [tokenizer.decode([i]) for i in input_ids]
109
+ # current_bpe[-1] = 'Input text'
110
  current_bpe[-1] = text
111
+ print("current_vis",len(current_vis))
112
+ print("current_bpe",len(current_bpe))
113
+ return image, current_vis[0], current_bpe[0]
 
 
114
 
115
+ # 事件处理函数
116
  def update_index(change):
117
  global current_vis, current_bpe, current_index
118
  current_index = max(0, min(len(current_vis) - 1, current_index + change))
119
  return current_vis[current_index], format_bpe_display(current_bpe[current_index])
120
 
121
+ def format_bpe_display(bpe):
122
+ # 使用HTML标签来设置字体大小、颜色,加粗,并居中
123
+ return f"<div style='text-align:center; font-size:20px;'><strong>Current BPE: <span style='color:red;'>{bpe}</span></strong></div>"
124
+
125
+ def update_slider_index(x):
126
+ current_vis = x[1]
127
+ current_bpe = x[2]
128
+ print(f"x: {x[0]}, current_vis length: {len(current_vis)}, current_bpe length: {len(current_bpe)}")
129
+ if 0 <= x < len(current_vis) and 0 <= x < len(current_bpe):
130
+ return current_vis[x], format_bpe_display(current_bpe[x])
131
+ else:
132
+ return None, "索引超出范围"
133
+
134
  # Gradio界面
135
  with gr.Blocks(title="BPE Visualization Demo") as demo:
136
  gr.Markdown("## BPE Visualization Demo - TokenFD基座模型能力可视化")
 
140
  model_type = gr.Dropdown(
141
  choices=["TokenFD_4096_English_seg", "TokenFD_2048_Bilingual_seg", "R50", "R50_siglip"],
142
  label="Select model type",
143
+ value="TokenOCR_4096_English_seg" # 设置默认值为第一个选项
144
  )
145
  image_input = gr.Image(label="Upload images", type="pil")
146
  text_input = gr.Textbox(label="Input text")
 
164
  orig_img = gr.Image(label="Original picture", interactive=False)
165
  heatmap = gr.Image(label="BPE visualization", interactive=False)
166
 
167
+ with gr.Row() as controls:
168
+ prev_btn = gr.Button("⬅ Last", visible=False)
169
+ index_slider = gr.Slider(0, 1, value=0, step=1, label="BPE index", visible=False)
170
+ next_btn = gr.Button("⮕ Next", visible=False)
171
 
172
+ bpe_display = gr.Markdown("Current BPE: ", visible=False)
173
 
174
  # 事件处理
175
  @spaces.GPU
176
  def on_run_clicked(model_type, image, text):
177
  global current_vis, current_bpe, current_index
178
+ current_index = 0 # Reset index when new image is processed
179
  image, vis, bpe = process_image(*load_model(model_type), model_type, image, text)
180
+ # Update the slider range and set value to 0
181
+ slider_max_val = len(current_bpe) - 1
182
+ bpe_text = format_bpe_display(bpe)
183
+ print("current_vis",len(current_vis))
184
+ print("current_bpe",len(current_bpe))
185
+ return image, vis, bpe_text, slider_max_val
186
 
187
  run_btn.click(
188
  on_run_clicked,
189
  inputs=[model_type, image_input, text_input],
190
+ outputs=[orig_img, heatmap, bpe_display, index_slider],
191
+ ).then(
192
+ lambda max_val: (gr.update(visible=True), gr.update(visible=True, maximum=max_val, value=0), gr.update(visible=True), gr.update(visible=True)),
193
+ inputs=index_slider,
194
+ outputs=[prev_btn, index_slider, next_btn, bpe_display],
195
  )
196
+
197
  prev_btn.click(
198
+ lambda: (*update_index(-1), current_index),
199
+ outputs=[heatmap, bpe_display, index_slider]
200
  )
201
+
202
  next_btn.click(
203
+ lambda: (*update_index(1), current_index),
204
+ outputs=[heatmap, bpe_display, index_slider]
205
  )
206
+
207
+ # index_slider.change(
208
+ # lambda x: (current_vis[x], format_bpe_display(current_bpe[x])) if 0<=x<len(current_vis else (None,"Invaild")
209
+ # inputs=index_slider,
210
+ # outputs=[heatmap, bpe_display]
211
+ # )
212
+
213
+ index_slider.change(
214
+ update_slider_index,
215
+ inputs=[index_slider,current_vis,current_bpe],
216
+ outputs=[heatmap, bpe_display]
217
+ )
218
 
219
  if __name__ == "__main__":
220
  demo.launch()