Spaces:
Runtime error
Runtime error
import spacy | |
import torch | |
from tqdm import tqdm | |
import numpy as np | |
import itertools | |
nlp = spacy.load('en_core_web_md') | |
def get_iou(box1, box2): | |
# box1 and box2 should be in the format [x1, y1, x2, y2] | |
intersection = max(0, min(box1[2], box2[2]) - max(box1[0], box2[0])) * \ | |
max(0, min(box1[3], box2[3]) - max(box1[1], box2[1])) | |
area_box1 = (box1[2] - box1[0]) * (box1[3] - box1[1]) | |
area_box2 = (box2[2] - box2[0]) * (box2[3] - box2[1]) | |
union = area_box1 + area_box2 - intersection | |
iou = intersection / union if union > 0 else 0 | |
return iou | |
# def find_root(token): | |
# if token.pos_ == "VERB": | |
# return token | |
# while token.dep_ not in ["pobj", "nsubj", "ROOT", "npadvmod", "dobj", "det", "prep", "punct", "cc", "conj", "acl", "dep", "appos", "relcl", "advmod", "nmod", "attr"]: | |
# token = token.head | |
# return token | |
def find_root(token): | |
if token.pos_ == "VERB": | |
return token | |
while token.dep_ in ["compound", "amod"]: | |
token = token.head | |
return token | |
def get_object_from_text(text, verbose=False): | |
if len(text.split(" ")) == 3: | |
text = text.split(" ") | |
return [text[0], text[-1]] | |
doc = nlp(text) | |
if verbose: | |
for TT in doc: | |
print(TT.text, TT.pos_, TT.dep_, TT.head) | |
roots = set() | |
for i, token in enumerate(doc): | |
roots.add(find_root(token)) | |
exprs = [] | |
roots = sorted(list(roots), key=lambda token: token.idx) | |
first_nsubj = True | |
if verbose: | |
print(roots) | |
for root in roots: | |
if root.pos_ not in ["NOUN", "PROPN"]: | |
continue | |
if root.dep_ not in ["pobj", "nsubj"]: | |
continue | |
if not first_nsubj and root.dep_ in ["nsubj"]: | |
continue | |
exprs.append([]) | |
for token in doc: | |
if find_root(token) == root: | |
exprs[-1].append(token.text) | |
exprs[-1] = " ".join(exprs[-1]).replace(" '", "'") | |
if exprs[-1] not in text: | |
if verbose: | |
print("not in text error:", exprs[-1], "#",text) | |
# for TT in doc: | |
# print(TT.text, TT.pos_, TT.dep_, TT.head) | |
# import pdb; pdb.set_trace() | |
exprs.pop() | |
if first_nsubj and root.dep_ in ["nsubj"]: | |
first_nsubj = False | |
if len(exprs) <= 1: | |
if verbose: | |
print("not enough exprs error:", exprs, "#",text) | |
return [] | |
return exprs | |
def is_correct(input_ids, logits, tokenizer, object: str, topk=5, N=10): | |
answer_id = torch.tensor(tokenizer(f" {object}", add_special_tokens=False)["input_ids"]).to(input_ids.device) | |
answer_begin_idx = (input_ids == answer_id[0]).nonzero() | |
answer_idx = None | |
for (batch_idx, IDX) in answer_begin_idx: | |
try: | |
if (input_ids[batch_idx, IDX:IDX+len(answer_id)] == answer_id).all(): | |
answer_idx = list(range(IDX-1, IDX+len(answer_id)-1)) | |
except: | |
pass | |
if answer_idx is None: | |
return np.inf, False, False | |
res = logits[0, answer_idx].softmax(-1).sort(descending=True) | |
values = res.values | |
indices = res.indices | |
chosen_ids = list(itertools.product(*([list(range(N))]*len(answer_idx)))) | |
probs = [] | |
for ids in chosen_ids: | |
prob = 1.0 | |
for i, id in enumerate(ids): | |
prob *= values[i, id] | |
probs.append((prob.item(), ids)) | |
probs.sort(reverse=True) | |
answer_pos = tuple([id_array.tolist().index(idx) for id_array, idx in zip(indices, answer_id)]) | |
ranking = [p[1] for p in probs] | |
# if len(answer_idx) > 1: | |
# import pdb; pdb.set_trace() | |
try: | |
r = ranking.index(answer_pos) | |
return r, r < 1, r < 5 | |
except: | |
return np.inf, False, False | |
def get_bbox(visual_box_list, batch_images, prompt, model, tokenizer, media_token_id, prebox_token_id, debug=False, return_all=False): | |
assert isinstance(prompt, list) and len(prompt) == 1 and isinstance(prompt[0], str) | |
encodings = tokenizer( | |
prompt, | |
padding="longest", | |
truncation=True, | |
return_tensors="pt", | |
max_length=2000, | |
) | |
input_ids = encodings["input_ids"] | |
attention_mask = encodings["attention_mask"] | |
image_start_index_list = ((input_ids == media_token_id).nonzero(as_tuple=True)[-1] + 1).tolist() | |
image_start_index_list = [[x] for x in image_start_index_list] | |
image_nums = [1] * len(input_ids) | |
vision_x = batch_images.cuda() | |
lang_x = input_ids.cuda() | |
attention_mask = attention_mask.cuda() | |
model.debug_id = 0 | |
with torch.inference_mode() and torch.cuda.amp.autocast(dtype=torch.float16): | |
outputs = model( | |
vision_x=vision_x, | |
lang_x=lang_x, | |
attention_mask=attention_mask, | |
labels=None, | |
image_nums=image_nums, | |
image_start_index_list=image_start_index_list, | |
added_bbox_list=visual_box_list, | |
add_box=visual_box_list is not None, | |
relations=None, | |
debug_mode=False, | |
) | |
boxes = outputs["boxes"] | |
scores = outputs["scores"] | |
if debug: | |
import pdb; pdb.set_trace() | |
if return_all: | |
return boxes, scores | |
if len(scores) == 0: | |
return None, None | |
else: | |
return boxes[scores.argmax()], scores.max() | |
def _eval_text_image(text, image, model, tokenizer, image_processor, vis_embed_size, media_token_id, prebox_token_id, debug=False, objects=None): | |
batch_images = image_processor(image).unsqueeze(0).unsqueeze(1).unsqueeze(0) | |
if objects is None: | |
objects = get_object_from_text(text) | |
if len(objects) == 0: | |
return None, None, None | |
if debug: | |
tqdm.write(text) | |
tqdm.write(f"{objects}") | |
first_idx = text.find(objects[0]) | |
if first_idx == 0: | |
first_text = f"<|#object#|>{objects[0]}<|#endofobject#|><|#visual#|>" | |
else: | |
first_text = text[:first_idx-1] + f"<|#object#|> {objects[0]}<|#endofobject#|><|#visual#|>" | |
if debug: | |
tqdm.write(first_text) | |
prompt = [f"{tokenizer.bos_token}<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|>{first_text}"] | |
# import pdb; pdb.set_trace() | |
# print("do first get_bbox |", first_text) | |
first_box, first_score = get_bbox(None, batch_images, prompt, model, tokenizer, media_token_id, prebox_token_id, return_all=False) | |
if not model.valid and debug: | |
import pdb; pdb.set_trace() | |
if first_box is not None: | |
added_bbox_list = [torch.tensor(first_box).unsqueeze(0).cuda() / 224] | |
text = first_text + "<|#box#|><|#endofobject#|>" + text[first_idx+len(objects[0]):] | |
else: | |
added_bbox_list = [] | |
final_ranks = [] | |
is_top1_list = [] | |
is_top5_list = [] | |
for kk, object in enumerate(objects): | |
if kk == 0: | |
continue | |
idx = text.find(objects[0]) | |
for t_i, temp in enumerate(objects[1:kk+1]): | |
# t_i is actually the previous one. This is not a bug | |
idx = text.find(temp, idx + len(objects[t_i])) | |
while idx+len(temp) != len(text) and (text[idx-1] == "#" or text[idx+len(temp)] == "#"): | |
# in case temp is box or object or visual or something like that | |
idx = text.find(temp, idx + len(temp)) | |
this_text = text[:idx-1] + "<|#object#|><|#previsual#|>" | |
# if this_text == "<|#object#|><|#previsual#|>": | |
# import pdb; pdb.set_trace() | |
if debug: | |
tqdm.write(this_text) | |
prompt = [f"{tokenizer.bos_token}<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|>{this_text}"] | |
# import pdb; pdb.set_trace() | |
# print("do pre get_bbox |", this_text) | |
pre_boxes, pre_scores = get_bbox(added_bbox_list, batch_images, prompt, model, tokenizer, media_token_id, | |
prebox_token_id, return_all=True) | |
if not model.valid and debug: | |
import pdb; pdb.set_trace() | |
logits_list = [] | |
# pre_boxes = [pre_boxes[0]] | |
# pre_scores = [pre_scores[0]] | |
this_text = this_text + f"<|#prebox#|><|#object#|> {object}<|#endofobject#|>" | |
for pre_box, pre_score in zip(pre_boxes, pre_scores): | |
prompt = [f"{tokenizer.bos_token}<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|>{this_text}"] | |
encodings = tokenizer( | |
prompt, | |
padding="longest", | |
truncation=True, | |
return_tensors="pt", | |
max_length=512, | |
) | |
input_ids = encodings["input_ids"] | |
attention_mask = encodings["attention_mask"] | |
image_start_index_list = ((input_ids == media_token_id).nonzero(as_tuple=True)[-1] + 1).tolist() | |
image_start_index_list = [[x] for x in image_start_index_list] | |
image_nums = [1] * len(input_ids) | |
vision_x = batch_images.cuda() | |
lang_x = input_ids.cuda() | |
attention_mask = attention_mask.cuda() | |
this_added_bbox_list = added_bbox_list + [torch.tensor(pre_box).unsqueeze(0).cuda() / 224] | |
with torch.cuda.amp.autocast(dtype=torch.float16) and torch.no_grad(): | |
outputs = model( | |
vision_x=vision_x, | |
lang_x=lang_x, | |
attention_mask=attention_mask, | |
image_nums=image_nums, | |
image_start_index_list=image_start_index_list, | |
added_bbox_list=this_added_bbox_list, | |
add_box=this_added_bbox_list is not None and len(this_added_bbox_list) != 0, | |
relations=None, | |
) | |
if not model.valid and debug: | |
import pdb; pdb.set_trace() | |
logits_list.append([pre_score, outputs.logits]) | |
if debug: | |
answer_start_idx = (lang_x == tokenizer("<|#object#|>", add_special_tokens=False)["input_ids"][-1]).nonzero()[-1][1] | |
logits = outputs["logits"][0, answer_start_idx:] | |
tqdm.write(tokenizer.decode(logits[0].sort(descending=True).indices.tolist()[:10])) | |
# if debug: | |
# image.save("Atest.png") | |
# open_cv_image = np.array(image) | |
# open_cv_image = open_cv_image[:, :, ::-1].copy() | |
# if first_box is not None: | |
# open_cv_image = cv2.rectangle(open_cv_image, first_box[:2].astype(int), first_box[2:].astype(int), (255, 0, 0), 2) | |
# if pre_box is not None: | |
# open_cv_image = cv2.rectangle(open_cv_image, pre_box[:2].astype(int), pre_box[2:].astype(int), (0, 255, 0), 2) | |
# cv2.imwrite(f"Atest.png", open_cv_image) | |
# import pdb; pdb.set_trace() | |
pre_scores = np.array([x[0] for x in logits_list]) | |
final_probs = 0.0 | |
for score, (_, logits) in zip(pre_scores, logits_list): | |
final_probs += score * logits.softmax(-1) | |
assert input_ids.shape[:2] == final_probs.shape[:2] | |
_rank, is_top1, is_top5 = is_correct(input_ids, final_probs, tokenizer, object, topk=5) | |
final_ranks.append(_rank) | |
is_top1_list.append(is_top1) | |
is_top5_list.append(is_top5) | |
this_text = text[:idx-1] + f"<|#object#|> {object}<|#endofobject#|><|#visual#|>" | |
if debug: | |
tqdm.write(this_text) | |
prompt = [f"{tokenizer.bos_token}<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|>{this_text}"] | |
# print("do this get_bbox |", this_text) | |
this_box, this_score = get_bbox(added_bbox_list, batch_images, prompt, model, tokenizer, media_token_id, prebox_token_id, return_all=False) | |
if not model.valid and debug: | |
import pdb; pdb.set_trace() | |
if this_box is not None: | |
added_bbox_list += [torch.tensor(this_box).unsqueeze(0).cuda() / 224] | |
text = this_text + "<|#box#|><|#endofobject#|>" + text[idx+len(object):] | |
return final_ranks, is_top1_list, is_top5_list | |
if __name__ == "__main__": | |
# print(get_object_from_text("there is a cookie. there is a bear. white orio cookie is next to the teddy bear. car runs on the traffic road. there is a tree.", verbose=False)) | |
print(get_object_from_text("President speaks to an American at a business office",verbose=True)) | |