TongkunGuan commited on
Commit
aa84990
·
verified ·
1 Parent(s): b6f1806

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +28 -230
app.py CHANGED
@@ -1,222 +1,3 @@
1
- # import os
2
- # import argparse
3
- # import numpy as np
4
- # from PIL import Image
5
- # import torch
6
- # import torchvision.transforms as T
7
- # from transformers import AutoTokenizer
8
- # import gradio as gr
9
- # from resnet50 import build_model
10
- # from utils import generate_similiarity_map, post_process, load_tokenizer, build_transform_R50
11
- # from utils import IMAGENET_MEAN, IMAGENET_STD
12
- # from internvl.train.dataset import dynamic_preprocess
13
- # from internvl.model.internvl_chat import InternVLChatModel
14
- # import spaces
15
-
16
- # # 模型配置
17
- # CHECKPOINTS = {
18
- # "TokenFD_4096_English_seg": "TongkunGuan/TokenFD_4096_English_seg",
19
- # "TokenFD_2048_Bilingual_seg": "TongkunGuan/TokenFD_2048_Bilingual_seg",
20
- # }
21
-
22
- # # 全局变量
23
- # HF_TOKEN = os.getenv("HF_TOKEN")
24
- # current_vis = []
25
- # current_bpe = []
26
- # current_index = 0
27
-
28
-
29
- # def load_model(check_type):
30
- # # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
31
- # device = torch.device("cuda")
32
- # if check_type == 'R50':
33
- # tokenizer = load_tokenizer('tokenizer_path')
34
- # model = build_model(argparse.Namespace()).eval()
35
- # model.load_state_dict(torch.load(CHECKPOINTS['R50'], map_location='cpu')['model'])
36
- # transform = build_transform_R50(normalize_type='imagenet')
37
-
38
- # elif check_type == 'R50_siglip':
39
- # tokenizer = load_tokenizer('tokenizer_path')
40
- # model = build_model(argparse.Namespace()).eval()
41
- # model.load_state_dict(torch.load(CHECKPOINTS['R50_siglip'], map_location='cpu')['model'])
42
- # transform = build_transform_R50(normalize_type='imagenet')
43
-
44
- # elif 'TokenFD' in check_type:
45
- # model_path = CHECKPOINTS[check_type]
46
- # tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True, use_fast=False, use_auth_token=HF_TOKEN)
47
- # model = InternVLChatModel.from_pretrained(model_path, torch_dtype=torch.bfloat16).eval()
48
- # transform = T.Compose([
49
- # T.Lambda(lambda img: img.convert('RGB')),
50
- # T.Resize((224, 224)),
51
- # T.ToTensor(),
52
- # T.Normalize(IMAGENET_MEAN, IMAGENET_STD)
53
- # ])
54
-
55
- # return model.to(device), tokenizer, transform, device
56
-
57
- # def process_image(model, tokenizer, transform, device, check_type, image, text):
58
- # global current_vis, current_bpe, current_index
59
- # src_size = image.size
60
- # if 'TokenOCR' in check_type:
61
- # images, target_ratio = dynamic_preprocess(image, min_num=1, max_num=12,
62
- # image_size=model.config.force_image_size,
63
- # use_thumbnail=model.config.use_thumbnail,
64
- # return_ratio=True)
65
- # pixel_values = torch.stack([transform(img) for img in images]).to(device)
66
- # else:
67
- # pixel_values = torch.stack([transform(image)]).to(device)
68
- # target_ratio = (1, 1)
69
-
70
- # # 文本处理
71
- # text += ' '
72
- # input_ids = tokenizer(text)['input_ids'][1:]
73
- # input_ids = torch.tensor(input_ids, device=device)
74
-
75
- # # 获取嵌入
76
- # with torch.no_grad():
77
- # if 'R50' in check_type:
78
- # text_embeds = model.language_embedding(input_ids)
79
- # else:
80
- # text_embeds = model.tok_embeddings(input_ids)
81
-
82
- # vit_embeds, size1 = model.forward_tokenocr(pixel_values.to(torch.bfloat16).to(device))
83
- # print("vit_embeds",vit_embeds)
84
- # print("vit_embeds,shape",vit_embeds.shape)
85
- # print("target_ratio",target_ratio)
86
- # print("check_type",check_type)
87
- # vit_embeds, size2 = post_process(vit_embeds, target_ratio, check_type)
88
-
89
- # # 计算相似度
90
- # text_embeds = text_embeds / text_embeds.norm(dim=-1, keepdim=True)
91
- # vit_embeds = vit_embeds / vit_embeds.norm(dim=-1, keepdim=True)
92
- # similarity = text_embeds @ vit_embeds.T
93
- # resized_size = size1 if size1 is not None else size2
94
-
95
- # # print(f"text_embeds shape: {text_embeds.shape}, numel: {text_embeds.numel()}") # text_embeds shape: torch.Size([4, 2048]), numel: 8192
96
- # # print(f"vit_embeds shape: {vit_embeds.shape}, numel: {vit_embeds.numel()}") # vit_embeds shape: torch.Size([9728, 2048]), numel: 19922944
97
- # # print(f"similarity shape: {similarity.shape}, numel: {similarity.numel()}")# similarity shape: torch.Size([4, 9728]), numel: 38912
98
-
99
-
100
- # # 生成可视化
101
- # attn_map = similarity.reshape(len(text_embeds), resized_size[0], resized_size[1])
102
- # # attn_map = similarity.reshape(len(text_embeds), *target_ratio)
103
- # all_bpe_strings = [tokenizer.decode(input_id) for input_id in input_ids]
104
- # current_vis = generate_similiarity_map([image], attn_map,
105
- # [tokenizer.decode([i]) for i in input_ids],
106
- # [], target_ratio, src_size)
107
-
108
- # current_bpe = [tokenizer.decode([i]) for i in input_ids]
109
- # # current_bpe[-1] = 'Input text'
110
- # current_bpe[-1] = text
111
- # print("current_vis",len(current_vis))
112
- # print("current_bpe",len(current_bpe))
113
- # return image, current_vis[0], current_bpe[0]
114
-
115
- # # 事件处理函数
116
- # def update_index(change):
117
- # global current_vis, current_bpe, current_index
118
- # current_index = max(0, min(len(current_vis) - 1, current_index + change))
119
- # return current_vis[current_index], format_bpe_display(current_bpe[current_index])
120
-
121
- # def format_bpe_display(bpe):
122
- # # 使用HTML标签来设置字体大小、颜色,加粗,并居中
123
- # return f"<div style='text-align:center; font-size:20px;'><strong>Current BPE: <span style='color:red;'>{bpe}</span></strong></div>"
124
-
125
- # def update_slider_index(x):
126
- # print(f"x: {x}, current_vis length: {len(current_vis)}, current_bpe length: {len(current_bpe)}")
127
- # if 0 <= x < len(current_vis) and 0 <= x < len(current_bpe):
128
- # return current_vis[x], format_bpe_display(current_bpe[x])
129
- # else:
130
- # return None, "索引超出范围"
131
-
132
- # # Gradio界面
133
- # with gr.Blocks(title="BPE Visualization Demo") as demo:
134
- # gr.Markdown("## BPE Visualization Demo - TokenFD基座模型能力可视化")
135
-
136
- # with gr.Row():
137
- # with gr.Column(scale=0.5):
138
- # model_type = gr.Dropdown(
139
- # choices=["TokenFD_4096_English_seg", "TokenFD_2048_Bilingual_seg", "R50", "R50_siglip"],
140
- # label="Select model type",
141
- # value="TokenOCR_4096_English_seg" # 设置默认值为第一个选项
142
- # )
143
- # image_input = gr.Image(label="Upload images", type="pil")
144
- # text_input = gr.Textbox(label="Input text")
145
-
146
- # run_btn = gr.Button("RUN")
147
-
148
- # gr.Examples(
149
- # examples=[
150
- # [os.path.join("examples", "examples0.jpg"), "Veterans and Benefits"],
151
- # [os.path.join("examples", "examples1.jpg"), "Refreshers"],
152
- # [os.path.join("examples", "examples2.png"), "Vision Transformer"]
153
- # ],
154
- # inputs=[image_input, text_input],
155
- # label="Sample input"
156
- # )
157
-
158
- # with gr.Column(scale=2):
159
- # gr.Markdown("<p style='font-size:20px;'><span style='color:red;'>If the input text is not included in the image</span>, the attention map will show a lot of noise (the actual response value is very low), since we normalize the attention map according to the relative value.</p>")
160
-
161
- # with gr.Row():
162
- # orig_img = gr.Image(label="Original picture", interactive=False)
163
- # heatmap = gr.Image(label="BPE visualization", interactive=False)
164
-
165
- # with gr.Row() as controls:
166
- # prev_btn = gr.Button("⬅ Last", visible=False)
167
- # index_slider = gr.Slider(0, 1, value=0, step=1, label="BPE index", visible=False)
168
- # next_btn = gr.Button("⮕ Next", visible=False)
169
-
170
- # bpe_display = gr.Markdown("Current BPE: ", visible=False)
171
-
172
- # # 事件处理
173
- # @spaces.GPU
174
- # def on_run_clicked(model_type, image, text):
175
- # global current_vis, current_bpe, current_index
176
- # current_index = 0 # Reset index when new image is processed
177
- # image, vis, bpe = process_image(*load_model(model_type), model_type, image, text)
178
- # # Update the slider range and set value to 0
179
- # slider_max_val = len(current_bpe) - 1
180
- # bpe_text = format_bpe_display(bpe)
181
- # print("current_vis",len(current_vis))
182
- # print("current_bpe",len(current_bpe))
183
- # return image, vis, bpe_text, slider_max_val
184
-
185
- # run_btn.click(
186
- # on_run_clicked,
187
- # inputs=[model_type, image_input, text_input],
188
- # outputs=[orig_img, heatmap, bpe_display, index_slider],
189
- # ).then(
190
- # lambda max_val: (gr.update(visible=True), gr.update(visible=True, maximum=max_val, value=0), gr.update(visible=True), gr.update(visible=True)),
191
- # inputs=index_slider,
192
- # outputs=[prev_btn, index_slider, next_btn, bpe_display],
193
- # )
194
-
195
- # prev_btn.click(
196
- # lambda: (*update_index(-1), current_index),
197
- # outputs=[heatmap, bpe_display, index_slider]
198
- # )
199
-
200
- # next_btn.click(
201
- # lambda: (*update_index(1), current_index),
202
- # outputs=[heatmap, bpe_display, index_slider]
203
- # )
204
-
205
- # # index_slider.change(
206
- # # lambda x: (current_vis[x], format_bpe_display(current_bpe[x])) if 0<=x<len(current_vis else (None,"Invaild")
207
- # # inputs=index_slider,
208
- # # outputs=[heatmap, bpe_display]
209
- # # )
210
-
211
- # index_slider.change(
212
- # update_slider_index,
213
- # inputs=index_slider,
214
- # outputs=[heatmap, bpe_display]
215
- # )
216
-
217
- # if __name__ == "__main__":
218
- # demo.launch()
219
-
220
  import os
