import torch import more_itertools from tqdm import tqdm import json import time import os import numpy as np from transformers import LogitsProcessor, MinNewTokensLengthLogitsProcessor, ForcedEOSTokenLogitsProcessor from PIL import Image import cv2 class VisualLogitsProcessor(LogitsProcessor): def __init__(self, tokenizer): super().__init__() self.tokenizer = tokenizer self.object_token_id = self.tokenizer("<|#object#|>", add_special_tokens=False)["input_ids"][-1] self.prebox_token_id = self.tokenizer("<|#prebox#|>", add_special_tokens=False)["input_ids"][-1] self.box_token_id = self.tokenizer("<|#box#|>", add_special_tokens=False)["input_ids"][-1] self.previsual_token_id = self.tokenizer("<|#previsual#|>", add_special_tokens=False)["input_ids"][-1] self.visual_token_id = self.tokenizer("<|#visual#|>", add_special_tokens=False)["input_ids"][-1] self.eos_token_id = self.tokenizer.encode(self.tokenizer.eos_token)[-1] self.endofobject_token_id = self.tokenizer("<|#endofobject#|>", add_special_tokens=False)["input_ids"][-1] self.topk = 2 def __call__(self, input_ids, scores): # print("decoding===>", self.tokenizer.decode(scores.sort(descending=True).indices.tolist()[0][:self.topk])) # import pdb; pdb.set_trace() if self.object_token_id in scores.sort(descending=True).indices.tolist()[0][1:self.topk] and self.eos_token_id not in scores.sort(descending=True).indices.tolist()[0][:self.topk] and (input_ids == self.object_token_id).sum() * 2 == (input_ids == self.endofobject_token_id).sum(): scores[0, self.object_token_id] = 1000 if input_ids[0, -1] == self.object_token_id and input_ids[0, -2] != self.prebox_token_id: if (input_ids[0, :-1] == self.object_token_id).sum() != 0: # print("generate a previsual token next") scores[0, self.previsual_token_id] = 1000 elif input_ids[0, -1] == self.previsual_token_id or input_ids[0, -1] == self.visual_token_id: # print("stop to run bbox generation for " + "previsual" if input_ids[0, -1] == self.previsual_token_id else "visual") scores[0, self.eos_token_id] = 1000 elif input_ids[0, -1] == self.endofobject_token_id and input_ids[0, -2] != self.box_token_id: # print("generate a visual token next") scores[0, self.visual_token_id] = 1000 return scores def prepare_batch_images(batch, image_processor): batch_images = None for b in batch: b_image = image_processor(b["image"]).unsqueeze(0).unsqueeze(1).unsqueeze(0) if batch_images is None: batch_images = b_image else: batch_images = torch.cat([batch_images, b_image], dim=0) return batch_images # def captioner( # model, tokenizer, image_ori, batch_images, input_ids, attention_mask, image_start_index_list, image_nums, # added_bbox_list, debug=True): # """Evaluate a model on COCO dataset. # Returns: # float: CIDEr score # # """ # visual_logits_processor = VisualLogitsProcessor(tokenizer) # model.eval() # # model.eval().cuda() # lang_encoder_name = model.lang_encoder.__class__.__name__.lower() # media_token_id = tokenizer("<|#image#|>", add_special_tokens=False)["input_ids"][-1] # endofmedia_token_id = tokenizer("<|#endofimage#|>", add_special_tokens=False)["input_ids"][-1] # pad_token_id = tokenizer(tokenizer.pad_token, add_special_tokens=False)["input_ids"][-1] # bos_token_id = tokenizer(tokenizer.bos_token, add_special_tokens=False)["input_ids"][-1] # previsual_token_id = tokenizer("<|#previsual#|>", add_special_tokens=False)["input_ids"][-1] # visual_token_id = tokenizer("<|#visual#|>", add_special_tokens=False)["input_ids"][-1] # box_token = "<|#box#|>" # prebox_token = "<|#prebox#|>" # endofobject_token = "<|#endofobject#|>" # object_token = "<|#object#|>" # ori_prompt_length = len(input_ids[0]) # have_prebox = False # prompt = None # out_image = None # no_end = True # for i in range(500): # if no_end: # batch_images = batch_images # if prompt == None: # input_ids = input_ids # attention_mask = attention_mask # else: # encodings = tokenizer( # [prompt], # padding="longest", # truncation=True, # return_tensors="pt", # max_length=2000, # ) # attention_mask = encodings["attention_mask"] # input_ids = encodings["input_ids"] # image_start_index_list = image_start_index_list # image_nums = image_nums # if debug: # print("input--->", tokenizer.decode(input_ids[0])) # p1 = MinNewTokensLengthLogitsProcessor( # prompt_length_to_skip=input_ids.shape[-1], # min_new_tokens=5, # eos_token_id=bos_token_id, # ) # with torch.inference_mode(): # outputs = model.generate( # batch_images, # input_ids, # attention_mask=attention_mask, # max_new_tokens=20, # # min_new_tokens=8, # num_beams=1, # # length_penalty=0, # image_start_index_list=image_start_index_list, # image_nums=image_nums, # added_bbox_list=added_bbox_list if len(added_bbox_list) != 0 else None, # logits_processor_list=[p1, visual_logits_processor], # ) # if debug: # print("outputs--->", tokenizer.decode(outputs[0])) # input_ids = encodings["input_ids"] # attention_mask = encodings["attention_mask"] # image_start_index_list = ((input_ids == media_token_id).nonzero(as_tuple=True)[-1] + 1).tolist() # image_start_index_list = [[x] for x in image_start_index_list] # image_nums = [1] * len(input_ids) # if debug: # print("get the visual bbox--->", tokenizer.decode(input_ids[0])) # with torch.no_grad(): # outputs = model( # vision_x=batch_images, # lang_x=input_ids, # attention_mask=attention_mask, # image_nums=image_nums, # image_start_index_list=image_start_index_list, # added_bbox_list=added_bbox_list if len(added_bbox_list) != 0 else None, # add_box=added_bbox_list is not None and len(added_bbox_list) != 0, # ) # boxes = outputs["boxes"] # scores = outputs["scores"] # if debug: # print("box num---->", len(boxes)) # # if not model.valid: # # import pdb; pdb.set_trace() # if boxes is not None: # if is_visual: # if have_prebox: # added_bbox_list.pop() # prompt = prompt.replace("<|#previsual#|><|#prebox#|><|#object#|>", "") # have_prebox = False # if debug: # print("find previsual and remove it--->", prompt) # first_box = boxes[scores.argmax()] # added_bbox_list += [torch.tensor(first_box).unsqueeze(0) / 224] # prompt = prompt[:-len(tokenizer.eos_token)] # prompt += box_token + endofobject_token # if debug: # print("after inserting visual---->", prompt) # # else: # import numpy as np # import cv2 # # # exit() # pre_box = boxes[scores.argmax()] # added_bbox_list += [torch.tensor(pre_box).unsqueeze(0) / 224] # prompt = prompt[:-len(tokenizer.eos_token)] # prompt += prebox_token + object_token # have_prebox = True # if debug: # print("after inserting previsual---->", prompt) # else: # # if debug: # # import pdb;pdb.set_trace() # prompt = tokenizer.decode(outputs.clone()[0]) # if debug: # print("before else---->", prompt) # prompt = tokenizer.decode(outputs[0, :-2].clone()[0]) # if debug: # print("after else---->", prompt) # # else: # no_end = False # # break # # print("outputs--->", tokenizer.decode(outputs[0])) # outputs = outputs[:, ori_prompt_length:] # outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0].replace('"', "") # open_cv_image = np.array(image_ori) # open_cv_image = open_cv_image[:, :, ::-1].copy() # width = image_ori.width # height = image_ori.height # for i, pre_box in enumerate(added_bbox_list): # open_cv_image = cv2.rectangle(open_cv_image, np.array(pre_box[0][:2]*[width,height]).astype(int), np.array(pre_box[0][2:]*[width,height]).astype(int), # (0, 255, 0), i + 1) # out_image = Image.fromarray(cv2.cvtColor(open_cv_image, cv2.COLOR_BGR2RGB)) # # new_predictions = [ # # postprocess_captioning_generation(out).replace('"', "") # # for out in tokenizer.batch_decode(outputs, skip_special_tokens=True) # # ] # # import pdb; pdb.set_trace() # # return outputs, out_image def captioner( model, tokenizer, image_ori, batch_images, input_ids, attention_mask, image_start_index_list, image_nums, added_bbox_list, debug=True): """Evaluate a model on COCO dataset. Returns: float: CIDEr score """ visual_logits_processor = VisualLogitsProcessor(tokenizer) model.eval() # model.eval().cuda() lang_encoder_name = model.lang_encoder.__class__.__name__.lower() media_token_id = tokenizer("<|#image#|>", add_special_tokens=False)["input_ids"][-1] endofmedia_token_id = tokenizer("<|#endofimage#|>", add_special_tokens=False)["input_ids"][-1] pad_token_id = tokenizer(tokenizer.pad_token, add_special_tokens=False)["input_ids"][-1] bos_token_id = tokenizer(tokenizer.bos_token, add_special_tokens=False)["input_ids"][-1] previsual_token_id = tokenizer("<|#previsual#|>", add_special_tokens=False)["input_ids"][-1] visual_token_id = tokenizer("<|#visual#|>", add_special_tokens=False)["input_ids"][-1] box_token = "<|#box#|>" prebox_token = "<|#prebox#|>" endofobject_token = "<|#endofobject#|>" object_token = "<|#object#|>" ori_prompt_length = len(input_ids[0]) have_prebox = False prompt = None out_image = None no_end = True for i in range(100): if no_end: batch_images = batch_images if prompt == None: input_ids = input_ids attention_mask = attention_mask else: encodings = tokenizer( [prompt], padding="longest", truncation=True, return_tensors="pt", max_length=2000, ) attention_mask = encodings["attention_mask"] input_ids = encodings["input_ids"] image_start_index_list = image_start_index_list image_nums = image_nums if debug: print("input--->", tokenizer.decode(input_ids[0])) p1 = MinNewTokensLengthLogitsProcessor( prompt_length_to_skip=input_ids.shape[-1], min_new_tokens=5, eos_token_id=bos_token_id, ) with torch.inference_mode(): outputs = model.generate( batch_images, input_ids, attention_mask=attention_mask, max_new_tokens=20, # min_new_tokens=8, num_beams=1, # length_penalty=0, image_start_index_list=image_start_index_list, image_nums=image_nums, added_bbox_list=added_bbox_list if len(added_bbox_list) != 0 else None, logits_processor_list=[p1, visual_logits_processor], ) if debug: print("outputs--->", tokenizer.decode(outputs[0])) if outputs[0, -2] in [previsual_token_id, visual_token_id] and outputs[0, -1] == bos_token_id: prompt = tokenizer.decode(outputs.clone()[0]) is_visual = (outputs[0, -2] == visual_token_id) batch_text = tokenizer.batch_decode(outputs[:, :-1]) encodings = tokenizer( batch_text, padding="longest", truncation=True, return_tensors="pt", max_length=2000, ) input_ids = encodings["input_ids"] attention_mask = encodings["attention_mask"] image_start_index_list = ((input_ids == media_token_id).nonzero(as_tuple=True)[-1] + 1).tolist() image_start_index_list = [[x] for x in image_start_index_list] image_nums = [1] * len(input_ids) if debug: print("get the visual bbox--->", tokenizer.decode(input_ids[0])) with torch.no_grad(): outputs = model( vision_x=batch_images, lang_x=input_ids, attention_mask=attention_mask, image_nums=image_nums, image_start_index_list=image_start_index_list, added_bbox_list=added_bbox_list if len(added_bbox_list) != 0 else None, add_box=added_bbox_list is not None and len(added_bbox_list) != 0, ) boxes = outputs["boxes"] scores = outputs["scores"] if debug: print("box num---->", len(boxes)) # if not model.valid: # import pdb; pdb.set_trace() if boxes is not None: if is_visual: if have_prebox: added_bbox_list.pop() prompt = prompt.replace("<|#previsual#|><|#prebox#|><|#object#|>", "") have_prebox = False if debug: print("find previsual and remove it--->", prompt) first_box = boxes[scores.argmax()] added_bbox_list += [torch.tensor(first_box).unsqueeze(0) / 224] prompt = prompt[:-len(tokenizer.eos_token)] prompt += box_token + endofobject_token if debug: print("after inserting visual---->", prompt) else: import numpy as np import cv2 # exit() pre_box = boxes[scores.argmax()] added_bbox_list += [torch.tensor(pre_box).unsqueeze(0) / 224] prompt = prompt[:-len(tokenizer.eos_token)] prompt += prebox_token + object_token have_prebox = True if debug: print("after inserting previsual---->", prompt) else: # if debug: # import pdb;pdb.set_trace() prompt = tokenizer.decode(outputs.clone()[0]) if debug: print("before else---->", prompt) prompt = tokenizer.decode(outputs[0, :-2].clone()[0]) if debug: print("after else---->", prompt) else: no_end = False outputs = outputs[:, ori_prompt_length:] outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0].replace('"', "") open_cv_image = np.array(image_ori) open_cv_image = open_cv_image[:, :, ::-1].copy() width = image_ori.width height = image_ori.height for i, pre_box in enumerate(added_bbox_list): print(pre_box) open_cv_image = cv2.rectangle(open_cv_image, (np.array(pre_box[0][:2]) * [width, height]).astype(int), (np.array(pre_box[0][2:]) * [width, height]).astype(int), (0, 255, 0), i + 1) out_image = Image.fromarray(cv2.cvtColor(open_cv_image, cv2.COLOR_BGR2RGB)) # new_predictions = [ # postprocess_captioning_generation(out).replace('"', "") # for out in tokenizer.batch_decode(outputs, skip_special_tokens=True) # ] # import pdb; pdb.set_trace() return outputs, out_image