TongkunGuan commited on
Commit
3d2b840
·
verified ·
1 Parent(s): 9c191d1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +91 -80
app.py CHANGED
@@ -25,11 +25,6 @@ current_vis = []
25
  current_bpe = []
26
  current_index = 0
27
 
28
- # 设置初始状态
29
- initial_state = {
30
- "vis": [],
31
- "bpe": []
32
- }
33
 
34
  def load_model(check_type):
35
  # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@@ -59,9 +54,70 @@ def load_model(check_type):
59
 
60
  return model.to(device), tokenizer, transform, device
61
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
  def process_image(model, tokenizer, transform, device, check_type, image, text):
63
  global current_vis, current_bpe, current_index
64
  src_size = image.size
 
 
 
65
  if 'TokenOCR' in check_type:
66
  images, target_ratio = dynamic_preprocess(image, min_num=1, max_num=12,
67
  image_size=model.config.force_image_size,
@@ -72,50 +128,33 @@ def process_image(model, tokenizer, transform, device, check_type, image, text):
72
  pixel_values = torch.stack([transform(image)]).to(device)
73
  target_ratio = (1, 1)
74
 
75
- # 文本处理
76
  text += ' '
77
  input_ids = tokenizer(text)['input_ids'][1:]
78
  input_ids = torch.tensor(input_ids, device=device)
79
-
80
- # 获取嵌入
81
  with torch.no_grad():
82
  if 'R50' in check_type:
83
  text_embeds = model.language_embedding(input_ids)
84
  else:
85
  text_embeds = model.tok_embeddings(input_ids)
86
-
87
- vit_embeds, size1 = model.forward_tokenocr(pixel_values.to(torch.bfloat16).to(device))
88
- print("vit_embeds",vit_embeds)
89
- print("vit_embeds,shape",vit_embeds.shape)
90
- print("target_ratio",target_ratio)
91
- print("check_type",check_type)
92
  vit_embeds, size2 = post_process(vit_embeds, target_ratio, check_type)
93
-
94
- # 计算相似度
95
  text_embeds = text_embeds / text_embeds.norm(dim=-1, keepdim=True)
96
  vit_embeds = vit_embeds / vit_embeds.norm(dim=-1, keepdim=True)
97
  similarity = text_embeds @ vit_embeds.T
98
  resized_size = size1 if size1 is not None else size2
99
 
100
- # print(f"text_embeds shape: {text_embeds.shape}, numel: {text_embeds.numel()}") # text_embeds shape: torch.Size([4, 2048]), numel: 8192
101
- # print(f"vit_embeds shape: {vit_embeds.shape}, numel: {vit_embeds.numel()}") # vit_embeds shape: torch.Size([9728, 2048]), numel: 19922944
102
- # print(f"similarity shape: {similarity.shape}, numel: {similarity.numel()}")# similarity shape: torch.Size([4, 9728]), numel: 38912
103
-
104
-
105
- # 生成可视化
106
  attn_map = similarity.reshape(len(text_embeds), resized_size[0], resized_size[1])
107
- # attn_map = similarity.reshape(len(text_embeds), *target_ratio)
108
  all_bpe_strings = [tokenizer.decode(input_id) for input_id in input_ids]
109
- current_vis = generate_similiarity_map([image], attn_map,
110
  [tokenizer.decode([i]) for i in input_ids],
111
  [], target_ratio, src_size)
112
-
113
  current_bpe = [tokenizer.decode([i]) for i in input_ids]
114
- # current_bpe[-1] = 'Input text'
115
  current_bpe[-1] = text
116
- print("current_vis",len(current_vis))
117
- print("current_bpe",len(current_bpe))
118
- return image, current_vis[0], current_bpe[0]
119
 
120
  # 事件处理函数
121
  def update_index(change):
@@ -127,24 +166,13 @@ def format_bpe_display(bpe):
127
  # 使用HTML标签来设置字体大小、颜色,加粗,并居中
128
  return f"<div style='text-align:center; font-size:20px;'><strong>Current BPE: <span style='color:red;'>{bpe}</span></strong></div>"
129
 
130
- # def update_slider_index(x):
131
- # global current_vis, current_bpe, current_index
132
- # print(f"x: {x}, current_vis length: {len(current_vis)}, current_bpe length: {len(current_bpe)}")
133
- # if 0 <= x < len(current_vis) and 0 <= x < len(current_bpe):
134
- # return current_vis[x], format_bpe_display(current_bpe[x])
135
- # else:
136
- # return None, "索引超出范围"
137
- # 状态更新函数,利用传递的状态(vis, bpe)
138
- # 使用状态信息来处理滑动条改变
139
- def update_slider_index(x, state):
140
- vis = state['vis']
141
- bpe = state['bpe']
142
- if 0 <= x < len(vis):
143
- return vis[x], format_bpe_display(bpe[x]), state
144
  else:
145
- return None, "索引超出范围", state
146
-
147
-
148
 
149
 
150
 
@@ -202,37 +230,24 @@ with gr.Blocks(title="BPE Visualization Demo") as demo:
202
  # return image, vis, bpe_text, slider_max_val
203
 
204
  @spaces.GPU
205
- def on_run_clicked(model_type, image, text, state):
 
 
206
  model, tokenizer, transform, device = load_model(model_type)
207
- current_index = 0 # Reset index when new image is processed
208
  image, vis, bpe = process_image(model, tokenizer, transform, device, model_type, image, text)
209
- slider_max_val = len(bpe) - 1
210
- bpe_text = format_bpe_display(bpe[current_index])
211
- # 更新状态并返回
212
- state['vis'] = vis
213
- state['bpe'] = bpe
214
- return image, vis[current_index], bpe_text, slider_max_val, state
215
-
216
-
217
-
218
 
219
 
220
- # run_btn.click(
221
- # on_run_clicked,
222
- # inputs=[model_type, image_input, text_input],
223
- # outputs=[orig_img, heatmap, bpe_display, index_slider],
224
- # ).then(
225
- # lambda max_val: (gr.update(visible=True), gr.update(visible=True, maximum=max_val, value=0), gr.update(visible=True), gr.update(visible=True)),
226
- # inputs=index_slider,
227
- # outputs=[prev_btn, index_slider, next_btn, bpe_display],
228
- # )
229
- # Gradio 按钮点击后的处理
230
- # Gradio 按钮点击后的处理
231
  run_btn.click(
232
  on_run_clicked,
233
- inputs=[model_type, image_input, text_input, 'state'],
234
- outputs=[orig_img, heatmap, bpe_display, index_slider, 'state'],
235
- _js="{state: { vis: [], bpe: []}}"
 
 
 
236
  )
237
 
238
  prev_btn.click(
@@ -246,16 +261,12 @@ with gr.Blocks(title="BPE Visualization Demo") as demo:
246
  )
247
 
248
 
249
- # index_slider.change(
250
- # update_slider_index,
251
- # inputs=index_slider,
252
- # outputs=[heatmap, bpe_display]
253
- # )
254
  index_slider.change(
255
- update_slider_index,
256
- inputs=[index_slider, 'state'],
257
- outputs=[heatmap, bpe_display, 'state']
258
- )
 
259
 
260
 
261
  if __name__ == "__main__":
 
25
  current_bpe = []
26
  current_index = 0
27
 
 
 
 
 
 
28
 
29
  def load_model(check_type):
30
  # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
54
 
55
  return model.to(device), tokenizer, transform, device
56
 
57
+ # def process_image(model, tokenizer, transform, device, check_type, image, text):
58
+ # global current_vis, current_bpe, current_index
59
+ # src_size = image.size
60
+ # if 'TokenOCR' in check_type:
61
+ # images, target_ratio = dynamic_preprocess(image, min_num=1, max_num=12,
62
+ # image_size=model.config.force_image_size,
63
+ # use_thumbnail=model.config.use_thumbnail,
64
+ # return_ratio=True)
65
+ # pixel_values = torch.stack([transform(img) for img in images]).to(device)
66
+ # else:
67
+ # pixel_values = torch.stack([transform(image)]).to(device)
68
+ # target_ratio = (1, 1)
69
+
70
+ # # 文本处理
71
+ # text += ' '
72
+ # input_ids = tokenizer(text)['input_ids'][1:]
73
+ # input_ids = torch.tensor(input_ids, device=device)
74
+
75
+ # # 获取嵌入
76
+ # with torch.no_grad():
77
+ # if 'R50' in check_type:
78
+ # text_embeds = model.language_embedding(input_ids)
79
+ # else:
80
+ # text_embeds = model.tok_embeddings(input_ids)
81
+
82
+ # vit_embeds, size1 = model.forward_tokenocr(pixel_values.to(torch.bfloat16).to(device))
83
+ # print("vit_embeds",vit_embeds)
84
+ # print("vit_embeds,shape",vit_embeds.shape)
85
+ # print("target_ratio",target_ratio)
86
+ # print("check_type",check_type)
87
+ # vit_embeds, size2 = post_process(vit_embeds, target_ratio, check_type)
88
+
89
+ # # 计算相似度
90
+ # text_embeds = text_embeds / text_embeds.norm(dim=-1, keepdim=True)
91
+ # vit_embeds = vit_embeds / vit_embeds.norm(dim=-1, keepdim=True)
92
+ # similarity = text_embeds @ vit_embeds.T
93
+ # resized_size = size1 if size1 is not None else size2
94
+
95
+ # # print(f"text_embeds shape: {text_embeds.shape}, numel: {text_embeds.numel()}") # text_embeds shape: torch.Size([4, 2048]), numel: 8192
96
+ # # print(f"vit_embeds shape: {vit_embeds.shape}, numel: {vit_embeds.numel()}") # vit_embeds shape: torch.Size([9728, 2048]), numel: 19922944
97
+ # # print(f"similarity shape: {similarity.shape}, numel: {similarity.numel()}")# similarity shape: torch.Size([4, 9728]), numel: 38912
98
+
99
+
100
+ # # 生成可视化
101
+ # attn_map = similarity.reshape(len(text_embeds), resized_size[0], resized_size[1])
102
+ # # attn_map = similarity.reshape(len(text_embeds), *target_ratio)
103
+ # all_bpe_strings = [tokenizer.decode(input_id) for input_id in input_ids]
104
+ # current_vis = generate_similiarity_map([image], attn_map,
105
+ # [tokenizer.decode([i]) for i in input_ids],
106
+ # [], target_ratio, src_size)
107
+
108
+ # current_bpe = [tokenizer.decode([i]) for i in input_ids]
109
+ # # current_bpe[-1] = 'Input text'
110
+ # current_bpe[-1] = text
111
+ # print("current_vis",len(current_vis))
112
+ # print("current_bpe",len(current_bpe))
113
+ # return image, current_vis[0], current_bpe[0]
114
+
115
  def process_image(model, tokenizer, transform, device, check_type, image, text):
116
  global current_vis, current_bpe, current_index
117
  src_size = image.size
118
+ # Ensure all processing is done on the correct device
119
+ image = image.to(device)
120
+
121
  if 'TokenOCR' in check_type:
122
  images, target_ratio = dynamic_preprocess(image, min_num=1, max_num=12,
123
  image_size=model.config.force_image_size,
 
128
  pixel_values = torch.stack([transform(image)]).to(device)
129
  target_ratio = (1, 1)
130
 
 
131
  text += ' '
132
  input_ids = tokenizer(text)['input_ids'][1:]
133
  input_ids = torch.tensor(input_ids, device=device)
134
+
 
135
  with torch.no_grad():
136
  if 'R50' in check_type:
137
  text_embeds = model.language_embedding(input_ids)
138
  else:
139
  text_embeds = model.tok_embeddings(input_ids)
140
+
141
+ vit_embeds, size1 = model.forward_tokenocr(pixel_values)
 
 
 
 
142
  vit_embeds, size2 = post_process(vit_embeds, target_ratio, check_type)
143
+
 
144
  text_embeds = text_embeds / text_embeds.norm(dim=-1, keepdim=True)
145
  vit_embeds = vit_embeds / vit_embeds.norm(dim=-1, keepdim=True)
146
  similarity = text_embeds @ vit_embeds.T
147
  resized_size = size1 if size1 is not None else size2
148
 
 
 
 
 
 
 
149
  attn_map = similarity.reshape(len(text_embeds), resized_size[0], resized_size[1])
 
150
  all_bpe_strings = [tokenizer.decode(input_id) for input_id in input_ids]
151
+ current_vis = generate_similiarity_map([image.cpu()], attn_map.cpu(),
152
  [tokenizer.decode([i]) for i in input_ids],
153
  [], target_ratio, src_size)
154
+
155
  current_bpe = [tokenizer.decode([i]) for i in input_ids]
 
156
  current_bpe[-1] = text
157
+ return image.cpu(), current_vis[0], current_bpe[0]
 
 
158
 
159
  # 事件处理函数
160
  def update_index(change):
 
166
  # 使用HTML标签来设置字体大小、颜色,加粗,并居中
167
  return f"<div style='text-align:center; font-size:20px;'><strong>Current BPE: <span style='color:red;'>{bpe}</span></strong></div>"
168
 
169
+ def update_slider_index(x):
170
+ global current_vis, current_bpe, current_index
171
+ print(f"x: {x}, current_vis length: {len(current_vis)}, current_bpe length: {len(current_bpe)}")
172
+ if 0 <= x < len(current_vis) and 0 <= x < len(current_bpe):
173
+ return current_vis[x], format_bpe_display(current_bpe[x])
 
 
 
 
 
 
 
 
 
174
  else:
175
+ return None, "索引超出范围"
 
 
176
 
177
 
178
 
 
230
  # return image, vis, bpe_text, slider_max_val
231
 
232
  @spaces.GPU
233
+ def on_run_clicked(model_type, image, text):
234
+ global current_vis, current_bpe, current_index
235
+ current_index = 0
236
  model, tokenizer, transform, device = load_model(model_type)
 
237
  image, vis, bpe = process_image(model, tokenizer, transform, device, model_type, image, text)
238
+ slider_max_val = len(current_bpe) - 1
239
+ bpe_text = format_bpe_display(bpe)
240
+ return image, vis, bpe_text, slider_max_val
 
 
 
 
 
 
241
 
242
 
 
 
 
 
 
 
 
 
 
 
 
243
  run_btn.click(
244
  on_run_clicked,
245
+ inputs=[model_type, image_input, text_input],
246
+ outputs=[orig_img, heatmap, bpe_display, index_slider],
247
+ ).then(
248
+ lambda max_val: (gr.update(visible=True), gr.update(visible=True, maximum=max_val, value=0), gr.update(visible=True), gr.update(visible=True)),
249
+ inputs=index_slider,
250
+ outputs=[prev_btn, index_slider, next_btn, bpe_display],
251
  )
252
 
253
  prev_btn.click(
 
261
  )
262
 
263
 
 
 
 
 
 
264
  index_slider.change(
265
+ update_slider_index,
266
+ inputs=index_slider,
267
+ outputs=[heatmap, bpe_display]
268
+ )
269
+
270
 
271
 
272
  if __name__ == "__main__":