TongkunGuan commited on
Commit
b70aad2
·
verified ·
1 Parent(s): 3cefd04

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +46 -137
app.py CHANGED
@@ -21,10 +21,6 @@ CHECKPOINTS = {
21
 
22
  # 全局变量
23
  HF_TOKEN = os.getenv("HF_TOKEN")
24
- current_vis = []
25
- current_bpe = []
26
- current_index = 0
27
-
28
 
29
  def load_model(check_type):
30
  # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@@ -54,129 +50,52 @@ def load_model(check_type):
54
 
55
  return model.to(device), tokenizer, transform, device
56
 
57
- # def process_image(model, tokenizer, transform, device, check_type, image, text):
58
- # global current_vis, current_bpe, current_index
59
- # src_size = image.size
60
- # if 'TokenOCR' in check_type:
61
- # images, target_ratio = dynamic_preprocess(image, min_num=1, max_num=12,
62
- # image_size=model.config.force_image_size,
63
- # use_thumbnail=model.config.use_thumbnail,
64
- # return_ratio=True)
65
- # pixel_values = torch.stack([transform(img) for img in images]).to(device)
66
- # else:
67
- # pixel_values = torch.stack([transform(image)]).to(device)
68
- # target_ratio = (1, 1)
69
-
70
- # # 文本处理
71
- # text += ' '
72
- # input_ids = tokenizer(text)['input_ids'][1:]
73
- # input_ids = torch.tensor(input_ids, device=device)
74
-
75
- # # 获取嵌入
76
- # with torch.no_grad():
77
- # if 'R50' in check_type:
78
- # text_embeds = model.language_embedding(input_ids)
79
- # else:
80
- # text_embeds = model.tok_embeddings(input_ids)
81
-
82
- # vit_embeds, size1 = model.forward_tokenocr(pixel_values.to(torch.bfloat16).to(device))
83
- # print("vit_embeds",vit_embeds)
84
- # print("vit_embeds,shape",vit_embeds.shape)
85
- # print("target_ratio",target_ratio)
86
- # print("check_type",check_type)
87
- # vit_embeds, size2 = post_process(vit_embeds, target_ratio, check_type)
88
-
89
- # # 计算相似度
90
- # text_embeds = text_embeds / text_embeds.norm(dim=-1, keepdim=True)
91
- # vit_embeds = vit_embeds / vit_embeds.norm(dim=-1, keepdim=True)
92
- # similarity = text_embeds @ vit_embeds.T
93
- # resized_size = size1 if size1 is not None else size2
94
-
95
- # # print(f"text_embeds shape: {text_embeds.shape}, numel: {text_embeds.numel()}") # text_embeds shape: torch.Size([4, 2048]), numel: 8192
96
- # # print(f"vit_embeds shape: {vit_embeds.shape}, numel: {vit_embeds.numel()}") # vit_embeds shape: torch.Size([9728, 2048]), numel: 19922944
97
- # # print(f"similarity shape: {similarity.shape}, numel: {similarity.numel()}")# similarity shape: torch.Size([4, 9728]), numel: 38912
98
-
99
-
100
- # # 生成可视化
101
- # attn_map = similarity.reshape(len(text_embeds), resized_size[0], resized_size[1])
102
- # # attn_map = similarity.reshape(len(text_embeds), *target_ratio)
103
- # all_bpe_strings = [tokenizer.decode(input_id) for input_id in input_ids]
104
- # current_vis = generate_similiarity_map([image], attn_map,
105
- # [tokenizer.decode([i]) for i in input_ids],
106
- # [], target_ratio, src_size)
107
-
108
- # current_bpe = [tokenizer.decode([i]) for i in input_ids]
109
- # # current_bpe[-1] = 'Input text'
110
- # current_bpe[-1] = text
111
- # print("current_vis",len(current_vis))
112
- # print("current_bpe",len(current_bpe))
113
- # return image, current_vis[0], current_bpe[0]
114
-
115
- def process_image(model, tokenizer, transform, device, check_type, image, text):
116
- global current_vis, current_bpe, current_index
117
  src_size = image.size
118
-
119
- # Convert PIL Image to Tensor and move to the appropriate device
120
  if 'TokenOCR' in check_type:
121
- # If dynamic preprocessing is required, handle differently
122
  images, target_ratio = dynamic_preprocess(image, min_num=1, max_num=12,
123
  image_size=model.config.force_image_size,
124
  use_thumbnail=model.config.use_thumbnail,
125
  return_ratio=True)
126
- pixel_values = torch.stack([transform(img).to(device) for img in images])
127
  else:
128
- # Standard image processing for a single image
129
- pixel_values = transform(image).unsqueeze(0).to(device) # Add batch dimension and move to device
130
  target_ratio = (1, 1)
131
 
 
132
  text += ' '
133
- input_ids = tokenizer(text, return_tensors='pt').input_ids.to(device) # Ensure tokens are on the same device
134
-
 
 
135
  with torch.no_grad():
136
  if 'R50' in check_type:
137
  text_embeds = model.language_embedding(input_ids)
138
  else:
139
  text_embeds = model.tok_embeddings(input_ids)
140
-
141
- # vit_embeds, size1 = model.forward_tokenocr(pixel_values)
142
- vit_embeds, size1 = model.forward_tokenocr(pixel_values.to(torch.bfloat16))
143
-
144
  vit_embeds, size2 = post_process(vit_embeds, target_ratio, check_type)
145
-
 
146
  text_embeds = text_embeds / text_embeds.norm(dim=-1, keepdim=True)
147
  vit_embeds = vit_embeds / vit_embeds.norm(dim=-1, keepdim=True)
148
  similarity = text_embeds @ vit_embeds.T
149
  resized_size = size1 if size1 is not None else size2
150
 
151
  attn_map = similarity.reshape(len(text_embeds), resized_size[0], resized_size[1])
152
- current_vis = generate_similiarity_map([image], attn_map,
153
- [tokenizer.decode([i]) for i in input_ids.squeeze()],
154
- [], target_ratio, src_size)
155
-
156
- current_bpe = [tokenizer.decode([i]) for i in input_ids.squeeze()]
157
- current_bpe[-1] = text
158
- return image, current_vis[0], current_bpe[0]
159
-
160
-
161
- # 事件处理函数
162
- def update_index(change):
163
- global current_vis, current_bpe, current_index
164
- current_index = max(0, min(len(current_vis) - 1, current_index + change))
165
- return current_vis[current_index], format_bpe_display(current_bpe[current_index])
166
-
167
- def format_bpe_display(bpe):
168
- # 使用HTML标签来设置字体大小、颜色,加粗,并居中
169
- return f"<div style='text-align:center; font-size:20px;'><strong>Current BPE: <span style='color:red;'>{bpe}</span></strong></div>"
170
-
171
- def update_slider_index(x):
172
- global current_vis, current_bpe, current_index
173
- print(f"x: {x}, current_vis length: {len(current_vis)}, current_bpe length: {len(current_bpe)}")
174
- if 0 <= x < len(current_vis) and 0 <= x < len(current_bpe):
175
- return current_vis[x], format_bpe_display(current_bpe[x])
176
- else:
177
- return None, "索引超出范围"
178
-
179
 
 
 
 
 
180
 
181
  # Gradio界面
182
  with gr.Blocks(title="BPE Visualization Demo") as demo:
@@ -218,58 +137,48 @@ with gr.Blocks(title="BPE Visualization Demo") as demo:
218
 
219
  bpe_display = gr.Markdown("Current BPE: ", visible=False)
220
 
221
- # 事件处理
222
- # @spaces.GPU
223
- # def on_run_clicked(model_type, image, text):
224
- # global current_vis, current_bpe, current_index
225
- # current_index = 0 # Reset index when new image is processed
226
- # image, vis, bpe = process_image(*load_model(model_type), model_type, image, text)
227
- # # Update the slider range and set value to 0
228
- # slider_max_val = len(current_bpe) - 1
229
- # bpe_text = format_bpe_display(bpe)
230
- # print("current_vis",len(current_vis))
231
- # print("current_bpe",len(current_bpe))
232
- # return image, vis, bpe_text, slider_max_val
233
 
234
  @spaces.GPU
235
- def on_run_clicked(model_type, image, text):
236
- global current_vis, current_bpe, current_index
237
- current_index = 0
238
- model, tokenizer, transform, device = load_model(model_type)
239
- image, vis, bpe = process_image(model, tokenizer, transform, device, model_type, image, text)
240
- slider_max_val = len(current_bpe) - 1
241
  bpe_text = format_bpe_display(bpe)
242
- return image, vis, bpe_text, slider_max_val
 
 
 
 
 
 
 
 
 
 
243
 
244
-
245
  run_btn.click(
246
  on_run_clicked,
247
- inputs=[model_type, image_input, text_input],
248
- outputs=[orig_img, heatmap, bpe_display, index_slider],
249
- ).then(
250
- lambda max_val: (gr.update(visible=True), gr.update(visible=True, maximum=max_val, value=0), gr.update(visible=True), gr.update(visible=True)),
251
- inputs=index_slider,
252
- outputs=[prev_btn, index_slider, next_btn, bpe_display],
253
  )
254
 
255
  prev_btn.click(
256
- lambda: (*update_index(-1), current_index),
 
257
  outputs=[heatmap, bpe_display, index_slider]
258
  )
259
 
260
  next_btn.click(
261
- lambda: (*update_index(1), current_index),
 
262
  outputs=[heatmap, bpe_display, index_slider]
263
  )
264
 
265
 
266
  index_slider.change(
267
- update_slider_index,
268
- inputs=index_slider,
269
  outputs=[heatmap, bpe_display]
270
  )
271
 
272
-
273
-
274
  if __name__ == "__main__":
275
- demo.launch()
 
21
 
22
  # 全局变量
23
  HF_TOKEN = os.getenv("HF_TOKEN")
 
 
 
 
24
 
25
  def load_model(check_type):
26
  # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
50
 
51
  return model.to(device), tokenizer, transform, device
52
 
53
+ def process_image(model, tokenizer, transform, device, check_type, image, text, state):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
  src_size = image.size
 
 
55
  if 'TokenOCR' in check_type:
 
56
  images, target_ratio = dynamic_preprocess(image, min_num=1, max_num=12,
57
  image_size=model.config.force_image_size,
58
  use_thumbnail=model.config.use_thumbnail,
59
  return_ratio=True)
60
+ pixel_values = torch.stack([transform(img) for img in images]).to(device)
61
  else:
62
+ pixel_values = torch.stack([transform(image)]).to(device)
 
63
  target_ratio = (1, 1)
64
 
65
+ # 文本处理
66
  text += ' '
67
+ input_ids = tokenizer(text)['input_ids'][1:]
68
+ input_ids = torch.tensor(input_ids, device=device)
69
+
70
+ # 获取嵌入
71
  with torch.no_grad():
72
  if 'R50' in check_type:
73
  text_embeds = model.language_embedding(input_ids)
74
  else:
75
  text_embeds = model.tok_embeddings(input_ids)
76
+
77
+ vit_embeds, size1 = model.forward_tokenocr(pixel_values.to(torch.bfloat16).to(device))
 
 
78
  vit_embeds, size2 = post_process(vit_embeds, target_ratio, check_type)
79
+
80
+ # 计算相似度
81
  text_embeds = text_embeds / text_embeds.norm(dim=-1, keepdim=True)
82
  vit_embeds = vit_embeds / vit_embeds.norm(dim=-1, keepdim=True)
83
  similarity = text_embeds @ vit_embeds.T
84
  resized_size = size1 if size1 is not None else size2
85
 
86
  attn_map = similarity.reshape(len(text_embeds), resized_size[0], resized_size[1])
87
+ all_bpe_strings = [tokenizer.decode(input_id) for input_id in input_ids]
88
+ vis = generate_similiarity_map([image], attn_map,
89
+ [tokenizer.decode([i]) for i in input_ids],
90
+ [], target_ratio, src_size)
91
+
92
+ bpe = [tokenizer.decode([i]) for i in input_ids]
93
+ bpe[-1] = text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
 
95
+ # Store results in state
96
+ state['current_vis'] = vis
97
+ state['current_bpe'] = bpe
98
+ return image, vis[0], bpe[0], len(vis) - 1
99
 
100
  # Gradio界面
101
  with gr.Blocks(title="BPE Visualization Demo") as demo:
 
137
 
138
  bpe_display = gr.Markdown("Current BPE: ", visible=False)
139
 
140
+ state = gr.State(current_vis=[], current_bpe=[], current_index=0)
 
 
 
 
 
 
 
 
 
 
 
141
 
142
  @spaces.GPU
143
+ def on_run_clicked(model_type, image, text, state):
144
+ image, vis, bpe, slider_max_val = process_image(*load_model(model_type), model_type, image, text, state)
 
 
 
 
145
  bpe_text = format_bpe_display(bpe)
146
+ index_slider.update(visible=True, maximum=slider_max_val, value=0)
147
+ prev_btn.update(visible=True)
148
+ next_btn.update(visible=True)
149
+ return image, vis, bpe_text
150
+
151
+ def update_index(change, state):
152
+ state['current_index'] = max(0, min(len(state['current_vis']) - 1, state['current_index'] + change))
153
+ return state['current_vis'][state['current_index']], format_bpe_display(state['current_bpe'][state['current_index']])
154
+
155
+ def format_bpe_display(bpe):
156
+ return f"<div style='text-align:center; font-size:20px;'><strong>Current BPE: <span style='color:red;'>{bpe}</span></strong></div>"
157
 
 
158
  run_btn.click(
159
  on_run_clicked,
160
+ inputs=[model_type, image_input, text_input, state],
161
+ outputs=[orig_img, heatmap, bpe_display],
 
 
 
 
162
  )
163
 
164
  prev_btn.click(
165
+ lambda state: (*update_index(-1, state), state['current_index']),
166
+ inputs=[state],
167
  outputs=[heatmap, bpe_display, index_slider]
168
  )
169
 
170
  next_btn.click(
171
+ lambda state: (*update_index(1, state), state['current_index']),
172
+ inputs=[state],
173
  outputs=[heatmap, bpe_display, index_slider]
174
  )
175
 
176
 
177
  index_slider.change(
178
+ lambda x, state: update_slider_index(x, state),
179
+ inputs=[index_slider, state],
180
  outputs=[heatmap, bpe_display]
181
  )
182
 
 
 
183
  if __name__ == "__main__":
184
+ demo.launch()