import spacy import torch from tqdm import tqdm import numpy as np import itertools nlp = spacy.load('en_core_web_md') def get_iou(box1, box2): # box1 and box2 should be in the format [x1, y1, x2, y2] intersection = max(0, min(box1[2], box2[2]) - max(box1[0], box2[0])) * \ max(0, min(box1[3], box2[3]) - max(box1[1], box2[1])) area_box1 = (box1[2] - box1[0]) * (box1[3] - box1[1]) area_box2 = (box2[2] - box2[0]) * (box2[3] - box2[1]) union = area_box1 + area_box2 - intersection iou = intersection / union if union > 0 else 0 return iou # def find_root(token): # if token.pos_ == "VERB": # return token # while token.dep_ not in ["pobj", "nsubj", "ROOT", "npadvmod", "dobj", "det", "prep", "punct", "cc", "conj", "acl", "dep", "appos", "relcl", "advmod", "nmod", "attr"]: # token = token.head # return token def find_root(token): if token.pos_ == "VERB": return token while token.dep_ in ["compound", "amod"]: token = token.head return token def get_object_from_text(text, verbose=False): if len(text.split(" ")) == 3: text = text.split(" ") return [text[0], text[-1]] doc = nlp(text) if verbose: for TT in doc: print(TT.text, TT.pos_, TT.dep_, TT.head) roots = set() for i, token in enumerate(doc): roots.add(find_root(token)) exprs = [] roots = sorted(list(roots), key=lambda token: token.idx) first_nsubj = True if verbose: print(roots) for root in roots: if root.pos_ not in ["NOUN", "PROPN"]: continue if root.dep_ not in ["pobj", "nsubj"]: continue if not first_nsubj and root.dep_ in ["nsubj"]: continue exprs.append([]) for token in doc: if find_root(token) == root: exprs[-1].append(token.text) exprs[-1] = " ".join(exprs[-1]).replace(" '", "'") if exprs[-1] not in text: if verbose: print("not in text error:", exprs[-1], "#",text) # for TT in doc: # print(TT.text, TT.pos_, TT.dep_, TT.head) # import pdb; pdb.set_trace() exprs.pop() if first_nsubj and root.dep_ in ["nsubj"]: first_nsubj = False if len(exprs) <= 1: if verbose: print("not enough exprs error:", exprs, "#",text) return [] return exprs def is_correct(input_ids, logits, tokenizer, object: str, topk=5, N=10): answer_id = torch.tensor(tokenizer(f" {object}", add_special_tokens=False)["input_ids"]).to(input_ids.device) answer_begin_idx = (input_ids == answer_id[0]).nonzero() answer_idx = None for (batch_idx, IDX) in answer_begin_idx: try: if (input_ids[batch_idx, IDX:IDX+len(answer_id)] == answer_id).all(): answer_idx = list(range(IDX-1, IDX+len(answer_id)-1)) except: pass if answer_idx is None: return np.inf, False, False res = logits[0, answer_idx].softmax(-1).sort(descending=True) values = res.values indices = res.indices chosen_ids = list(itertools.product(*([list(range(N))]*len(answer_idx)))) probs = [] for ids in chosen_ids: prob = 1.0 for i, id in enumerate(ids): prob *= values[i, id] probs.append((prob.item(), ids)) probs.sort(reverse=True) answer_pos = tuple([id_array.tolist().index(idx) for id_array, idx in zip(indices, answer_id)]) ranking = [p[1] for p in probs] # if len(answer_idx) > 1: # import pdb; pdb.set_trace() try: r = ranking.index(answer_pos) return r, r < 1, r < 5 except: return np.inf, False, False def get_bbox(visual_box_list, batch_images, prompt, model, tokenizer, media_token_id, prebox_token_id, debug=False, return_all=False): assert isinstance(prompt, list) and len(prompt) == 1 and isinstance(prompt[0], str) encodings = tokenizer( prompt, padding="longest", truncation=True, return_tensors="pt", max_length=2000, ) input_ids = encodings["input_ids"] attention_mask = encodings["attention_mask"] image_start_index_list = ((input_ids == media_token_id).nonzero(as_tuple=True)[-1] + 1).tolist() image_start_index_list = [[x] for x in image_start_index_list] image_nums = [1] * len(input_ids) vision_x = batch_images.cuda() lang_x = input_ids.cuda() attention_mask = attention_mask.cuda() model.debug_id = 0 with torch.inference_mode() and torch.cuda.amp.autocast(dtype=torch.float16): outputs = model( vision_x=vision_x, lang_x=lang_x, attention_mask=attention_mask, labels=None, image_nums=image_nums, image_start_index_list=image_start_index_list, added_bbox_list=visual_box_list, add_box=visual_box_list is not None, relations=None, debug_mode=False, ) boxes = outputs["boxes"] scores = outputs["scores"] if debug: import pdb; pdb.set_trace() if return_all: return boxes, scores if len(scores) == 0: return None, None else: return boxes[scores.argmax()], scores.max() def _eval_text_image(text, image, model, tokenizer, image_processor, vis_embed_size, media_token_id, prebox_token_id, debug=False, objects=None): batch_images = image_processor(image).unsqueeze(0).unsqueeze(1).unsqueeze(0) if objects is None: objects = get_object_from_text(text) if len(objects) == 0: return None, None, None if debug: tqdm.write(text) tqdm.write(f"{objects}") first_idx = text.find(objects[0]) if first_idx == 0: first_text = f"<|#object#|>{objects[0]}<|#endofobject#|><|#visual#|>" else: first_text = text[:first_idx-1] + f"<|#object#|> {objects[0]}<|#endofobject#|><|#visual#|>" if debug: tqdm.write(first_text) prompt = [f"{tokenizer.bos_token}<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|>{first_text}"] # import pdb; pdb.set_trace() # print("do first get_bbox |", first_text) first_box, first_score = get_bbox(None, batch_images, prompt, model, tokenizer, media_token_id, prebox_token_id, return_all=False) if not model.valid and debug: import pdb; pdb.set_trace() if first_box is not None: added_bbox_list = [torch.tensor(first_box).unsqueeze(0).cuda() / 224] text = first_text + "<|#box#|><|#endofobject#|>" + text[first_idx+len(objects[0]):] else: added_bbox_list = [] final_ranks = [] is_top1_list = [] is_top5_list = [] for kk, object in enumerate(objects): if kk == 0: continue idx = text.find(objects[0]) for t_i, temp in enumerate(objects[1:kk+1]): # t_i is actually the previous one. This is not a bug idx = text.find(temp, idx + len(objects[t_i])) while idx+len(temp) != len(text) and (text[idx-1] == "#" or text[idx+len(temp)] == "#"): # in case temp is box or object or visual or something like that idx = text.find(temp, idx + len(temp)) this_text = text[:idx-1] + "<|#object#|><|#previsual#|>" # if this_text == "<|#object#|><|#previsual#|>": # import pdb; pdb.set_trace() if debug: tqdm.write(this_text) prompt = [f"{tokenizer.bos_token}<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|>{this_text}"] # import pdb; pdb.set_trace() # print("do pre get_bbox |", this_text) pre_boxes, pre_scores = get_bbox(added_bbox_list, batch_images, prompt, model, tokenizer, media_token_id, prebox_token_id, return_all=True) if not model.valid and debug: import pdb; pdb.set_trace() logits_list = [] # pre_boxes = [pre_boxes[0]] # pre_scores = [pre_scores[0]] this_text = this_text + f"<|#prebox#|><|#object#|> {object}<|#endofobject#|>" for pre_box, pre_score in zip(pre_boxes, pre_scores): prompt = [f"{tokenizer.bos_token}<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|>{this_text}"] encodings = tokenizer( prompt, padding="longest", truncation=True, return_tensors="pt", max_length=512, ) input_ids = encodings["input_ids"] attention_mask = encodings["attention_mask"] image_start_index_list = ((input_ids == media_token_id).nonzero(as_tuple=True)[-1] + 1).tolist() image_start_index_list = [[x] for x in image_start_index_list] image_nums = [1] * len(input_ids) vision_x = batch_images.cuda() lang_x = input_ids.cuda() attention_mask = attention_mask.cuda() this_added_bbox_list = added_bbox_list + [torch.tensor(pre_box).unsqueeze(0).cuda() / 224] with torch.cuda.amp.autocast(dtype=torch.float16) and torch.no_grad(): outputs = model( vision_x=vision_x, lang_x=lang_x, attention_mask=attention_mask, image_nums=image_nums, image_start_index_list=image_start_index_list, added_bbox_list=this_added_bbox_list, add_box=this_added_bbox_list is not None and len(this_added_bbox_list) != 0, relations=None, ) if not model.valid and debug: import pdb; pdb.set_trace() logits_list.append([pre_score, outputs.logits]) if debug: answer_start_idx = (lang_x == tokenizer("<|#object#|>", add_special_tokens=False)["input_ids"][-1]).nonzero()[-1][1] logits = outputs["logits"][0, answer_start_idx:] tqdm.write(tokenizer.decode(logits[0].sort(descending=True).indices.tolist()[:10])) # if debug: # image.save("Atest.png") # open_cv_image = np.array(image) # open_cv_image = open_cv_image[:, :, ::-1].copy() # if first_box is not None: # open_cv_image = cv2.rectangle(open_cv_image, first_box[:2].astype(int), first_box[2:].astype(int), (255, 0, 0), 2) # if pre_box is not None: # open_cv_image = cv2.rectangle(open_cv_image, pre_box[:2].astype(int), pre_box[2:].astype(int), (0, 255, 0), 2) # cv2.imwrite(f"Atest.png", open_cv_image) # import pdb; pdb.set_trace() pre_scores = np.array([x[0] for x in logits_list]) final_probs = 0.0 for score, (_, logits) in zip(pre_scores, logits_list): final_probs += score * logits.softmax(-1) assert input_ids.shape[:2] == final_probs.shape[:2] _rank, is_top1, is_top5 = is_correct(input_ids, final_probs, tokenizer, object, topk=5) final_ranks.append(_rank) is_top1_list.append(is_top1) is_top5_list.append(is_top5) this_text = text[:idx-1] + f"<|#object#|> {object}<|#endofobject#|><|#visual#|>" if debug: tqdm.write(this_text) prompt = [f"{tokenizer.bos_token}<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|>{this_text}"] # print("do this get_bbox |", this_text) this_box, this_score = get_bbox(added_bbox_list, batch_images, prompt, model, tokenizer, media_token_id, prebox_token_id, return_all=False) if not model.valid and debug: import pdb; pdb.set_trace() if this_box is not None: added_bbox_list += [torch.tensor(this_box).unsqueeze(0).cuda() / 224] text = this_text + "<|#box#|><|#endofobject#|>" + text[idx+len(object):] return final_ranks, is_top1_list, is_top5_list if __name__ == "__main__": # print(get_object_from_text("there is a cookie. there is a bear. white orio cookie is next to the teddy bear. car runs on the traffic road. there is a tree.", verbose=False)) print(get_object_from_text("President speaks to an American at a business office",verbose=True))