chendl commited on
Commit
d42bf88
·
1 Parent(s): 768ab84

update cap

Browse files
app.py CHANGED
@@ -2,7 +2,7 @@ import os
2
  import sys
3
  from pathlib import Path
4
  # os.system("cd transformers && pip install .")
5
- os.system("cd multimodal && pip install .")
6
  os.system("cd multimodal/YOLOX && pip install .")
7
  import numpy as np
8
  import torch
@@ -233,21 +233,42 @@ def upload_img(gr_img, text_input, chat_state, chatbot):
233
  path = build_image(gr_img)
234
  chatbot = chatbot + [[(path,), None]]
235
  llm_message = chat.upload_img(gr_img, chat_state, img_list)
236
- return gr.update(interactive=False), gr.update(interactive=True, placeholder='Type and press Enter'), gr.update(
237
  value="Start Chatting", interactive=False), chat_state, img_list, chatbot
238
 
239
 
240
-
241
- def gradio_ask(user_message, chatbot, chat_state,radio):
242
  # if len(user_message) == 0:
243
  # return gr.update(interactive=True, placeholder='Input should not be empty!'), chatbot, chat_state
244
 
245
-
246
- chat.ask(user_message, chat_state,radio,model_name)
247
  chatbot = chatbot + [[user_message, None]]
248
  return chatbot, chat_state
249
 
250
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
251
  def gradio_answer(chatbot, chat_state, img_list, radio, text, num_beams, temperature):
252
  image = None
253
  llm_message, image = \
@@ -325,10 +346,15 @@ with gr.Blocks() as demo:
325
  # submit_button.click(gradio_ask, [text_input, chatbot, chat_state,radio], [chatbot, chat_state]).then(
326
  # gradio_answer, [chatbot, chat_state, img_list, radio, text_input,num_beams, temperature], [text_input,chatbot, chat_state, img_list]
327
  # )
328
- text_input.submit(gradio_ask, [text_input, chatbot, chat_state, radio], [chatbot, chat_state]).then(
329
- gradio_answer, [chatbot, chat_state, img_list, radio, text_input, num_beams, temperature],
330
- [text_input, chatbot, chat_state, img_list]
331
- )
 
 
 
 
 
332
  clear.click(gradio_reset, [chat_state, img_list], [chatbot, image, text_input, upload_button, chat_state, img_list],
333
  queue=False)
334
 
 
2
  import sys
3
  from pathlib import Path
4
  # os.system("cd transformers && pip install .")
5
+ os.system("cd multimodal && pip install -e .")
6
  os.system("cd multimodal/YOLOX && pip install .")
7
  import numpy as np
8
  import torch
 
233
  path = build_image(gr_img)
234
  chatbot = chatbot + [[(path,), None]]
235
  llm_message = chat.upload_img(gr_img, chat_state, img_list)
236
+ return gr.update(interactive=False), gr.Textbox(placeholder='Type and press Enter', interactive=True), gr.update(
237
  value="Start Chatting", interactive=False), chat_state, img_list, chatbot
238
 
239
 
240
+ def gradio_ask(user_message, chatbot, chat_state, radio):
 
241
  # if len(user_message) == 0:
242
  # return gr.update(interactive=True, placeholder='Input should not be empty!'), chatbot, chat_state
243
 
244
+ chat.ask(user_message, chat_state, radio, model_name)
 
245
  chatbot = chatbot + [[user_message, None]]
246
  return chatbot, chat_state
247
 
248
 
