chendl commited on
Commit
e939d7a
·
1 Parent(s): df58d6d

update cap

Browse files
multimodal/open_flamingo/chat/conversation.py CHANGED
@@ -19,7 +19,7 @@ import gradio as gr
19
  from huggingface_hub import hf_hub_download, login
20
 
21
  from open_flamingo.src.factory import create_model_and_transforms
22
- from open_flamingo.eval.task.caption import captioner
23
 
24
  class SeparatorStyle(Enum):
25
  """Different separator style."""
 
19
  from huggingface_hub import hf_hub_download, login
20
 
21
  from open_flamingo.src.factory import create_model_and_transforms
22
+ from open_flamingo.eval.task.caption_chat import captioner
23
 
24
  class SeparatorStyle(Enum):
25
  """Different separator style."""
multimodal/open_flamingo/eval/task/caption_chat.py ADDED
@@ -0,0 +1,417 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+ import more_itertools
4
+ from tqdm import tqdm
5
+ import json
6
+ import time
7
+ import os
8
+ from transformers import LogitsProcessor, MinNewTokensLengthLogitsProcessor, ForcedEOSTokenLogitsProcessor
9
+ from PIL import Image
10
+
11
+ class VisualLogitsProcessor(LogitsProcessor):
12
+ def __init__(self, tokenizer):
13
+ super().__init__()
14
+ self.tokenizer = tokenizer
15
+ self.object_token_id = self.tokenizer("<|#object#|>", add_special_tokens=False)["input_ids"][-1]
16
+ self.prebox_token_id = self.tokenizer("<|#prebox#|>", add_special_tokens=False)["input_ids"][-1]
17
+ self.box_token_id = self.tokenizer("<|#box#|>", add_special_tokens=False)["input_ids"][-1]
18
+ self.previsual_token_id = self.tokenizer("<|#previsual#|>", add_special_tokens=False)["input_ids"][-1]
19
+ self.visual_token_id = self.tokenizer("<|#visual#|>", add_special_tokens=False)["input_ids"][-1]
20
+ self.eos_token_id = self.tokenizer.encode(self.tokenizer.eos_token)[-1]
21
+ self.endofobject_token_id = self.tokenizer("<|#endofobject#|>", add_special_tokens=False)["input_ids"][-1]
22
+ self.topk = 2
23
+
24
+ def __call__(self, input_ids, scores):
25
+ # print("decoding===>", self.tokenizer.decode(scores.sort(descending=True).indices.tolist()[0][:self.topk]))
26
+ # import pdb; pdb.set_trace()
27
+ if self.object_token_id in scores.sort(descending=True).indices.tolist()[0][1:self.topk] and self.eos_token_id not in scores.sort(descending=True).indices.tolist()[0][:self.topk] and (input_ids == self.object_token_id).sum() * 2 == (input_ids == self.endofobject_token_id).sum():
28
+ scores[0, self.object_token_id] = 1000
29
+ if input_ids[0, -1] == self.object_token_id and input_ids[0, -2] != self.prebox_token_id:
30
+ if (input_ids[0, :-1] == self.object_token_id).sum() != 0:
31
+ # print("generate a previsual token next")
32
+ scores[0, self.previsual_token_id] = 1000
33
+ elif input_ids[0, -1] == self.previsual_token_id or input_ids[0, -1] == self.visual_token_id:
34
+ # print("stop to run bbox generation for " + "previsual" if input_ids[0, -1] == self.previsual_token_id else "visual")
35
+ scores[0, self.eos_token_id] = 1000
36
+ elif input_ids[0, -1] == self.endofobject_token_id and input_ids[0, -2] != self.box_token_id:
37
+ # print("generate a visual token next")
38
+ scores[0, self.visual_token_id] = 1000
39
+ return scores
40
+
41
+
42
+ def prepare_batch_images(batch, image_processor):
43
+ batch_images = None
44
+ for b in batch:
45
+ b_image = image_processor(b["image"]).unsqueeze(0).unsqueeze(1).unsqueeze(0)
46
+ if batch_images is None:
47
+ batch_images = b_image
48
+ else:
49
+ batch_images = torch.cat([batch_images, b_image], dim=0)
50
+ return batch_images
51
+
52
+
53
+ def captioner(
54
+ model,tokenizer,image_ori,batch_images,input_ids,attention_mask,image_start_index_list,image_nums,added_bbox_list,debug=False):
55
+ """Evaluate a model on COCO dataset.
56
+ Returns:
57
+ float: CIDEr score
58
+
59
+ """
60
+ visual_logits_processor = VisualLogitsProcessor(tokenizer)
61
+ model.eval()
62
+ # model.eval().cuda()
63
+ lang_encoder_name = model.lang_encoder.__class__.__name__.lower()
64
+ media_token_id = tokenizer("<|#image#|>", add_special_tokens=False)["input_ids"][-1]
65
+ endofmedia_token_id = tokenizer("<|#endofimage#|>", add_special_tokens=False)["input_ids"][-1]
66
+ pad_token_id = tokenizer(tokenizer.pad_token, add_special_tokens=False)["input_ids"][-1]
67
+ bos_token_id = tokenizer(tokenizer.bos_token, add_special_tokens=False)["input_ids"][-1]
68
+ previsual_token_id = tokenizer("<|#previsual#|>", add_special_tokens=False)["input_ids"][-1]
69
+ visual_token_id = tokenizer("<|#visual#|>", add_special_tokens=False)["input_ids"][-1]
70
+ box_token = "<|#box#|>"
71
+ prebox_token = "<|#prebox#|>"
72
+ endofobject_token = "<|#endofobject#|>"
73
+ object_token = "<|#object#|>"
74
+ ori_prompt_length = len(input_ids[0])
75
+ have_prebox = False
76
+ while True:
77
+ batch_images = batch_images
78
+ input_ids = input_ids
79
+ attention_mask = attention_mask
80
+ image_start_index_list = image_start_index_list
81
+ image_nums = image_nums
82
+ if debug:
83
+ print("input--->",tokenizer.decode(input_ids[0]))
84
+ p1 = MinNewTokensLengthLogitsProcessor(
85
+ prompt_length_to_skip=input_ids.shape[-1],
86
+ min_new_tokens=5,
87
+ eos_token_id=bos_token_id,
88
+ )
89
+ with torch.inference_mode():
90
+ outputs = model.generate(
91
+ batch_images,
92
+ input_ids,
93
+ attention_mask=attention_mask,
94
+ max_new_tokens=20,
95
+ # min_new_tokens=8,
96
+ num_beams=1,
97
+ # length_penalty=0,
98
+ image_start_index_list=image_start_index_list,
99
+ image_nums=image_nums,
100
+ added_bbox_list=added_bbox_list if len(added_bbox_list) != 0 else None,
101
+ logits_processor_list=[p1, visual_logits_processor],
102
+ )
103
+ if debug:
104
+ print("outputs--->",tokenizer.decode(outputs[0]))
105
+ if outputs[0, -2] in [previsual_token_id, visual_token_id] and outputs[0, -1] == bos_token_id:
106
+ prompt = tokenizer.decode(outputs.clone()[0])
107
+ is_visual = (outputs[0, -2] == visual_token_id)
108
+ batch_text = tokenizer.batch_decode(outputs[:, :-1])
109
+ encodings = tokenizer(
110
+ batch_text,
111
+ padding="longest",
112
+ truncation=True,
113
+ return_tensors="pt",
114
+ max_length=2000,
115
+ )
116
+ input_ids = encodings["input_ids"]
117
+ attention_mask = encodings["attention_mask"]
118
+ image_start_index_list = ((input_ids == media_token_id).nonzero(as_tuple=True)[-1] + 1).tolist()
119
+ image_start_index_list = [[x] for x in image_start_index_list]
120
+ image_nums = [1] * len(input_ids)
121
+ if debug:
122
+ print("get the visual bbox--->",tokenizer.decode(input_ids[0]))
123
+ with torch.no_grad():
124
+ outputs = model(
125
+ vision_x=batch_images,
126
+ lang_x=input_ids,
127
+ attention_mask=attention_mask,
128
+ image_nums=image_nums,
129
+ image_start_index_list=image_start_index_list,
130
+ added_bbox_list=added_bbox_list if len(added_bbox_list) != 0 else None,
131
+ add_box=added_bbox_list is not None and len(added_bbox_list) != 0,
132
+ )
133
+ boxes = outputs["boxes"]
134
+ scores = outputs["scores"]
135
+ # if not model.valid:
136
+ # import pdb; pdb.set_trace()
137
+ if boxes is not None:
138
+ if is_visual:
139
+ if have_prebox:
140
+ added_bbox_list.pop()
141
+ prompt = prompt.replace("<|#previsual#|><|#prebox#|><|#object#|>", "")
142
+ have_prebox = False
143
+ if debug:
144
+ print("find previsual and remove it--->", prompt)
145
+ first_box = boxes[scores.argmax()]
146
+ added_bbox_list += [torch.tensor(first_box).unsqueeze(0) / 224]
147
+ prompt = prompt[:-len(tokenizer.eos_token)]
148
+ prompt += box_token + endofobject_token
149
+ if debug:
150
+ print("after inserting visual---->", prompt)
151
+ else:
152
+ import numpy as np
153
+ import cv2
154
+ open_cv_image = np.array(image_ori)
155
+ open_cv_image = open_cv_image[:, :, ::-1].copy()
156
+ for i, pre_box in enumerate(boxes):
157
+ open_cv_image = cv2.rectangle(open_cv_image, pre_box[:2].astype(int), pre_box[2:].astype(int), (0, 255, 0), i+1)
158
+ out_image = Image.fromarray(cv2.cvtColor(open_cv_image, cv2.COLOR_BGR2RGB))
159
+ # exit()
160
+ pre_box = boxes[scores.argmax()]
161
+ added_bbox_list += [torch.tensor(pre_box).unsqueeze(0).cuda() / 224]
162
+ prompt = prompt[:-len(tokenizer.eos_token)]
163
+ prompt += prebox_token + object_token
164
+ have_prebox = True
165
+ if debug:
166
+ print("after inserting previsual---->", prompt)
167
+ else:
168
+ if debug:
169
+ import pdb;pdb.set_trace()
170
+ prompt = tokenizer.decode(outputs[0, :-2].clone()[0])
171
+ else:
172
+ break
173
+ outputs = outputs[:, ori_prompt_length:]
174
+ outputs = postprocess_captioning_generation(tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]).replace('"', "")
175
+ # new_predictions = [
176
+ # postprocess_captioning_generation(out).replace('"', "")
177
+ # for out in tokenizer.batch_decode(outputs, skip_special_tokens=True)
178
+ # ]
179
+ # import pdb; pdb.set_trace()
180
+ return outputs, out_image
181
+
182
+
183
+ def evaluate_coco_flickr(
184
+ model,
185
+ tokenizer,
186
+ image_processor,
187
+ batch_size,
188
+ is_flickr=False,
189
+ vis_embed_size=None,
190
+ rank=0,
191
+ world_size=1,
192
+ id=0,
193
+ debug=False,
194
+ ):
195
+ """Evaluate a model on COCO dataset.
196
+ Returns:
197
+ float: CIDEr score
198
+
199
+ """
200
+ visual_logits_processor = VisualLogitsProcessor(tokenizer)
201
+ coco_dataset = load_dataset("coco_caption")
202
+ eval_dataset = coco_dataset["test"]
203
+ model.eval().cuda()
204
+ predictions = dict()
205
+ lang_encoder_name = model.lang_encoder.__class__.__name__.lower()
206
+ media_token_id = tokenizer("<|#image#|>", add_special_tokens=False)["input_ids"][-1]
207
+ endofmedia_token_id = tokenizer("<|#endofimage#|>", add_special_tokens=False)["input_ids"][-1]
208
+ pad_token_id = tokenizer(tokenizer.pad_token, add_special_tokens=False)["input_ids"][-1]
209
+ bos_token_id = tokenizer(tokenizer.bos_token, add_special_tokens=False)["input_ids"][-1]
210
+ previsual_token_id = tokenizer("<|#previsual#|>", add_special_tokens=False)["input_ids"][-1]
211
+ visual_token_id = tokenizer("<|#visual#|>", add_special_tokens=False)["input_ids"][-1]
212
+ box_token = "<|#box#|>"
213
+ prebox_token = "<|#prebox#|>"
214
+ endofobject_token = "<|#endofobject#|>"
215
+ object_token = "<|#object#|>"
216
+ cnt = 0
217
+ if world_size > 1:
218
+ torch.distributed.barrier()
219
+ desc = "Running inference Flickr30" if is_flickr else "Running inference COCO"
220
+ for ii, batch in enumerate(more_itertools.chunked(
221
+ tqdm(eval_dataset, desc=desc, disable=(rank != 0)), batch_size
222
+ )):
223
+ if ii % world_size != rank:
224
+ continue
225
+ cnt += len(batch)
226
+ batch[0]["image"] = Image.open("/gpfs/u/home/LMCG/LMCGljnn/scratch/images/img3.jpg").resize((224, 224))
227
+ batch_images = prepare_batch_images(
228
+ batch=batch,
229
+ image_processor=image_processor,
230
+ ).cuda()
231
+ prompt = f"{tokenizer.bos_token}<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|>"
232
+ added_bbox_list = []
233
+ batch_text = [prompt for _ in batch]
234
+ encodings = tokenizer(
235
+ batch_text,
236
+ padding="longest",
237
+ truncation=True,
238
+ return_tensors="pt",
239
+ max_length=2000,
240
+ )
241
+ ori_prompt_length = len(encodings["input_ids"][0])
242
+ have_prebox = False
243
+ while True:
244
+ batch_text = [prompt for _ in batch]
245
+ encodings = tokenizer(
246
+ batch_text,
247
+ padding="longest",
248
+ truncation=True,
249
+ return_tensors="pt",
250
+ max_length=2000,
251
+ )
252
+ input_ids = encodings["input_ids"].cuda()
253
+ attention_mask = encodings["attention_mask"].cuda()
254
+ image_start_index_list = ((input_ids == media_token_id).nonzero(as_tuple=True)[-1] + 1).tolist()
255
+ image_start_index_list = [[x] for x in image_start_index_list]
256
+ image_nums = [1] * len(input_ids)
257
+ if debug:
258
+ print("input--->",tokenizer.decode(input_ids[0]))
259
+ p1 = MinNewTokensLengthLogitsProcessor(
260
+ prompt_length_to_skip=input_ids.shape[-1],
261
+ min_new_tokens=5,
262
+ eos_token_id=bos_token_id,
263
+ )
264
+ with torch.inference_mode() and torch.cuda.amp.autocast(dtype=torch.float16):
265
+ outputs = model.generate(
266
+ batch_images,
267
+ input_ids,
268
+ attention_mask=attention_mask,
269
+ max_new_tokens=20,
270
+ # min_new_tokens=8,
271
+ num_beams=1,
272
+ # length_penalty=0,
273
+ image_start_index_list=image_start_index_list,
274
+ image_nums=image_nums,
275
+ added_bbox_list=added_bbox_list if len(added_bbox_list) != 0 else None,
276
+ logits_processor_list=[p1, visual_logits_processor],
277
+ )
278
+ if debug:
279
+ print("outputs--->",tokenizer.decode(outputs[0]))
280
+ if outputs[0, -2] in [previsual_token_id, visual_token_id] and outputs[0, -1] == bos_token_id:
281
+ prompt = tokenizer.decode(outputs.clone()[0])
282
+ is_visual = (outputs[0, -2] == visual_token_id)
283
+ batch_text = tokenizer.batch_decode(outputs[:, :-1])
284
+ encodings = tokenizer(
285
+ batch_text,
286
+ padding="longest",
287
+ truncation=True,
288
+ return_tensors="pt",
289
+ max_length=2000,
290
+ )
291
+ input_ids = encodings["input_ids"].cuda()
292
+ attention_mask = encodings["attention_mask"].cuda()
293
+ image_start_index_list = ((input_ids == media_token_id).nonzero(as_tuple=True)[-1] + 1).tolist()
294
+ image_start_index_list = [[x] for x in image_start_index_list]
295
+ image_nums = [1] * len(input_ids)
296
+ if debug:
297
+ print("get the visual bbox--->",tokenizer.decode(input_ids[0]))
298
+ with torch.cuda.amp.autocast(dtype=torch.float16) and torch.no_grad():
299
+ outputs = model(
300
+ vision_x=batch_images,
301
+ lang_x=input_ids,
302
+ attention_mask=attention_mask,
303
+ image_nums=image_nums,
304
+ image_start_index_list=image_start_index_list,
305
+ added_bbox_list=added_bbox_list if len(added_bbox_list) != 0 else None,
306
+ add_box=added_bbox_list is not None and len(added_bbox_list) != 0,
307
+ )
308
+ boxes = outputs["boxes"]
309
+ scores = outputs["scores"]
310
+ # if not model.valid:
311
+ # import pdb; pdb.set_trace()
312
+ if boxes is not None:
313
+ if is_visual:
314
+ if have_prebox:
315
+ added_bbox_list.pop()
316
+ prompt = prompt.replace("<|#previsual#|><|#prebox#|><|#object#|>", "")
317
+ have_prebox = False
318
+ if debug:
319
+ print("find previsual and remove it--->", prompt)
320
+ first_box = boxes[scores.argmax()]
321
+ added_bbox_list += [torch.tensor(first_box).unsqueeze(0).cuda() / 224]
322
+ prompt = prompt[:-len(tokenizer.eos_token)]
323
+ prompt += box_token + endofobject_token
324
+ if debug:
325
+ print("after inserting visual---->", prompt)
326
+ else:
327
+ import numpy as np
328
+ import cv2
329
+ open_cv_image = np.array(batch[0]["image"])
330
+ open_cv_image = open_cv_image[:, :, ::-1].copy()
331
+ for i, pre_box in enumerate(boxes):
332
+ open_cv_image = cv2.rectangle(open_cv_image, pre_box[:2].astype(int), pre_box[2:].astype(int), (0, 255, 0), i+1)
333
+ cv2.imwrite("Atest.png", open_cv_image)
334
+ exit()
335
+ pre_box = boxes[scores.argmax()]
336
+ added_bbox_list += [torch.tensor(pre_box).unsqueeze(0).cuda() / 224]
337
+ prompt = prompt[:-len(tokenizer.eos_token)]
338
+ prompt += prebox_token + object_token
339
+ have_prebox = True
340
+ if debug:
341
+ print("after inserting previsual---->", prompt)
342
+ else:
343
+ import pdb;pdb.set_trace()
344
+ prompt = tokenizer.decode(outputs[0, :-2].clone()[0])
345
+ else:
346
+ break
347
+ outputs = outputs[:, ori_prompt_length:]
348
+ new_predictions = [
349
+ postprocess_captioning_generation(out).replace('"', "")
350
+ for out in tokenizer.batch_decode(outputs, skip_special_tokens=True)
351
+ ]
352
+ # import pdb; pdb.set_trace()
353
+ if rank == 0:
354
+ tqdm.write(new_predictions[0])
355
+ for i, sample in enumerate(batch):
356
+ predictions[int(sample["image_id"])] = {
357
+ "caption": new_predictions[i],
358
+ }
359
+ print(new_predictions)
360
+ exit()
361
+ results_path = (
362
+ f"flickrresults_{lang_encoder_name}_{rank}_{id}.json"
363
+ if is_flickr
364
+ else f"cocoresults_{lang_encoder_name}_{rank}_{id}.json"
365
+ )
366
+ with open(results_path, "w") as f:
367
+ f.write(
368
+ json.dumps(
369
+ [
370
+ {"image_id": k, "caption": predictions[k]["caption"]}
371
+ for k in predictions
372
+ ],
373
+ indent=2,
374
+ )
375
+ )
376
+ print("save to", results_path)
377
+ del predictions
378
+ time.sleep(10)
379
+ if world_size > 1:
380
+ torch.distributed.barrier()
381
+ if rank == 0:
382
+ print(f"evaluate on rank {rank}. world size is {world_size}")
383
+ predictions = []
384
+ for rank_i in range(world_size):
385
+ part_results_path = (
386
+ f"flickrresults_{lang_encoder_name}_{rank_i}_{id}.json"
387
+ if is_flickr
388
+ else f"cocoresults_{lang_encoder_name}_{rank_i}_{id}.json"
389
+ )
390
+ print("load", part_results_path)
391
+ predictions.extend(json.load(open(part_results_path)))
392
+ os.remove(part_results_path)
393
+ print("num:", len(predictions))
394
+ results_path = (
395
+ f"flickrresults_{lang_encoder_name}.json"
396
+ if is_flickr
397
+ else f"cocoresults_{lang_encoder_name}.json"
398
+ )
399
+ json.dump(predictions, open(results_path, "w"), indent=2)
400
+
401
+ metrics = compute_cider(
402
+ result_path=results_path,
403
+ annotations_path="/gpfs/u/home/LMCG/LMCGljnn/scratch/.cache/lavis/coco_gt/coco_karpathy_test_gt.json",
404
+ )
405
+ metrics["CIDEr"] *= 100
406
+ os.makedirs("eval_results", exist_ok=True)
407
+ acc = metrics["CIDEr"]
408
+ with open(os.path.join("eval_results", f"cococap_{model.expr_name}_{model.step_num}_{int(time.time())}_{acc}"), "w") as f:
409
+ f.write(json.dumps(predictions, indent=2))
410
+
411
+ # delete the temporary file
412
+ os.remove(results_path)
413
+ else:
414
+ metrics = {}
415
+ metrics["CIDEr"] = 0.0
416
+
417
+ return metrics["CIDEr"]