ODR's picture
Duplicate from haotiz/glip-zeroshot-demo
12e55f6
import torch
import torch.distributed as dist
import time
from torchvision.ops import nms
import random
import numpy as np
from PIL import Image, ImageDraw
import pdb
from maskrcnn_benchmark.structures.bounding_box import BoxList
from .modulated_coco import ConvertCocoPolysToMask
from .tsv import ODTSVDataset, TSVYamlDataset
from .od_to_grounding import sanity_check_target_after_processing
from copy import deepcopy
class PseudoData(TSVYamlDataset):
def __init__(self,
yaml_file,
transforms,
return_tokens,
return_masks,
tokenizer,
caption_min_box=1,
replace_clean_label=False,
further_screen=False,
caption_conf=0.5,
caption_nms=-1,
pack_random_caption_number=0,
inference_caption=False,
sample_negative_for_grounding_data=-1,
random_pack_prob=-1.0,
no_random_pack_probability=0.0,
safeguard_positive_caption=True,
mlm_obj_for_only_positive=False,
caption_format_version="v1",
local_debug=False,
max_query_len=256,
diver_box_for_vqa=False,
**kwargs
):
super(PseudoData, self).__init__(yaml_file, None, replace_clean_label)
self.yaml_file = yaml_file
self._transforms = transforms
self.max_query_len = max_query_len
self.prepare = ConvertCocoPolysToMask(return_masks=return_masks,
return_tokens=return_tokens,
tokenizer=tokenizer,
max_query_len=max_query_len)
self.diver_box_for_vqa = diver_box_for_vqa
if "qa" in self.yaml_file:
assert(self.diver_box_for_vqa) # must diver box
self.tokenizer = tokenizer
self.caption_min_box = caption_min_box
self.replace_clean_label = replace_clean_label
self.further_screen = further_screen
self.pack_random_caption_number = pack_random_caption_number
self.caption_format_version = caption_format_version
self.caption_conf = caption_conf
self.caption_nms = caption_nms
self.inference_caption = inference_caption
self.sample_negative_for_grounding_data = sample_negative_for_grounding_data
self.random_pack_prob = random_pack_prob
self.no_random_pack_probability = no_random_pack_probability
self.safeguard_positive_caption = safeguard_positive_caption
self.mlm_obj_for_only_positive = mlm_obj_for_only_positive
self.local_debug = local_debug
try:
self.rank = dist.get_rank()
except:
self.rank = 0
def __len__(self):
return super(PseudoData, self).__len__()
@staticmethod
def check_for_overlap(range1, range2):
if range1[0] > range2[1] or range2[0] > range1[1]:
return False
return True
def divert_boxes(self, anno):
# first get answer start and end
answer_start = len(anno['text']) + 1 # +1 for the space
answer_end = len(anno["caption"])
question = anno["caption"][:answer_start] # get the question
mask_start = len(question)
# add the mask token
mask_token = self.tokenizer.mask_token
if mask_token is None:
mask_token = 'answer'
question += mask_token
mask_end = len(question)
# divert the box
for i in range(len(anno["bboxes"])):
# check over lap
for j in range(len(anno["tokens_positive"][i])):
if self.check_for_overlap(anno["tokens_positive"][i][j], [answer_start, answer_end]):
# if overlap, then divert the box to the mask token
anno["tokens_positive"][i][j] = [mask_start, mask_end]
anno["caption"] = question
return question, anno
def __getitem__(self, idx):
img, anno, _, scale = super(PseudoData, self).__getitem__(idx)
if self.inference_caption:
caption = None
if isinstance(anno, list):
caption = anno[0]["caption"] # inference mode for bing
anno = []
elif len(anno) == 1:
caption = anno["caption"] # inference mode for googlecc
anno = []
else:
caption = " ".join(anno["captions"])
anno = []
else:
if self.caption_format_version == "v2":
anno = self.convert_anno_from_yiling_to_ours(anno)
if self.further_screen:
conf = self.caption_conf
nms_thre = self.caption_nms
bboxes = torch.as_tensor(anno["bboxes"]).float()
scores = torch.as_tensor(anno["scores"])
tokens_positive = anno["tokens_positive"]
keep = scores > conf
scores = scores[keep]
bboxes = bboxes[keep]
tokens_positive = [i for index, i in enumerate(tokens_positive) if keep[index]]
assert (len(tokens_positive) == len(bboxes) == len(scores))
if len(bboxes) < self.caption_min_box: # Retry triggered!
return self[np.random.choice(len(self))]
if nms_thre > 0:
keep = nms(boxes=bboxes, scores=scores, iou_threshold=nms_thre)
scores = scores[keep]
bboxes = bboxes[keep]
tokens_positive = [tokens_positive[i] for i in keep]
assert (len(tokens_positive) == len(bboxes) == len(scores))
# Write back
anno["bboxes"] = bboxes.tolist()
anno["scores"] = scores.tolist()
anno["tokens_positive"] = tokens_positive
boxes = torch.as_tensor(anno["bboxes"])
if len(boxes) < self.caption_min_box: # Retry triggered!
return self[np.random.choice(len(self))]
target = BoxList(boxes, (anno["img_w"], anno["img_h"]), mode="xyxy")
target = target.clip_to_image(remove_empty=True)
if self.diver_box_for_vqa:
caption, anno = self.divert_boxes(anno=anno) # will change caption and "tokens_positive"
caption = anno["caption"]
greenlight_span_for_masked_lm_objective = [(0, len(caption))]
new_anno = []
areas = target.area()
for i in range(len(target)):
new_anno_i = {}
new_anno_i["area"] = areas[i]
new_anno_i["iscrowd"] = 0
new_anno_i["image_id"] = idx
new_anno_i["category_id"] = 1 # following vg and others
new_anno_i["id"] = None
new_anno_i['bbox'] = target.bbox[i].numpy().tolist()
new_anno_i["tokens_positive"] = anno["tokens_positive"][i]
new_anno.append(new_anno_i)
anno = new_anno
annotations = {"image_id": idx, "annotations": anno, "caption": caption}
annotations["greenlight_span_for_masked_lm_objective"] = greenlight_span_for_masked_lm_objective
img, annotations = self.prepare(img, annotations, box_format="xyxy")
if self._transforms is not None:
img, target = self._transforms(img, target)
# add additional property
for ann in annotations:
target.add_field(ann, annotations[ann])
# This is the real image_id
image_id = self.get_img_id(idx)
# Can insert additional field into target if needed
sanity_check_target_after_processing(target)
return img, target, idx
def convert_anno_from_yiling_to_ours(self, anno):
flatterned_bboxes = []
flatterned_tokens_positive = []
flatterned_bboxes_scores = []
for i in range(len(anno["bboxes"])):
# i is the index for entity
for j in range(len(anno["bboxes"][i])):
# j is the index for each box
flatterned_bboxes.append(anno["bboxes"][i][j])
flatterned_tokens_positive.append(
anno["tokens_positive"][i]) # Assume this box corresponds to all the token_spans for this entity
flatterned_bboxes_scores.append(anno["scores"][i][j])
anno["bboxes"] = flatterned_bboxes
anno["tokens_positive"] = flatterned_tokens_positive
anno["scores"] = flatterned_bboxes_scores
return anno
def get_raw_image(self, idx):
image, *_ = super(PseudoData, self).__getitem__(idx)
return image
def get_img_id(self, idx):
line_no = self.get_line_no(idx)
if self.label_tsv is not None:
row = self.label_tsv.seek(line_no)
img_id = row[0]
return img_id