249
+ def generate_ans(user_message, chatbot, chat_state, img_list, radio, text, num_beams, temperature):
250
+ # if len(user_message) == 0:
251
+ # return gr.update(interactive=True, placeholder='Input should not be empty!'), chatbot, chat_state
252
+
253
+ chat.ask(user_message, chat_state, radio, model_name)
254
+ chatbot = chatbot + [[user_message, None]]
255
+ # return chatbot, chat_state
256
+ image = None
257
+ llm_message, image = \
258
+ chat.answer(conv=chat_state, img_list=img_list, max_new_tokens=300, num_beams=1, temperature=temperature,
259
+ max_length=2000, radio=radio, text_input=text, model_name=model_name)
260
+
261
+ chatbot[-1][1] = llm_message
262
+ if chat_state[-1]["from"] == "gpt":
263
+ chat_state[-1]["value"] = llm_message
264
+ if image == None:
265
+ return "", chatbot, chat_state, img_list
266
+ else:
267
+ path = build_image(image)
268
+ chatbot = chatbot + [[None, (path,)]]
269
+ return "", chatbot, chat_state, img_list
270
+
271
+
272
  def gradio_answer(chatbot, chat_state, img_list, radio, text, num_beams, temperature):
273
  image = None
274
  llm_message, image = \
 
346
  # submit_button.click(gradio_ask, [text_input, chatbot, chat_state,radio], [chatbot, chat_state]).then(
347
  # gradio_answer, [chatbot, chat_state, img_list, radio, text_input,num_beams, temperature], [text_input,chatbot, chat_state, img_list]
348
  # )
349
+
350
+ text_input.submit(generate_ans,
351
+ [text_input, chatbot, chat_state, img_list, radio, text_input, num_beams, temperature],
352
+ [text_input, chatbot, chat_state, img_list])
353
+
354
+ # text_input.submit(gradio_ask, [text_input, chatbot, chat_state, radio], [chatbot, chat_state]).then(
355
+ # gradio_answer, [chatbot, chat_state, img_list, radio, text_input, num_beams, temperature],
356
+ # [text_input, chatbot, chat_state, img_list]
357
+ # )
358
  clear.click(gradio_reset, [chat_state, img_list], [chatbot, image, text_input, upload_button, chat_state, img_list],
359
  queue=False)
360
 
multimodal/open_flamingo/chat/conversation.py CHANGED
@@ -519,72 +519,4 @@ class Chat:
519
  # return mixed_embs
520
 
521
 
522
- def evaluate_exp(
523
- model,
524
- tokenizer,
525
- image_processor,
526
- vis_embed_size=None,
527
- rank=0,
528
- world_size=1,
529
- id=0,
530
- add_visual=True,
531
- ):
532
- media_token_id = tokenizer("<|#image#|>", add_special_tokens=False)["input_ids"][-1]
533
- box_token_id = tokenizer("<|#box#|>", add_special_tokens=False)["input_ids"][-1]
534
- endofobject_token_id = tokenizer("<|#endofobject#|>", add_special_tokens=False)["input_ids"][-1]
535
- endofattr_token_id = tokenizer("<|#endofattr#|>", add_special_tokens=False)["input_ids"][-1]
536
- endofmedia_token_id = tokenizer("<|#endofimage#|>", add_special_tokens=False)["input_ids"][-1]
537
- visual_token_id = tokenizer("<|#visual#|>", add_special_tokens=False)["input_ids"][-1]
538
- previsual_token_id = tokenizer("<|#previsual#|>", add_special_tokens=False)["input_ids"][-1]
539
- prebox_token_id = tokenizer("<|#prebox#|>", add_special_tokens=False)["input_ids"][-1]
540
- size = image_processor.size["shortest_edge"]
541
- model.eval()
542
- # "/gpfs/u/home/LMCG/LMCGljnn/scratch-shared/cdl/tmp_img/chat_vis/chat19.png"
543
- image_path = input("Please enter the image path: ")
544
- image = Image.open(image_path).convert("RGB")
545
- image = image.resize((size, size))
546
- print(f"image size: {image.size}")
547
- batch_images = preprocess_image(image, image_processor).unsqueeze(0).unsqueeze(1).unsqueeze(0)
548
- conversation = []
549
- human_sentence = None
550
- while True:
551
- human_sentence = input("### Human: ")
552
- if human_sentence == "#end#":
553
- break
554
- conversation.append({
555
- "from": "human",
556
- "value": human_sentence,
557
- })
558
- conversation.append({
559
- "from": "gpt",
560
- "value": "",
561
- })
562
- text = preprocess_conv(conversation).strip()
563
- caption = f"<|#image#|>{tokenizer.pad_token * vis_embed_size}<|#endofimage#|>{text}"
564
- encodings = tokenizer(
565
- caption,
566
- padding="longest",
567
- truncation=True,
568
- return_tensors="pt",
569
- max_length=2000,
570
- )
571
- input_ids = encodings["input_ids"].to("cuda")
572
- attention_mask = encodings["attention_mask"].to("cuda")
573
- image_start_index_list = ((input_ids == media_token_id).nonzero(as_tuple=True)[-1] + 1).tolist()
574
- image_start_index_list = [[x] for x in image_start_index_list]
575
- image_nums = [1] * len(input_ids)
576
- with torch.no_grad() and torch.cuda.amp.autocast(dtype=torch.float16):
577
- outputs = model.generate(
578
- batch_images,
579
- input_ids,
580
- attention_mask=attention_mask,
581
- max_new_tokens=100,
582
- # min_new_tokens=8,
583
- num_beams=1,
584
- image_start_index_list=image_start_index_list,
585
- image_nums=image_nums,
586
- )
587
- print(f"### Assistant: {tokenizer.decode(outputs[0, input_ids.shape[1]:], skip_special_tokens=True).strip()}")
588
-
589
-
590
 
 
519
  # return mixed_embs