221
  import argparse
222
  import numpy as np
@@ -240,8 +21,9 @@ CHECKPOINTS = {
240
 
241
  # 全局变量
242
  HF_TOKEN = os.getenv("HF_TOKEN")
243
- current_vis = []
244
- current_bpe = []
 
245
 
246
  def load_model(check_type):
247
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@@ -271,7 +53,7 @@ def load_model(check_type):
271
  return model.to(device), tokenizer, transform, device
272
 
273
  def process_image(model, tokenizer, transform, device, check_type, image, text):
274
- global current_vis, current_bpe
275
  src_size = image.size
276
  if 'TokenOCR' in check_type:
277
  images, target_ratio = dynamic_preprocess(image, min_num=1, max_num=12,
@@ -313,11 +95,17 @@ def process_image(model, tokenizer, transform, device, check_type, image, text):
313
 
314
  current_bpe = [tokenizer.decode([i]) for i in input_ids]
315
  current_bpe[-1] = text
316
- return image, current_vis, current_bpe
 
317
 
318
  def format_bpe_display(bpe):
319
  return f"<div style='text-align:center; font-size:20px;'><strong>Current BPE: <span style='color:red;'>{bpe}</span></strong></div>"
320
 
 
 
 
 
 
321
  # Gradio界面
322
  with gr.Blocks(title="BPE Visualization Demo") as demo:
323
  gr.Markdown("## BPE Visualization Demo - TokenFD基座模型能力可视化")
@@ -351,23 +139,33 @@ with gr.Blocks(title="BPE Visualization Demo") as demo:
351
  orig_img = gr.Image(label="Original picture", interactive=False)
352
  heatmap = gr.Image(label="BPE visualization", interactive=False)
353
 
354
- bpe_display = gr.Markdown("Current BPE: ", visible=False)
 
 
 
 
355
 
356
  # 事件处理
357
  @spaces.GPU
358
  def on_run_clicked(model_type, image, text):
359
- global current_vis, current_bpe
360
  image, vis, bpe = process_image(*load_model(model_type), model_type, image, text)
361
- bpe_text = format_bpe_display(bpe)
362
- return image, vis[0], bpe_text
363
 
364
  run_btn.click(
365
  on_run_clicked,
366
  inputs=[model_type, image_input, text_input],
367
  outputs=[orig_img, heatmap, bpe_display],
368
- ).then(
369
- lambda: (gr.update(visible=True)),
370
- outputs=[bpe_display],
 
 
 
 
 
 
 
371
  )
372
 
373
  if __name__ == "__main__":
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
  import argparse
3
  import numpy as np
 
21
 
22
  # 全局变量
23
  HF_TOKEN = os.getenv("HF_TOKEN")
24
+ current_vis = [] # 存储所有 heatmap
25
+ current_bpe = [] # 存储所有 BPE
26
+ current_index = 0 # 当前显示的 heatmap 和 BPE 的索引
27
 
28
  def load_model(check_type):
29
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
53
  return model.to(device), tokenizer, transform, device
54
 
55
  def process_image(model, tokenizer, transform, device, check_type, image, text):
56
+ global current_vis, current_bpe, current_index
57
  src_size = image.size
58
  if 'TokenOCR' in check_type:
59
  images, target_ratio = dynamic_preprocess(image, min_num=1, max_num=12,
 
95
 
96
  current_bpe = [tokenizer.decode([i]) for i in input_ids]
97
  current_bpe[-1] = text
98
+ current_index = 0 # 重置索引
99
+ return image, current_vis[current_index], format_bpe_display(current_bpe[current_index])
100
 
101
  def format_bpe_display(bpe):
102
  return f"<div style='text-align:center; font-size:20px;'><strong>Current BPE: <span style='color:red;'>{bpe}</span></strong></div>"
103
 
104
+ def update_index(change):
105
+ global current_vis, current_bpe, current_index
106
+ current_index = max(0, min(len(current_vis) - 1, current_index + change))
107
+ return current_vis[current_index], format_bpe_display(current_bpe[current_index])
108
+
109
  # Gradio界面
110
  with gr.Blocks(title="BPE Visualization Demo") as demo:
111
  gr.Markdown("## BPE Visualization Demo - TokenFD基座模型能力可视化")
 
139
  orig_img = gr.Image(label="Original picture", interactive=False)
140
  heatmap = gr.Image(label="BPE visualization", interactive=False)
141
 
142
+ with gr.Row():
143
+ prev_btn = gr.Button("⬅ Previous")
144
+ next_btn = gr.Button("Next ⮕")
145
+
146
+ bpe_display = gr.Markdown("Current BPE: ", visible=True)
147
 
148
  # 事件处理
149
  @spaces.GPU
150
  def on_run_clicked(model_type, image, text):
151
+ global current_vis, current_bpe, current_index
152
  image, vis, bpe = process_image(*load_model(model_type), model_type, image, text)
153
+ return image, vis, bpe
 
154
 
155
  run_btn.click(
156
  on_run_clicked,
157
  inputs=[model_type, image_input, text_input],
158
  outputs=[orig_img, heatmap, bpe_display],
159
+ )
160
+
161
+ prev_btn.click(
162
+ lambda: update_index(-1),
163
+ outputs=[heatmap, bpe_display]
164
+ )
165
+
166
+ next_btn.click(
167
+ lambda: update_index(1),
168
+ outputs=[heatmap, bpe_display]
169
  )
170
 
171
  if __name__ == "__main__":