TongkunGuan commited on
Commit
b6f1806
·
verified ·
1 Parent(s): 77bca7f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +228 -72
app.py CHANGED
@@ -1,3 +1,222 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
  import argparse
3
  import numpy as np
@@ -23,12 +242,9 @@ CHECKPOINTS = {
23
  HF_TOKEN = os.getenv("HF_TOKEN")
24
  current_vis = []
25
  current_bpe = []
26
- current_index = 0
27
-
28
 
29
  def load_model(check_type):
30
- # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
31
- device = torch.device("cuda")
32
  if check_type == 'R50':
33
  tokenizer = load_tokenizer('tokenizer_path')
34
  model = build_model(argparse.Namespace()).eval()
@@ -55,7 +271,7 @@ def load_model(check_type):
55
  return model.to(device), tokenizer, transform, device
56
 
57
  def process_image(model, tokenizer, transform, device, check_type, image, text):
58
- global current_vis, current_bpe, current_index
59
  src_size = image.size
60
  if 'TokenOCR' in check_type:
61
  images, target_ratio = dynamic_preprocess(image, min_num=1, max_num=12,
@@ -80,10 +296,6 @@ def process_image(model, tokenizer, transform, device, check_type, image, text):
80
  text_embeds = model.tok_embeddings(input_ids)
81
 
82
  vit_embeds, size1 = model.forward_tokenocr(pixel_values.to(torch.bfloat16).to(device))
83
- print("vit_embeds",vit_embeds)
84
- print("vit_embeds,shape",vit_embeds.shape)
85
- print("target_ratio",target_ratio)
86
- print("check_type",check_type)
87
  vit_embeds, size2 = post_process(vit_embeds, target_ratio, check_type)
88
 
89
  # 计算相似度
@@ -92,43 +304,20 @@ def process_image(model, tokenizer, transform, device, check_type, image, text):
92
  similarity = text_embeds @ vit_embeds.T
93
  resized_size = size1 if size1 is not None else size2
94
 
95
- # print(f"text_embeds shape: {text_embeds.shape}, numel: {text_embeds.numel()}") # text_embeds shape: torch.Size([4, 2048]), numel: 8192
96
- # print(f"vit_embeds shape: {vit_embeds.shape}, numel: {vit_embeds.numel()}") # vit_embeds shape: torch.Size([9728, 2048]), numel: 19922944
97
- # print(f"similarity shape: {similarity.shape}, numel: {similarity.numel()}")# similarity shape: torch.Size([4, 9728]), numel: 38912
98
-
99
-
100
  # 生成可视化
101
  attn_map = similarity.reshape(len(text_embeds), resized_size[0], resized_size[1])
102
- # attn_map = similarity.reshape(len(text_embeds), *target_ratio)
103
  all_bpe_strings = [tokenizer.decode(input_id) for input_id in input_ids]
104
  current_vis = generate_similiarity_map([image], attn_map,
105
  [tokenizer.decode([i]) for i in input_ids],
106
  [], target_ratio, src_size)
107
 
108
  current_bpe = [tokenizer.decode([i]) for i in input_ids]
109
- # current_bpe[-1] = 'Input text'
110
  current_bpe[-1] = text
111
- print("current_vis",len(current_vis))
112
- print("current_bpe",len(current_bpe))
113
- return image, current_vis[0], current_bpe[0]
114
-
115
- # 事件处理函数
116
- def update_index(change):
117
- global current_vis, current_bpe, current_index
118
- current_index = max(0, min(len(current_vis) - 1, current_index + change))
119
- return current_vis[current_index], format_bpe_display(current_bpe[current_index])
120
 
121
  def format_bpe_display(bpe):
122
- # 使用HTML标签来设置字体大小、颜色,加粗,并居中
123
  return f"<div style='text-align:center; font-size:20px;'><strong>Current BPE: <span style='color:red;'>{bpe}</span></strong></div>"
124
 
125
- def update_slider_index(x):
126
- print(f"x: {x}, current_vis length: {len(current_vis)}, current_bpe length: {len(current_bpe)}")
127
- if 0 <= x < len(current_vis) and 0 <= x < len(current_bpe):
128
- return current_vis[x], format_bpe_display(current_bpe[x])
129
- else:
130
- return None, "索引超出范围"
131
-
132
  # Gradio界面
133
  with gr.Blocks(title="BPE Visualization Demo") as demo:
134
  gr.Markdown("## BPE Visualization Demo - TokenFD基座模型能力可视化")
@@ -138,7 +327,7 @@ with gr.Blocks(title="BPE Visualization Demo") as demo:
138
  model_type = gr.Dropdown(
139
  choices=["TokenFD_4096_English_seg", "TokenFD_2048_Bilingual_seg", "R50", "R50_siglip"],
140
  label="Select model type",
141
- value="TokenOCR_4096_English_seg" # 设置默认值为第一个选项
142
  )
143
  image_input = gr.Image(label="Upload images", type="pil")
144
  text_input = gr.Textbox(label="Input text")
@@ -162,57 +351,24 @@ with gr.Blocks(title="BPE Visualization Demo") as demo:
162
  orig_img = gr.Image(label="Original picture", interactive=False)
163
  heatmap = gr.Image(label="BPE visualization", interactive=False)
164
 
165
- with gr.Row() as controls:
166
- prev_btn = gr.Button("⬅ Last", visible=False)
167
- index_slider = gr.Slider(0, 1, value=0, step=1, label="BPE index", visible=False)
168
- next_btn = gr.Button("⮕ Next", visible=False)
169
-
170
  bpe_display = gr.Markdown("Current BPE: ", visible=False)
171
 
172
  # 事件处理
173
  @spaces.GPU
174
  def on_run_clicked(model_type, image, text):
175
- global current_vis, current_bpe, current_index
176
- current_index = 0 # Reset index when new image is processed
177
  image, vis, bpe = process_image(*load_model(model_type), model_type, image, text)
178
- # Update the slider range and set value to 0
179
- slider_max_val = len(current_bpe) - 1
180
  bpe_text = format_bpe_display(bpe)
181
- print("current_vis",len(current_vis))
182
- print("current_bpe",len(current_bpe))
183
- return image, vis, bpe_text, slider_max_val
184
 
185
  run_btn.click(
186
  on_run_clicked,
187
  inputs=[model_type, image_input, text_input],
188
- outputs=[orig_img, heatmap, bpe_display, index_slider],
189
  ).then(
190
- lambda max_val: (gr.update(visible=True), gr.update(visible=True, maximum=max_val, value=0), gr.update(visible=True), gr.update(visible=True)),
191
- inputs=index_slider,
192
- outputs=[prev_btn, index_slider, next_btn, bpe_display],
193
  )
194
-
195
- prev_btn.click(
196
- lambda: (*update_index(-1), current_index),
197
- outputs=[heatmap, bpe_display, index_slider]
198
- )
199
-
200
- next_btn.click(
201
- lambda: (*update_index(1), current_index),
202
- outputs=[heatmap, bpe_display, index_slider]
203
- )
204
-
205
- # index_slider.change(
206
- # lambda x: (current_vis[x], format_bpe_display(current_bpe[x])) if 0<=x<len(current_vis else (None,"Invaild")
207
- # inputs=index_slider,
208
- # outputs=[heatmap, bpe_display]
209
- # )
210
-
211
- index_slider.change(
212
- update_slider_index,
213
- inputs=index_slider,
214
- outputs=[heatmap, bpe_display]
215
- )
216
 
217
  if __name__ == "__main__":
218
  demo.launch()
 
1
+ # import os
2
+ # import argparse
3
+ # import numpy as np
4
+ # from PIL import Image
5
+ # import torch
6
+ # import torchvision.transforms as T
7
+ # from transformers import AutoTokenizer
8
+ # import gradio as gr
9
+ # from resnet50 import build_model
10
+ # from utils import generate_similiarity_map, post_process, load_tokenizer, build_transform_R50
11
+ # from utils import IMAGENET_MEAN, IMAGENET_STD
12
+ # from internvl.train.dataset import dynamic_preprocess
13
+ # from internvl.model.internvl_chat import InternVLChatModel
14
+ # import spaces
15
+
16
+ # # 模型配置
17
+ # CHECKPOINTS = {
18
+ # "TokenFD_4096_English_seg": "TongkunGuan/TokenFD_4096_English_seg",
19
+ # "TokenFD_2048_Bilingual_seg": "TongkunGuan/TokenFD_2048_Bilingual_seg",
20
+ # }
21
+
22
+ # # 全局变量
23
+ # HF_TOKEN = os.getenv("HF_TOKEN")
24
+ # current_vis = []
25
+ # current_bpe = []
26
+ # current_index = 0
27
+
28
+
29
+ # def load_model(check_type):
30
+ # # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
31
+ # device = torch.device("cuda")
32
+ # if check_type == 'R50':
33
+ # tokenizer = load_tokenizer('tokenizer_path')
34
+ # model = build_model(argparse.Namespace()).eval()
35
+ # model.load_state_dict(torch.load(CHECKPOINTS['R50'], map_location='cpu')['model'])
36
+ # transform = build_transform_R50(normalize_type='imagenet')
37
+
38
+ # elif check_type == 'R50_siglip':
39
+ # tokenizer = load_tokenizer('tokenizer_path')
40
+ # model = build_model(argparse.Namespace()).eval()
41
+ # model.load_state_dict(torch.load(CHECKPOINTS['R50_siglip'], map_location='cpu')['model'])
42
+ # transform = build_transform_R50(normalize_type='imagenet')
43
+
44
+ # elif 'TokenFD' in check_type:
45
+ # model_path = CHECKPOINTS[check_type]
46
+ # tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True, use_fast=False, use_auth_token=HF_TOKEN)
47
+ # model = InternVLChatModel.from_pretrained(model_path, torch_dtype=torch.bfloat16).eval()
48
+ # transform = T.Compose([
49
+ # T.Lambda(lambda img: img.convert('RGB')),
50
+ # T.Resize((224, 224)),
51
+ # T.ToTensor(),
52
+ # T.Normalize(IMAGENET_MEAN, IMAGENET_STD)
53
+ # ])
54
+
55
+ # return model.to(device), tokenizer, transform, device
56
+
57
+ # def process_image(model, tokenizer, transform, device, check_type, image, text):
58
+ # global current_vis, current_bpe, current_index
59
+ # src_size = image.size
60
+ # if 'TokenOCR' in check_type:
61
+ # images, target_ratio = dynamic_preprocess(image, min_num=1, max_num=12,
62
+ # image_size=model.config.force_image_size,
63
+ # use_thumbnail=model.config.use_thumbnail,
64
+ # return_ratio=True)
65
+ # pixel_values = torch.stack([transform(img) for img in images]).to(device)
66
+ # else:
67
+ # pixel_values = torch.stack([transform(image)]).to(device)
68
+ # target_ratio = (1, 1)
69
+
70
+ # # 文本处理
71
+ # text += ' '
72
+ # input_ids = tokenizer(text)['input_ids'][1:]
73
+ # input_ids = torch.tensor(input_ids, device=device)
74
+
75
+ # # 获取嵌入
76
+ # with torch.no_grad():
77
+ # if 'R50' in check_type:
78
+ # text_embeds = model.language_embedding(input_ids)
79
+ # else:
80
+ # text_embeds = model.tok_embeddings(input_ids)
81
+
82
+ # vit_embeds, size1 = model.forward_tokenocr(pixel_values.to(torch.bfloat16).to(device))
83
+ # print("vit_embeds",vit_embeds)
84
+ # print("vit_embeds,shape",vit_embeds.shape)
85
+ # print("target_ratio",target_ratio)
86
+ # print("check_type",check_type)
87
+ # vit_embeds, size2 = post_process(vit_embeds, target_ratio, check_type)
88
+
89
+ # # 计算相似度
90
+ # text_embeds = text_embeds / text_embeds.norm(dim=-1, keepdim=True)
91
+ # vit_embeds = vit_embeds / vit_embeds.norm(dim=-1, keepdim=True)
92
+ # similarity = text_embeds @ vit_embeds.T
93
+ # resized_size = size1 if size1 is not None else size2
94
+
95
+ # # print(f"text_embeds shape: {text_embeds.shape}, numel: {text_embeds.numel()}") # text_embeds shape: torch.Size([4, 2048]), numel: 8192
96
+ # # print(f"vit_embeds shape: {vit_embeds.shape}, numel: {vit_embeds.numel()}") # vit_embeds shape: torch.Size([9728, 2048]), numel: 19922944
97
+ # # print(f"similarity shape: {similarity.shape}, numel: {similarity.numel()}")# similarity shape: torch.Size([4, 9728]), numel: 38912
98
+
99
+
100
+ # # 生成可视化
101
+ # attn_map = similarity.reshape(len(text_embeds), resized_size[0], resized_size[1])
102
+ # # attn_map = similarity.reshape(len(text_embeds), *target_ratio)
103
+ # all_bpe_strings = [tokenizer.decode(input_id) for input_id in input_ids]
104
+ # current_vis = generate_similiarity_map([image], attn_map,
105
+ # [tokenizer.decode([i]) for i in input_ids],
106
+ # [], target_ratio, src_size)
107
+
108
+ # current_bpe = [tokenizer.decode([i]) for i in input_ids]
109
+ # # current_bpe[-1] = 'Input text'
110
+ # current_bpe[-1] = text
111
+ # print("current_vis",len(current_vis))
112
+ # print("current_bpe",len(current_bpe))
113
+ # return image, current_vis[0], current_bpe[0]
114
+
115
+ # # 事件处理函数
116
+ # def update_index(change):
117
+ # global current_vis, current_bpe, current_index
118
+ # current_index = max(0, min(len(current_vis) - 1, current_index + change))
119
+ # return current_vis[current_index], format_bpe_display(current_bpe[current_index])
120
+
121
+ # def format_bpe_display(bpe):
122
+ # # 使用HTML标签来设置字体大小、颜色,加粗,并居中
123
+ # return f"<div style='text-align:center; font-size:20px;'><strong>Current BPE: <span style='color:red;'>{bpe}</span></strong></div>"
124
+
125
+ # def update_slider_index(x):
126
+ # print(f"x: {x}, current_vis length: {len(current_vis)}, current_bpe length: {len(current_bpe)}")
127
+ # if 0 <= x < len(current_vis) and 0 <= x < len(current_bpe):
128
+ # return current_vis[x], format_bpe_display(current_bpe[x])
129
+ # else:
130
+ # return None, "索引超出范围"
131
+
132
+ # # Gradio界面
133
+ # with gr.Blocks(title="BPE Visualization Demo") as demo:
134
+ # gr.Markdown("## BPE Visualization Demo - TokenFD基座模型能力可视化")
135
+
136
+ # with gr.Row():
137
+ # with gr.Column(scale=0.5):
138
+ # model_type = gr.Dropdown(
139
+ # choices=["TokenFD_4096_English_seg", "TokenFD_2048_Bilingual_seg", "R50", "R50_siglip"],
140
+ # label="Select model type",
141
+ # value="TokenOCR_4096_English_seg" # 设置默认值为第一个选项
142
+ # )
143
+ # image_input = gr.Image(label="Upload images", type="pil")
144
+ # text_input = gr.Textbox(label="Input text")
145
+
146
+ # run_btn = gr.Button("RUN")
147
+
148
+ # gr.Examples(
149
+ # examples=[
150
+ # [os.path.join("examples", "examples0.jpg"), "Veterans and Benefits"],
151
+ # [os.path.join("examples", "examples1.jpg"), "Refreshers"],
152
+ # [os.path.join("examples", "examples2.png"), "Vision Transformer"]
153
+ # ],
154
+ # inputs=[image_input, text_input],
155
+ # label="Sample input"
156
+ # )
157
+
158
+ # with gr.Column(scale=2):
159
+ # gr.Markdown("<p style='font-size:20px;'><span style='color:red;'>If the input text is not included in the image</span>, the attention map will show a lot of noise (the actual response value is very low), since we normalize the attention map according to the relative value.</p>")
160
+
161
+ # with gr.Row():
162
+ # orig_img = gr.Image(label="Original picture", interactive=False)
163
+ # heatmap = gr.Image(label="BPE visualization", interactive=False)
164
+
165
+ # with gr.Row() as controls:
166
+ # prev_btn = gr.Button("⬅ Last", visible=False)
167
+ # index_slider = gr.Slider(0, 1, value=0, step=1, label="BPE index", visible=False)
168
+ # next_btn = gr.Button("⮕ Next", visible=False)
169
+
170
+ # bpe_display = gr.Markdown("Current BPE: ", visible=False)
171
+
172
+ # # 事件处理
173
+ # @spaces.GPU
174
+ # def on_run_clicked(model_type, image, text):
175
+ # global current_vis, current_bpe, current_index
176
+ # current_index = 0 # Reset index when new image is processed
177
+ # image, vis, bpe = process_image(*load_model(model_type), model_type, image, text)
178
+ # # Update the slider range and set value to 0
179
+ # slider_max_val = len(current_bpe) - 1
180
+ # bpe_text = format_bpe_display(bpe)
181
+ # print("current_vis",len(current_vis))
182
+ # print("current_bpe",len(current_bpe))
183
+ # return image, vis, bpe_text, slider_max_val
184
+
185
+ # run_btn.click(
186
+ # on_run_clicked,
187
+ # inputs=[model_type, image_input, text_input],
188
+ # outputs=[orig_img, heatmap, bpe_display, index_slider],
189
+ # ).then(
190
+ # lambda max_val: (gr.update(visible=True), gr.update(visible=True, maximum=max_val, value=0), gr.update(visible=True), gr.update(visible=True)),
191
+ # inputs=index_slider,
192
+ # outputs=[prev_btn, index_slider, next_btn, bpe_display],
193
+ # )
194
+
195
+ # prev_btn.click(
196
+ # lambda: (*update_index(-1), current_index),
197
+ # outputs=[heatmap, bpe_display, index_slider]
198
+ # )
199
+
200
+ # next_btn.click(
201
+ # lambda: (*update_index(1), current_index),
202
+ # outputs=[heatmap, bpe_display, index_slider]
203
+ # )
204
+
205
+ # # index_slider.change(
206
+ # # lambda x: (current_vis[x], format_bpe_display(current_bpe[x])) if 0<=x<len(current_vis else (None,"Invaild")
207
+ # # inputs=index_slider,
208
+ # # outputs=[heatmap, bpe_display]
209
+ # # )
210
+
211
+ # index_slider.change(
212
+ # update_slider_index,
213
+ # inputs=index_slider,
214
+ # outputs=[heatmap, bpe_display]
215
+ # )
216
+
217
+ # if __name__ == "__main__":
218
+ # demo.launch()
219
+
220
  import os
221
  import argparse
222
  import numpy as np
 
242
  HF_TOKEN = os.getenv("HF_TOKEN")
243
  current_vis = []
244
  current_bpe = []
 
 
245
 
246
  def load_model(check_type):
247
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
248
  if check_type == 'R50':
249
  tokenizer = load_tokenizer('tokenizer_path')
250
  model = build_model(argparse.Namespace()).eval()
 
271
  return model.to(device), tokenizer, transform, device
272
 
273
  def process_image(model, tokenizer, transform, device, check_type, image, text):
274
+ global current_vis, current_bpe
275
  src_size = image.size
276
  if 'TokenOCR' in check_type:
277
  images, target_ratio = dynamic_preprocess(image, min_num=1, max_num=12,
 
296
  text_embeds = model.tok_embeddings(input_ids)
297
 
298
  vit_embeds, size1 = model.forward_tokenocr(pixel_values.to(torch.bfloat16).to(device))
 
 
 
 
299
  vit_embeds, size2 = post_process(vit_embeds, target_ratio, check_type)
300
 
301
  # 计算相似度
 
304
  similarity = text_embeds @ vit_embeds.T
305
  resized_size = size1 if size1 is not None else size2
306
 
 
 
 
 
 
307
  # 生成可视化
308
  attn_map = similarity.reshape(len(text_embeds), resized_size[0], resized_size[1])
 
309
  all_bpe_strings = [tokenizer.decode(input_id) for input_id in input_ids]
310
  current_vis = generate_similiarity_map([image], attn_map,
311
  [tokenizer.decode([i]) for i in input_ids],
312
  [], target_ratio, src_size)
313
 
314
  current_bpe = [tokenizer.decode([i]) for i in input_ids]
 
315
  current_bpe[-1] = text
316
+ return image, current_vis, current_bpe
 
 
 
 
 
 
 
 
317
 
318
  def format_bpe_display(bpe):
 
319
  return f"<div style='text-align:center; font-size:20px;'><strong>Current BPE: <span style='color:red;'>{bpe}</span></strong></div>"
320
 
 
 
 
 
 
 
 
321
  # Gradio界面
322
  with gr.Blocks(title="BPE Visualization Demo") as demo:
323
  gr.Markdown("## BPE Visualization Demo - TokenFD基座模型能力可视化")
 
327
  model_type = gr.Dropdown(
328
  choices=["TokenFD_4096_English_seg", "TokenFD_2048_Bilingual_seg", "R50", "R50_siglip"],
329
  label="Select model type",
330
+ value="TokenOCR_4096_English_seg"
331
  )
332
  image_input = gr.Image(label="Upload images", type="pil")
333
  text_input = gr.Textbox(label="Input text")
 
351
  orig_img = gr.Image(label="Original picture", interactive=False)
352
  heatmap = gr.Image(label="BPE visualization", interactive=False)
353
 
 
 
 
 
 
354
  bpe_display = gr.Markdown("Current BPE: ", visible=False)
355
 
356
  # 事件处理
357
  @spaces.GPU
358
  def on_run_clicked(model_type, image, text):
359
+ global current_vis, current_bpe
 
360
  image, vis, bpe = process_image(*load_model(model_type), model_type, image, text)
 
 
361
  bpe_text = format_bpe_display(bpe)
362
+ return image, vis[0], bpe_text
 
 
363
 
364
  run_btn.click(
365
  on_run_clicked,
366
  inputs=[model_type, image_input, text_input],
367
+ outputs=[orig_img, heatmap, bpe_display],
368
  ).then(
369
+ lambda: (gr.update(visible=True)),
370
+ outputs=[bpe_display],
 
371
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
372
 
373
  if __name__ == "__main__":
374
  demo.launch()