520
 
521
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
522
 
multimodal/open_flamingo/eval/task/caption_chat.py CHANGED
@@ -1,12 +1,14 @@
 
1
  import torch
2
  import more_itertools
3
  from tqdm import tqdm
4
  import json
5
  import time
6
  import os
 
7
  from transformers import LogitsProcessor, MinNewTokensLengthLogitsProcessor, ForcedEOSTokenLogitsProcessor
8
  from PIL import Image
9
-
10
 
11
  class VisualLogitsProcessor(LogitsProcessor):
12
  def __init__(self, tokenizer):
@@ -24,10 +26,7 @@ class VisualLogitsProcessor(LogitsProcessor):
24
  def __call__(self, input_ids, scores):
25
  # print("decoding===>", self.tokenizer.decode(scores.sort(descending=True).indices.tolist()[0][:self.topk]))
26
  # import pdb; pdb.set_trace()
27
- if self.object_token_id in scores.sort(descending=True).indices.tolist()[0][
28
- 1:self.topk] and self.eos_token_id not in \
29
- scores.sort(descending=True).indices.tolist()[0][:self.topk] and (
30
- input_ids == self.object_token_id).sum() * 2 == (input_ids == self.endofobject_token_id).sum():
31
  scores[0, self.object_token_id] = 1000
32
  if input_ids[0, -1] == self.object_token_id and input_ids[0, -2] != self.prebox_token_id:
33
  if (input_ids[0, :-1] == self.object_token_id).sum() != 0:
@@ -53,13 +52,165 @@ def prepare_batch_images(batch, image_processor):
53
  return batch_images
54
 
55
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  def captioner(
57
  model, tokenizer, image_ori, batch_images, input_ids, attention_mask, image_start_index_list, image_nums,
58
  added_bbox_list, debug=True):
59
  """Evaluate a model on COCO dataset.
60
  Returns:
61
  float: CIDEr score
62
-
63
  """
64
  visual_logits_processor = VisualLogitsProcessor(tokenizer)
65
  model.eval()
