from lavis.datasets.builders import load_dataset import torch import more_itertools from tqdm import tqdm from coco_metric import compute_cider, postprocess_captioning_generation import json import time import os from transformers import LogitsProcessor, MinNewTokensLengthLogitsProcessor, ForcedEOSTokenLogitsProcessor from PIL import Image class VisualLogitsProcessor(LogitsProcessor): def __init__(self, tokenizer): super().__init__() self.tokenizer = tokenizer self.object_token_id = self.tokenizer("<|#object#|>", add_special_tokens=False)["input_ids"][-1] self.prebox_token_id = self.tokenizer("<|#prebox#|>", add_special_tokens=False)["input_ids"][-1] self.box_token_id = self.tokenizer("<|#box#|>", add_special_tokens=False)["input_ids"][-1] self.previsual_token_id = self.tokenizer("<|#previsual#|>", add_special_tokens=False)["input_ids"][-1] self.visual_token_id = self.tokenizer("<|#visual#|>", add_special_tokens=False)["input_ids"][-1] self.eos_token_id = self.tokenizer.encode(self.tokenizer.eos_token)[-1] self.endofobject_token_id = self.tokenizer("<|#endofobject#|>", add_special_tokens=False)["input_ids"][-1] self.topk = 2 def __call__(self, input_ids, scores): # print("decoding===>", self.tokenizer.decode(scores.sort(descending=True).indices.tolist()[0][:self.topk])) # import pdb; pdb.set_trace() if self.object_token_id in scores.sort(descending=True).indices.tolist()[0][1:self.topk] and self.eos_token_id not in scores.sort(descending=True).indices.tolist()[0][:self.topk] and (input_ids == self.object_token_id).sum() * 2 == (input_ids == self.endofobject_token_id).sum(): scores[0, self.object_token_id] = 1000 if input_ids[0, -1] == self.object_token_id and input_ids[0, -2] != self.prebox_token_id: if (input_ids[0, :-1] == self.object_token_id).sum() != 0: # print("generate a previsual token next") scores[0, self.previsual_token_id] = 1000 elif input_ids[0, -1] == self.previsual_token_id or input_ids[0, -1] == self.visual_token_id: # print("stop to run bbox generation for " + "previsual" if input_ids[0, -1] == self.previsual_token_id else "visual") scores[0, self.eos_token_id] = 1000 elif input_ids[0, -1] == self.endofobject_token_id and input_ids[0, -2] != self.box_token_id: # print("generate a visual token next") scores[0, self.visual_token_id] = 1000 return scores def prepare_batch_images(batch, image_processor): batch_images = None for b in batch: b_image = image_processor(b["image"]).unsqueeze(0).unsqueeze(1).unsqueeze(0) if batch_images is None: batch_images = b_image else: batch_images = torch.cat([batch_images, b_image], dim=0) return batch_images def captioner( model,tokenizer,image_ori,batch_images,input_ids,attention_mask,image_start_index_list,image_nums,added_bbox_list,debug=False): """Evaluate a model on COCO dataset. Returns: float: CIDEr score """ visual_logits_processor = VisualLogitsProcessor(tokenizer) model.eval() # model.eval().cuda() lang_encoder_name = model.lang_encoder.__class__.__name__.lower() media_token_id = tokenizer("<|#image#|>", add_special_tokens=False)["input_ids"][-1] endofmedia_token_id = tokenizer("<|#endofimage#|>", add_special_tokens=False)["input_ids"][-1] pad_token_id = tokenizer(tokenizer.pad_token, add_special_tokens=False)["input_ids"][-1] bos_token_id = tokenizer(tokenizer.bos_token, add_special_tokens=False)["input_ids"][-1] previsual_token_id = tokenizer("<|#previsual#|>", add_special_tokens=False)["input_ids"][-1] visual_token_id = tokenizer("<|#visual#|>", add_special_tokens=False)["input_ids"][-1] box_token = "<|#box#|>" prebox_token = "<|#prebox#|>" endofobject_token = "<|#endofobject#|>" object_token = "<|#object#|>" ori_prompt_length = len(input_ids[0]) have_prebox = False out_image = None while True: batch_images = batch_images input_ids = input_ids attention_mask = attention_mask image_start_index_list = image_start_index_list image_nums = image_nums if debug: print("input--->",tokenizer.decode(input_ids[0])) p1 = MinNewTokensLengthLogitsProcessor( prompt_length_to_skip=input_ids.shape[-1], min_new_tokens=5, eos_token_id=bos_token_id, ) with torch.inference_mode(): outputs = model.generate( batch_images, input_ids, attention_mask=attention_mask, max_new_tokens=20, # min_new_tokens=8, num_beams=1, # length_penalty=0, image_start_index_list=image_start_index_list, image_nums=image_nums, added_bbox_list=added_bbox_list if len(added_bbox_list) != 0 else None, logits_processor_list=[p1, visual_logits_processor], ) if debug: print("outputs--->",tokenizer.decode(outputs[0])) if outputs[0, -2] in [previsual_token_id, visual_token_id] and outputs[0, -1] == bos_token_id: prompt = tokenizer.decode(outputs.clone()[0]) is_visual = (outputs[0, -2] == visual_token_id) batch_text = tokenizer.batch_decode(outputs[:, :-1]) encodings = tokenizer( batch_text, padding="longest", truncation=True, return_tensors="pt", max_length=2000, ) input_ids = encodings["input_ids"] attention_mask = encodings["attention_mask"] image_start_index_list = ((input_ids == media_token_id).nonzero(as_tuple=True)[-1] + 1).tolist() image_start_index_list = [[x] for x in image_start_index_list] image_nums = [1] * len(input_ids) if debug: print("get the visual bbox--->",tokenizer.decode(input_ids[0])) with torch.no_grad(): outputs = model( vision_x=batch_images, lang_x=input_ids, attention_mask=attention_mask, image_nums=image_nums, image_start_index_list=image_start_index_list, added_bbox_list=added_bbox_list if len(added_bbox_list) != 0 else None, add_box=added_bbox_list is not None and len(added_bbox_list) != 0, ) boxes = outputs["boxes"] scores = outputs["scores"] # if not model.valid: # import pdb; pdb.set_trace() if boxes is not None: if is_visual: if have_prebox: added_bbox_list.pop() prompt = prompt.replace("<|#previsual#|><|#prebox#|><|#object#|>", "") have_prebox = False if debug: print("find previsual and remove it--->", prompt) first_box = boxes[scores.argmax()] added_bbox_list += [torch.tensor(first_box).unsqueeze(0) / 224] prompt = prompt[:-len(tokenizer.eos_token)] prompt += box_token + endofobject_token if debug: print("after inserting visual---->", prompt) else: import numpy as np import cv2 open_cv_image = np.array(image_ori) open_cv_image = open_cv_image[:, :, ::-1].copy() for i, pre_box in enumerate(boxes): open_cv_image = cv2.rectangle(open_cv_image, pre_box[:2].astype(int), pre_box[2:].astype(int), (0, 255, 0), i+1) out_image = Image.fromarray(cv2.cvtColor(open_cv_image, cv2.COLOR_BGR2RGB)) # exit() pre_box = boxes[scores.argmax()] added_bbox_list += [torch.tensor(pre_box).unsqueeze(0).cuda() / 224] prompt = prompt[:-len(tokenizer.eos_token)] prompt += prebox_token + object_token have_prebox = True if debug: print("after inserting previsual---->", prompt) else: if debug: import pdb;pdb.set_trace() prompt = tokenizer.decode(outputs[0, :-2].clone()[0]) else: break outputs = outputs[:, ori_prompt_length:] outputs = postprocess_captioning_generation(tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]).replace('"', "") # new_predictions = [ # postprocess_captioning_generation(out).replace('"', "") # for out in tokenizer.batch_decode(outputs, skip_special_tokens=True) # ] # import pdb; pdb.set_trace() return outputs, out_image def evaluate_coco_flickr( model, tokenizer, image_processor, batch_size, is_flickr=False, vis_embed_size=None, rank=0, world_size=1, id=0, debug=False, ): """Evaluate a model on COCO dataset. Returns: float: CIDEr score """ visual_logits_processor = VisualLogitsProcessor(tokenizer) coco_dataset = load_dataset("coco_caption") eval_dataset = coco_dataset["test"] model.eval().cuda() predictions = dict() lang_encoder_name = model.lang_encoder.__class__.__name__.lower() media_token_id = tokenizer("<|#image#|>", add_special_tokens=False)["input_ids"][-1] endofmedia_token_id = tokenizer("<|#endofimage#|>", add_special_tokens=False)["input_ids"][-1] pad_token_id = tokenizer(tokenizer.pad_token, add_special_tokens=False)["input_ids"][-1] bos_token_id = tokenizer(tokenizer.bos_token, add_special_tokens=False)["input_ids"][-1] previsual_token_id = tokenizer("<|#previsual#|>", add_special_tokens=False)["input_ids"][-1] visual_token_id = tokenizer("<|#visual#|>", add_special_tokens=False)["input_ids"][-1] box_token = "<|#box#|>" prebox_token = "<|#prebox#|>" endofobject_token = "<|#endofobject#|>" object_token = "<|#object#|>" cnt = 0 if world_size > 1: torch.distributed.barrier() desc = "Running inference Flickr30" if is_flickr else "Running inference COCO" for ii, batch in enumerate(more_itertools.chunked( tqdm(eval_dataset, desc=desc, disable=(rank != 0)), batch_size )): if ii % world_size != rank: continue cnt += len(batch) batch[0]["image"] = Image.open("/gpfs/u/home/LMCG/LMCGljnn/scratch/images/img3.jpg").resize((224, 224)) batch_images = prepare_batch_images( batch=batch, image_processor=image_processor, ).cuda() prompt = f"{tokenizer.bos_token}<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|>" added_bbox_list = [] batch_text = [prompt for _ in batch] encodings = tokenizer( batch_text, padding="longest", truncation=True, return_tensors="pt", max_length=2000, ) ori_prompt_length = len(encodings["input_ids"][0]) have_prebox = False while True: batch_text = [prompt for _ in batch] encodings = tokenizer( batch_text, padding="longest", truncation=True, return_tensors="pt", max_length=2000, ) input_ids = encodings["input_ids"].cuda() attention_mask = encodings["attention_mask"].cuda() image_start_index_list = ((input_ids == media_token_id).nonzero(as_tuple=True)[-1] + 1).tolist() image_start_index_list = [[x] for x in image_start_index_list] image_nums = [1] * len(input_ids) if debug: print("input--->",tokenizer.decode(input_ids[0])) p1 = MinNewTokensLengthLogitsProcessor( prompt_length_to_skip=input_ids.shape[-1], min_new_tokens=5, eos_token_id=bos_token_id, ) with torch.inference_mode() and torch.cuda.amp.autocast(dtype=torch.float16): outputs = model.generate( batch_images, input_ids, attention_mask=attention_mask, max_new_tokens=20, # min_new_tokens=8, num_beams=1, # length_penalty=0, image_start_index_list=image_start_index_list, image_nums=image_nums, added_bbox_list=added_bbox_list if len(added_bbox_list) != 0 else None, logits_processor_list=[p1, visual_logits_processor], ) if debug: print("outputs--->",tokenizer.decode(outputs[0])) if outputs[0, -2] in [previsual_token_id, visual_token_id] and outputs[0, -1] == bos_token_id: prompt = tokenizer.decode(outputs.clone()[0]) is_visual = (outputs[0, -2] == visual_token_id) batch_text = tokenizer.batch_decode(outputs[:, :-1]) encodings = tokenizer( batch_text, padding="longest", truncation=True, return_tensors="pt", max_length=2000, ) input_ids = encodings["input_ids"].cuda() attention_mask = encodings["attention_mask"].cuda() image_start_index_list = ((input_ids == media_token_id).nonzero(as_tuple=True)[-1] + 1).tolist() image_start_index_list = [[x] for x in image_start_index_list] image_nums = [1] * len(input_ids) if debug: print("get the visual bbox--->",tokenizer.decode(input_ids[0])) with torch.cuda.amp.autocast(dtype=torch.float16) and torch.no_grad(): outputs = model( vision_x=batch_images, lang_x=input_ids, attention_mask=attention_mask, image_nums=image_nums, image_start_index_list=image_start_index_list, added_bbox_list=added_bbox_list if len(added_bbox_list) != 0 else None, add_box=added_bbox_list is not None and len(added_bbox_list) != 0, ) boxes = outputs["boxes"] scores = outputs["scores"] # if not model.valid: # import pdb; pdb.set_trace() if boxes is not None: if is_visual: if have_prebox: added_bbox_list.pop() prompt = prompt.replace("<|#previsual#|><|#prebox#|><|#object#|>", "") have_prebox = False if debug: print("find previsual and remove it--->", prompt) first_box = boxes[scores.argmax()] added_bbox_list += [torch.tensor(first_box).unsqueeze(0).cuda() / 224] prompt = prompt[:-len(tokenizer.eos_token)] prompt += box_token + endofobject_token if debug: print("after inserting visual---->", prompt) else: import numpy as np import cv2 open_cv_image = np.array(batch[0]["image"]) open_cv_image = open_cv_image[:, :, ::-1].copy() for i, pre_box in enumerate(boxes): open_cv_image = cv2.rectangle(open_cv_image, pre_box[:2].astype(int), pre_box[2:].astype(int), (0, 255, 0), i+1) cv2.imwrite("Atest.png", open_cv_image) exit() pre_box = boxes[scores.argmax()] added_bbox_list += [torch.tensor(pre_box).unsqueeze(0).cuda() / 224] prompt = prompt[:-len(tokenizer.eos_token)] prompt += prebox_token + object_token have_prebox = True if debug: print("after inserting previsual---->", prompt) else: import pdb;pdb.set_trace() prompt = tokenizer.decode(outputs[0, :-2].clone()[0]) else: break outputs = outputs[:, ori_prompt_length:] new_predictions = [ postprocess_captioning_generation(out).replace('"', "") for out in tokenizer.batch_decode(outputs, skip_special_tokens=True) ] # import pdb; pdb.set_trace() if rank == 0: tqdm.write(new_predictions[0]) for i, sample in enumerate(batch): predictions[int(sample["image_id"])] = { "caption": new_predictions[i], } print(new_predictions) exit() results_path = ( f"flickrresults_{lang_encoder_name}_{rank}_{id}.json" if is_flickr else f"cocoresults_{lang_encoder_name}_{rank}_{id}.json" ) with open(results_path, "w") as f: f.write( json.dumps( [ {"image_id": k, "caption": predictions[k]["caption"]} for k in predictions ], indent=2, ) ) print("save to", results_path) del predictions time.sleep(10) if world_size > 1: torch.distributed.barrier() if rank == 0: print(f"evaluate on rank {rank}. world size is {world_size}") predictions = [] for rank_i in range(world_size): part_results_path = ( f"flickrresults_{lang_encoder_name}_{rank_i}_{id}.json" if is_flickr else f"cocoresults_{lang_encoder_name}_{rank_i}_{id}.json" ) print("load", part_results_path) predictions.extend(json.load(open(part_results_path))) os.remove(part_results_path) print("num:", len(predictions)) results_path = ( f"flickrresults_{lang_encoder_name}.json" if is_flickr else f"cocoresults_{lang_encoder_name}.json" ) json.dump(predictions, open(results_path, "w"), indent=2) metrics = compute_cider( result_path=results_path, annotations_path="/gpfs/u/home/LMCG/LMCGljnn/scratch/.cache/lavis/coco_gt/coco_karpathy_test_gt.json", ) metrics["CIDEr"] *= 100 os.makedirs("eval_results", exist_ok=True) acc = metrics["CIDEr"] with open(os.path.join("eval_results", f"cococap_{model.expr_name}_{model.step_num}_{int(time.time())}_{acc}"), "w") as f: f.write(json.dumps(predictions, indent=2)) # delete the temporary file os.remove(results_path) else: metrics = {} metrics["CIDEr"] = 0.0 return metrics["CIDEr"]