@@ -80,125 +231,131 @@ def captioner(
80
  prompt = None
81
  out_image = None
82
  no_end = True
83
- while no_end:
84
- batch_images = batch_images
85
- if prompt == None:
86
- input_ids = input_ids
87
- attention_mask = attention_mask
88
- else:
89
- encodings = tokenizer(
90
- [prompt],
91
- padding="longest",
92
- truncation=True,
93
- return_tensors="pt",
94
- max_length=2000,
95
- )
96
- attention_mask = encodings["attention_mask"]
97
- input_ids = encodings["input_ids"]
98
- image_start_index_list = image_start_index_list
99
- image_nums = image_nums
100
- if debug:
101
- print("input--->", tokenizer.decode(input_ids[0]))
102
- p1 = MinNewTokensLengthLogitsProcessor(
103
- prompt_length_to_skip=input_ids.shape[-1],
104
- min_new_tokens=5,
105
- eos_token_id=bos_token_id,
106
- )
107
- with torch.inference_mode():
108
- outputs = model.generate(
109
- batch_images,
110
- input_ids,
111
- attention_mask=attention_mask,
112
- max_new_tokens=20,
113
- # min_new_tokens=8,
114
- num_beams=1,
115
- # length_penalty=0,
116
- image_start_index_list=image_start_index_list,
117
- image_nums=image_nums,
118
- added_bbox_list=added_bbox_list if len(added_bbox_list) != 0 else None,
119
- logits_processor_list=[p1, visual_logits_processor],
120
- )
121
- if debug:
122
- print("outputs--->", tokenizer.decode(outputs[0]))
123
- if outputs[0, -2] in [previsual_token_id, visual_token_id] and outputs[0, -1] == bos_token_id:
124
- prompt = tokenizer.decode(outputs.clone()[0])
125
- is_visual = (outputs[0, -2] == visual_token_id)
126
- batch_text = tokenizer.batch_decode(outputs[:, :-1])
127
- encodings = tokenizer(
128
- batch_text,
129
- padding="longest",
130
- truncation=True,
131
- return_tensors="pt",
132
- max_length=2000,
133
- )
134
- input_ids = encodings["input_ids"]
135
- attention_mask = encodings["attention_mask"]
136
- image_start_index_list = ((input_ids == media_token_id).nonzero(as_tuple=True)[-1] + 1).tolist()
137
- image_start_index_list = [[x] for x in image_start_index_list]
138
- image_nums = [1] * len(input_ids)
139
  if debug:
140
- print("get the visual bbox--->", tokenizer.decode(input_ids[0]))
141
- with torch.no_grad():
142
- outputs = model(
143
- vision_x=batch_images,
144
- lang_x=input_ids,
 
 
 
 
 
145
  attention_mask=attention_mask,
146
- image_nums=image_nums,
 
 
 
147
  image_start_index_list=image_start_index_list,
 
148
  added_bbox_list=added_bbox_list if len(added_bbox_list) != 0 else None,
149
- add_box=added_bbox_list is not None and len(added_bbox_list) != 0,
150
  )
151
- boxes = outputs["boxes"]
152
- scores = outputs["scores"]
153
  if debug:
154
- print("box num---->", len(boxes))
155
- # if not model.valid:
156
- # import pdb; pdb.set_trace()
157
- if boxes is not None:
158
- if is_visual:
159
- if have_prebox:
160
- added_bbox_list.pop()
161
- prompt = prompt.replace("<|#previsual#|><|#prebox#|><|#object#|>", "")
162
- have_prebox = False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
163
  if debug:
164
- print("find previsual and remove it--->", prompt)
165
- first_box = boxes[scores.argmax()]
166
- added_bbox_list += [torch.tensor(first_box).unsqueeze(0) / 224]
167
- prompt = prompt[:-len(tokenizer.eos_token)]
168
- prompt += box_token + endofobject_token
169
- if debug:
170
- print("after inserting visual---->", prompt)
171
 
172
- else:
173
- import numpy as np
174
- import cv2
175
 
176
- # exit()
177
- pre_box = boxes[scores.argmax()]
178
- added_bbox_list += [torch.tensor(pre_box).unsqueeze(0).cuda() / 224]
179
- prompt = prompt[:-len(tokenizer.eos_token)]
180
- prompt += prebox_token + object_token
181
- have_prebox = True
 
 
 
 
 
 
182
  if debug:
183
- print("after inserting previsual---->", prompt)
 
 
 
184
  else:
185
- # if debug:
186
- # import pdb;pdb.set_trace()
187
- prompt = tokenizer.decode(outputs.clone()[0])
188
- if debug:
189
- print("before else---->", prompt)
190
- prompt = tokenizer.decode(outputs[0, :-2].clone()[0])
191
- if debug:
192
- print("after else---->", prompt)
193
- else:
194
- no_end = False
195
  outputs = outputs[:, ori_prompt_length:]
196
  outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0].replace('"', "")
197
  open_cv_image = np.array(image_ori)
198
  open_cv_image = open_cv_image[:, :, ::-1].copy()
 
 
199
  for i, pre_box in enumerate(added_bbox_list):
200
- open_cv_image = cv2.rectangle(open_cv_image, (pre_box[:2] * 224).astype(int), (pre_box[2:] * 224).astype(int),
 
 
201
  (0, 255, 0), i + 1)
 
202
  out_image = Image.fromarray(cv2.cvtColor(open_cv_image, cv2.COLOR_BGR2RGB))
203
  # new_predictions = [
204
  # postprocess_captioning_generation(out).replace('"', "")
@@ -206,6 +363,4 @@ def captioner(
206
  # ]
207
  # import pdb; pdb.set_trace()
208
 
209
- return outputs, out_image
210
-
211
-
 
1
+
2
  import torch
3
  import more_itertools
4
  from tqdm import tqdm
5
  import json
6
  import time
7
  import os
8
+ import numpy as np
9
  from transformers import LogitsProcessor, MinNewTokensLengthLogitsProcessor, ForcedEOSTokenLogitsProcessor
10
  from PIL import Image
11
+ import cv2
12
 
13
  class VisualLogitsProcessor(LogitsProcessor):
14
  def __init__(self, tokenizer):
 
26
  def __call__(self, input_ids, scores):
27
  # print("decoding===>", self.tokenizer.decode(scores.sort(descending=True).indices.tolist()[0][:self.topk]))
28
  # import pdb; pdb.set_trace()
29
+ if self.object_token_id in scores.sort(descending=True).indices.tolist()[0][1:self.topk] and self.eos_token_id not in scores.sort(descending=True).indices.tolist()[0][:self.topk] and (input_ids == self.object_token_id).sum() * 2 == (input_ids == self.endofobject_token_id).sum():
 
 
 
30
  scores[0, self.object_token_id] = 1000
31
  if input_ids[0, -1] == self.object_token_id and input_ids[0, -2] != self.prebox_token_id:
32
  if (input_ids[0, :-1] == self.object_token_id).sum() != 0:
 
52
  return batch_images
53
 
54
 
55
+ # def captioner(
56
+ # model, tokenizer, image_ori, batch_images, input_ids, attention_mask, image_start_index_list, image_nums,
57
+ # added_bbox_list, debug=True):
58
+ # """Evaluate a model on COCO dataset.
59
+ # Returns:
60
+ # float: CIDEr score
61
+ #
62
+ # """
63
+ # visual_logits_processor = VisualLogitsProcessor(tokenizer)
64
+ # model.eval()
65
+ # # model.eval().cuda()
66
+ # lang_encoder_name = model.lang_encoder.__class__.__name__.lower()
67
+ # media_token_id = tokenizer("<|#image#|>", add_special_tokens=False)["input_ids"][-1]
68
+ # endofmedia_token_id = tokenizer("<|#endofimage#|>", add_special_tokens=False)["input_ids"][-1]
69
+ # pad_token_id = tokenizer(tokenizer.pad_token, add_special_tokens=False)["input_ids"][-1]
70
+ # bos_token_id = tokenizer(tokenizer.bos_token, add_special_tokens=False)["input_ids"][-1]
71
+ # previsual_token_id = tokenizer("<|#previsual#|>", add_special_tokens=False)["input_ids"][-1]
72
+ # visual_token_id = tokenizer("<|#visual#|>", add_special_tokens=False)["input_ids"][-1]
73
+ # box_token = "<|#box#|>"
74
+ # prebox_token = "<|#prebox#|>"
75
+ # endofobject_token = "<|#endofobject#|>"
76
+ # object_token = "<|#object#|>"
77
+ # ori_prompt_length = len(input_ids[0])
78
+ # have_prebox = False
79
+ # prompt = None
80
+ # out_image = None
81
+ # no_end = True
82
+ # for i in range(500):
83
+ # if no_end:
84
+ # batch_images = batch_images
85
+ # if prompt == None:
86
+ # input_ids = input_ids
87
+ # attention_mask = attention_mask
88
+ # else:
89
+ # encodings = tokenizer(
90
+ # [prompt],
91
+ # padding="longest",
92
+ # truncation=True,
93
+ # return_tensors="pt",
94
+ # max_length=2000,
95
+ # )
96
+ # attention_mask = encodings["attention_mask"]
97
+ # input_ids = encodings["input_ids"]
98
+ # image_start_index_list = image_start_index_list
99
+ # image_nums = image_nums
100
+ # if debug:
101
+ # print("input--->", tokenizer.decode(input_ids[0]))
102
+ # p1 = MinNewTokensLengthLogitsProcessor(
103
+ # prompt_length_to_skip=input_ids.shape[-1],
104
+ # min_new_tokens=5,
105
+ # eos_token_id=bos_token_id,
106
+ # )
107
+ # with torch.inference_mode():
108
+ # outputs = model.generate(
109
+ # batch_images,
110
+ # input_ids,
111
+ # attention_mask=attention_mask,
112
+ # max_new_tokens=20,
113
+ # # min_new_tokens=8,
114
+ # num_beams=1,
115
+ # # length_penalty=0,
116
+ # image_start_index_list=image_start_index_list,
117
+ # image_nums=image_nums,
118
+ # added_bbox_list=added_bbox_list if len(added_bbox_list) != 0 else None,
119
+ # logits_processor_list=[p1, visual_logits_processor],
120
+ # )
121
+ # if debug:
122
+ # print("outputs--->", tokenizer.decode(outputs[0]))
123
+ # input_ids = encodings["input_ids"]
124
+ # attention_mask = encodings["attention_mask"]
125
+ # image_start_index_list = ((input_ids == media_token_id).nonzero(as_tuple=True)[-1] + 1).tolist()
126
+ # image_start_index_list = [[x] for x in image_start_index_list]
127
+ # image_nums = [1] * len(input_ids)
128
+ # if debug:
129
+ # print("get the visual bbox--->", tokenizer.decode(input_ids[0]))
130
+ # with torch.no_grad():
131
+ # outputs = model(
132
+ # vision_x=batch_images,
133
+ # lang_x=input_ids,
134
+ # attention_mask=attention_mask,
135
+ # image_nums=image_nums,
136
+ # image_start_index_list=image_start_index_list,
137
+ # added_bbox_list=added_bbox_list if len(added_bbox_list) != 0 else None,
138
+ # add_box=added_bbox_list is not None and len(added_bbox_list) != 0,
139
+ # )
140
+ # boxes = outputs["boxes"]
141
+ # scores = outputs["scores"]
142
+ # if debug:
143
+ # print("box num---->", len(boxes))
144
+ # # if not model.valid:
145
+ # # import pdb; pdb.set_trace()
146
+ # if boxes is not None:
147
+ # if is_visual:
148
+ # if have_prebox:
149
+ # added_bbox_list.pop()
150
+ # prompt = prompt.replace("<|#previsual#|><|#prebox#|><|#object#|>", "")
151
+ # have_prebox = False
152
+ # if debug:
153
+ # print("find previsual and remove it--->", prompt)
154
+ # first_box = boxes[scores.argmax()]
155
+ # added_bbox_list += [torch.tensor(first_box).unsqueeze(0) / 224]
156
+ # prompt = prompt[:-len(tokenizer.eos_token)]
157
+ # prompt += box_token + endofobject_token
158
+ # if debug:
159
+ # print("after inserting visual---->", prompt)
160
+ #
161
+ # else:
162
+ # import numpy as np
163
+ # import cv2
164
+ #
165
+ # # exit()
166
+ # pre_box = boxes[scores.argmax()]
167
+ # added_bbox_list += [torch.tensor(pre_box).unsqueeze(0) / 224]
168
+ # prompt = prompt[:-len(tokenizer.eos_token)]
169
+ # prompt += prebox_token + object_token
170
+ # have_prebox = True
171
+ # if debug:
172
+ # print("after inserting previsual---->", prompt)
173
+ # else:
174
+ # # if debug:
175
+ # # import pdb;pdb.set_trace()
176
+ # prompt = tokenizer.decode(outputs.clone()[0])
177
+ # if debug:
178
+ # print("before else---->", prompt)
179
+ # prompt = tokenizer.decode(outputs[0, :-2].clone()[0])
180
+ # if debug:
181
+ # print("after else---->", prompt)
182
+ #
183
+ # else:
184
+ # no_end = False
185
+ # # break
186
+ # # print("outputs--->", tokenizer.decode(outputs[0]))
187
+ # outputs = outputs[:, ori_prompt_length:]
188
+ # outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0].replace('"', "")
189
+ # open_cv_image = np.array(image_ori)
190
+ # open_cv_image = open_cv_image[:, :, ::-1].copy()
191
+ # width = image_ori.width
192
+ # height = image_ori.height
193
+ # for i, pre_box in enumerate(added_bbox_list):
194
+ # open_cv_image = cv2.rectangle(open_cv_image, np.array(pre_box[0][:2]*[width,height]).astype(int), np.array(pre_box[0][2:]*[width,height]).astype(int),
195
+ # (0, 255, 0), i + 1)
196
+ # out_image = Image.fromarray(cv2.cvtColor(open_cv_image, cv2.COLOR_BGR2RGB))
197
+ # # new_predictions = [
198
+ # # postprocess_captioning_generation(out).replace('"', "")
199
+ # # for out in tokenizer.batch_decode(outputs, skip_special_tokens=True)
200
+ # # ]
201
+ # # import pdb; pdb.set_trace()
202
+ #
203
+ # return outputs, out_image
204
+
205
+
206
+
207
+
208
  def captioner(
209
  model, tokenizer, image_ori, batch_images, input_ids, attention_mask, image_start_index_list, image_nums,
210
  added_bbox_list, debug=True):
211
  """Evaluate a model on COCO dataset.
212
  Returns:
213
  float: CIDEr score
 
214
  """
215
  visual_logits_processor = VisualLogitsProcessor(tokenizer)
216
  model.eval()
 
231
  prompt = None
232
  out_image = None
233
  no_end = True
234
+ for i in range(100):
235
+ if no_end:
236
+ batch_images = batch_images
237
+ if prompt == None:
238
+ input_ids = input_ids
239
+ attention_mask = attention_mask
240
+ else:
241
+ encodings = tokenizer(
242
+ [prompt],
243
+ padding="longest",
244
+ truncation=True,
245
+ return_tensors="pt",
246
+ max_length=2000,
247
+ )
248
+ attention_mask = encodings["attention_mask"]
249
+ input_ids = encodings["input_ids"]
250
+ image_start_index_list = image_start_index_list
251
+ image_nums = image_nums
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
252
  if debug:
253
+ print("input--->", tokenizer.decode(input_ids[0]))
254
+ p1 = MinNewTokensLengthLogitsProcessor(
255
+ prompt_length_to_skip=input_ids.shape[-1],
256
+ min_new_tokens=5,
257
+ eos_token_id=bos_token_id,
258
+ )
259
+ with torch.inference_mode():
260
+ outputs = model.generate(
261
+ batch_images,
262
+ input_ids,
263
  attention_mask=attention_mask,
264
+ max_new_tokens=20,
265
+ # min_new_tokens=8,
266
+ num_beams=1,
267
+ # length_penalty=0,
268
  image_start_index_list=image_start_index_list,
269
+ image_nums=image_nums,
270
  added_bbox_list=added_bbox_list if len(added_bbox_list) != 0 else None,
271
+ logits_processor_list=[p1, visual_logits_processor],
272
  )
 
 
273
  if debug:
274
+ print("outputs--->", tokenizer.decode(outputs[0]))
275
+ if outputs[0, -2] in [previsual_token_id, visual_token_id] and outputs[0, -1] == bos_token_id:
276
+ prompt = tokenizer.decode(outputs.clone()[0])
277
+ is_visual = (outputs[0, -2] == visual_token_id)
278
+ batch_text = tokenizer.batch_decode(outputs[:, :-1])
279
+ encodings = tokenizer(
280
+ batch_text,
281
+ padding="longest",
282
+ truncation=True,
283
+ return_tensors="pt",
284
+ max_length=2000,
285
+ )
286
+ input_ids = encodings["input_ids"]
287
+ attention_mask = encodings["attention_mask"]
288
+ image_start_index_list = ((input_ids == media_token_id).nonzero(as_tuple=True)[-1] + 1).tolist()
289
+ image_start_index_list = [[x] for x in image_start_index_list]
290
+ image_nums = [1] * len(input_ids)
291
+ if debug:
292
+ print("get the visual bbox--->", tokenizer.decode(input_ids[0]))
293
+ with torch.no_grad():
294
+ outputs = model(
295
+ vision_x=batch_images,
296
+ lang_x=input_ids,
297
+ attention_mask=attention_mask,
298
+ image_nums=image_nums,
299
+ image_start_index_list=image_start_index_list,
300
+ added_bbox_list=added_bbox_list if len(added_bbox_list) != 0 else None,
301
+ add_box=added_bbox_list is not None and len(added_bbox_list) != 0,
302
+ )
303
+ boxes = outputs["boxes"]
304
+ scores = outputs["scores"]
305
+ if debug:
306
+ print("box num---->", len(boxes))
307
+ # if not model.valid:
308
+ # import pdb; pdb.set_trace()
309
+ if boxes is not None:
310
+ if is_visual:
311
+ if have_prebox:
312
+ added_bbox_list.pop()
313
+ prompt = prompt.replace("<|#previsual#|><|#prebox#|><|#object#|>", "")
314
+ have_prebox = False
315
+ if debug:
316
+ print("find previsual and remove it--->", prompt)
317
+ first_box = boxes[scores.argmax()]
318
+ added_bbox_list += [torch.tensor(first_box).unsqueeze(0) / 224]
319
+ prompt = prompt[:-len(tokenizer.eos_token)]
320
+ prompt += box_token + endofobject_token
321
  if debug:
322
+ print("after inserting visual---->", prompt)
 
 
 
 
 
 
323
 
324
+ else:
325
+ import numpy as np
326
+ import cv2
327
 
328
+ # exit()
329
+ pre_box = boxes[scores.argmax()]
330
+ added_bbox_list += [torch.tensor(pre_box).unsqueeze(0) / 224]
331
+ prompt = prompt[:-len(tokenizer.eos_token)]
332
+ prompt += prebox_token + object_token
333
+ have_prebox = True
334
+ if debug:
335
+ print("after inserting previsual---->", prompt)
336
+ else:
337
+ # if debug:
338
+ # import pdb;pdb.set_trace()
339
+ prompt = tokenizer.decode(outputs.clone()[0])
340
  if debug:
341
+ print("before else---->", prompt)
342
+ prompt = tokenizer.decode(outputs[0, :-2].clone()[0])
343
+ if debug:
344
+ print("after else---->", prompt)
345
  else:
346
+ no_end = False
 
 
 
 
 
 
 
 
 
347
  outputs = outputs[:, ori_prompt_length:]
348
  outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0].replace('"', "")
349
  open_cv_image = np.array(image_ori)
350
  open_cv_image = open_cv_image[:, :, ::-1].copy()
351
+ width = image_ori.width
352
+ height = image_ori.height
353
  for i, pre_box in enumerate(added_bbox_list):
354
+ print(pre_box)
355
+ open_cv_image = cv2.rectangle(open_cv_image, (np.array(pre_box[0][:2]) * [width, height]).astype(int),
356
+ (np.array(pre_box[0][2:]) * [width, height]).astype(int),
357
  (0, 255, 0), i + 1)
358
+
359
  out_image = Image.fromarray(cv2.cvtColor(open_cv_image, cv2.COLOR_BGR2RGB))
360
  # new_predictions = [
361
  # postprocess_captioning_generation(out).replace('"', "")
 
363
  # ]
364
  # import pdb; pdb.set_trace()
365
 
366
+ return outputs, out_image