|
from tkinter.messagebox import NO |
|
import torch |
|
import json |
|
from collections import defaultdict |
|
from PIL import Image, ImageDraw |
|
from copy import deepcopy |
|
import os |
|
import torchvision.transforms as transforms |
|
import torchvision |
|
from .base_dataset import BaseDataset, check_filenames_in_zipdata, recalculate_box_and_verify_if_valid |
|
from io import BytesIO |
|
import random |
|
|
|
from .tsv import TSVFile |
|
|
|
from io import BytesIO |
|
import base64 |
|
from PIL import Image |
|
import numpy as np |
|
|
|
|
|
def decode_base64_to_pillow(image_b64): |
|
return Image.open(BytesIO(base64.b64decode(image_b64))).convert('RGB') |
|
|
|
def decode_tensor_from_string(arr_str, use_tensor=True): |
|
arr = np.frombuffer(base64.b64decode(arr_str), dtype='float32') |
|
if use_tensor: |
|
arr = torch.from_numpy(arr) |
|
return arr |
|
|
|
def decode_item(item): |
|
item = json.loads(item) |
|
item['image'] = decode_base64_to_pillow(item['image']) |
|
|
|
for anno in item['annos']: |
|
anno['image_embedding_before'] = decode_tensor_from_string(anno['image_embedding_before']) |
|
anno['text_embedding_before'] = decode_tensor_from_string(anno['text_embedding_before']) |
|
anno['image_embedding_after'] = decode_tensor_from_string(anno['image_embedding_after']) |
|
anno['text_embedding_after'] = decode_tensor_from_string(anno['text_embedding_after']) |
|
return item |
|
|
|
def check_unique(images, fields): |
|
for field in fields: |
|
temp_list = [] |
|
for img_info in images: |
|
temp_list.append(img_info[field]) |
|
assert len(set(temp_list)) == len(temp_list), field |
|
|
|
def clean_data(data): |
|
for data_info in data: |
|
data_info.pop("original_img_id", None) |
|
data_info.pop("original_id", None) |
|
data_info.pop("sentence_id", None) |
|
data_info.pop("dataset_name", None) |
|
data_info.pop("data_source", None) |
|
data_info["data_id"] = data_info.pop("id") |
|
|
|
|
|
def clean_annotations(annotations): |
|
for anno_info in annotations: |
|
anno_info.pop("iscrowd", None) |
|
anno_info.pop("category_id", None) |
|
anno_info.pop("area", None) |
|
|
|
anno_info["data_id"] = anno_info.pop("image_id") |
|
|
|
|
|
def draw_box(img, boxes): |
|
draw = ImageDraw.Draw(img) |
|
for box in boxes: |
|
draw.rectangle([box[0], box[1], box[2], box[3]], outline ="red", width=2) |
|
return img |
|
|
|
|
|
def xyhw2xyxy(box): |
|
x0, y0, w, h = box |
|
return [ x0, y0, x0+w, y0+h ] |
|
|
|
|
|
def make_a_sentence(obj_names, clean=False): |
|
|
|
if clean: |
|
obj_names = [ name[:-6] if ("-other" in name) else name for name in obj_names] |
|
|
|
caption = "" |
|
tokens_positive = [] |
|
for obj_name in obj_names: |
|
start_len = len(caption) |
|
caption += obj_name |
|
end_len = len(caption) |
|
caption += ", " |
|
tokens_positive.append( |
|
[[start_len, end_len]] |
|
) |
|
caption = caption[:-2] |
|
|
|
return caption |
|
|
|
|
|
def mask_for_random_drop_text_or_image_feature(masks, random_drop_embedding): |
|
""" |
|
input masks tell how many valid grounding tokens for this image |
|
e.g., 1,1,1,1,0,0,0,0,0,0... |
|
|
|
If random_drop_embedding=both. we will random drop either image or |
|
text feature for each token, |
|
but we always make sure there is at least one feature used. |
|
In other words, the following masks are not valid |
|
(because for the second obj, no feature at all): |
|
image: 1,0,1,1,0,0,0,0,0 |
|
text: 1,0,0,0,0,0,0,0,0 |
|
|
|
if random_drop_embedding=image. we will random drop image feature |
|
and always keep the text one. |
|
|
|
""" |
|
N = masks.shape[0] |
|
|
|
if random_drop_embedding=='both': |
|
temp_mask = torch.ones(2,N) |
|
for i in range(N): |
|
if random.uniform(0, 1) < 0.5: |
|
idx = random.sample([0,1], 1)[0] |
|
temp_mask[idx,i] = 0 |
|
image_masks = temp_mask[0]*masks |
|
text_masks = temp_mask[1]*masks |
|
|
|
if random_drop_embedding=='image': |
|
image_masks = masks*(torch.rand(N)>0.5)*1 |
|
text_masks = masks |
|
|
|
return image_masks, text_masks |
|
|
|
|
|
|
|
|
|
|
|
def project(x, projection_matrix): |
|
""" |
|
x (Batch*768) should be the penultimate feature of CLIP (before projection) |
|
projection_matrix (768*768) is the CLIP projection matrix, which should be weight.data of Linear layer |
|
defined in CLIP (out_dim, in_dim), thus we need to apply transpose below. |
|
this function will return the CLIP feature (without normalziation) |
|
""" |
|
return [email protected](projection_matrix, 0, 1) |
|
|
|
|
|
def inv_project(y, projection_matrix): |
|
""" |
|
y (Batch*768) should be the CLIP feature (after projection) |
|
projection_matrix (768*768) is the CLIP projection matrix, which should be weight.data of Linear layer |
|
defined in CLIP (out_dim, in_dim). |
|
this function will return the CLIP penultimate feature. |
|
|
|
Note: to make sure getting the correct penultimate feature, the input y should not be normalized. |
|
If it is normalized, then the result will be scaled by CLIP feature norm, which is unknown. |
|
""" |
|
return [email protected](torch.linalg.inv(projection_matrix), 0, 1) |
|
|
|
|
|
|
|
|
|
class TSVDataset(BaseDataset): |
|
def __init__(self, |
|
tsv_path, |
|
which_embedder='clip', |
|
which_layer=['after','after'], |
|
prob_use_caption=1, |
|
random_drop_embedding='none', |
|
image_size=256, |
|
min_box_size=0.01, |
|
max_boxes_per_data=8, |
|
max_images=None, |
|
random_crop = False, |
|
random_flip = True, |
|
): |
|
image_root = "a placeholder path as we are using tsv here" |
|
super().__init__(image_root, random_crop, random_flip, image_size) |
|
self.tsv_path = tsv_path |
|
self.which_embedder = which_embedder |
|
self.prob_use_caption = prob_use_caption |
|
self.random_drop_embedding = random_drop_embedding |
|
self.min_box_size = min_box_size |
|
self.max_boxes_per_data = max_boxes_per_data |
|
self.max_images = max_images |
|
|
|
assert which_layer in [ ['after','after'], ['before','after_renorm'], ['before','after_reproject'] ] |
|
assert random_drop_embedding in ['none', 'both', 'image'] |
|
self.which_layer_text = which_layer[0] |
|
self.which_layer_image = which_layer[1] |
|
|
|
|
|
self.projection_matrix = torch.load('projection_matrix') |
|
|
|
|
|
self.tsv_file = TSVFile(self.tsv_path) |
|
|
|
|
|
|
|
if which_embedder == 'bert': |
|
self.embedding_len = 1280 |
|
elif which_embedder == 'clip': |
|
self.embedding_len = 768 |
|
else: |
|
assert False |
|
|
|
def total_images(self): |
|
return len(self) |
|
|
|
def get_item_from_tsv(self, index): |
|
_, item = self.tsv_file[index] |
|
item = decode_item(item) |
|
return item |
|
|
|
|
|
def mapping(self, image_embedding): |
|
if self.which_layer_image == 'after': |
|
|
|
return image_embedding |
|
elif self.which_layer_image == 'after_renorm': |
|
|
|
return image_embedding*28.7 |
|
elif self.which_layer_image == 'after_reproject': |
|
image_embedding = project( image_embedding.unsqueeze(0), self.projection_matrix.T ) |
|
image_embedding = image_embedding.squeeze(0) |
|
image_embedding = image_embedding / image_embedding.norm() |
|
image_embedding = image_embedding * 28.7 |
|
return image_embedding |
|
|
|
|
|
|
|
def __getitem__(self, index): |
|
if self.max_boxes_per_data > 99: |
|
assert False, "Are you sure setting such large number of boxes?" |
|
|
|
raw_item = self.get_item_from_tsv(index) |
|
is_det = raw_item.get('is_det', False) |
|
|
|
out = {} |
|
|
|
|
|
out['id'] = raw_item['data_id'] |
|
image = raw_item['image'] |
|
image_tensor, trans_info = self.transform_image(image) |
|
out["image"] = image_tensor |
|
|
|
|
|
|
|
|
|
annos = raw_item['annos'] |
|
|
|
areas = [] |
|
all_boxes = [] |
|
all_masks = [] |
|
all_text_embeddings = [] |
|
all_image_embeddings = [] |
|
if is_det: |
|
all_category_names = [] |
|
|
|
text_embedding_name = 'text_embedding_before' if self.which_layer_text == 'before' else 'text_embedding_after' |
|
image_embedding_name = 'image_embedding_after' |
|
|
|
for anno in annos: |
|
x, y, w, h = anno['bbox'] |
|
valid, (x0, y0, x1, y1) = recalculate_box_and_verify_if_valid(x, y, w, h, trans_info, self.image_size, self.min_box_size) |
|
|
|
if valid: |
|
areas.append( (x1-x0)*(y1-y0) ) |
|
all_boxes.append( torch.tensor([x0,y0,x1,y1]) / self.image_size ) |
|
all_masks.append(1) |
|
all_text_embeddings.append(anno[text_embedding_name]) |
|
all_image_embeddings.append( self.mapping(anno[image_embedding_name]) ) |
|
if is_det: |
|
all_category_names.append(anno["category_name"]) |
|
|
|
|
|
wanted_idxs = torch.tensor(areas).sort(descending=True)[1] |
|
wanted_idxs = wanted_idxs[0:self.max_boxes_per_data] |
|
|
|
boxes = torch.zeros(self.max_boxes_per_data, 4) |
|
masks = torch.zeros(self.max_boxes_per_data) |
|
text_embeddings = torch.zeros(self.max_boxes_per_data, self.embedding_len) |
|
image_embeddings = torch.zeros(self.max_boxes_per_data, self.embedding_len) |
|
if is_det: |
|
category_names = [] |
|
for i, idx in enumerate(wanted_idxs): |
|
boxes[i] = all_boxes[idx] |
|
masks[i] = all_masks[idx] |
|
text_embeddings[i] = all_text_embeddings[idx] |
|
image_embeddings[i] = all_image_embeddings[idx] |
|
if is_det: |
|
category_names.append(all_category_names[idx]) |
|
|
|
if self.random_drop_embedding != 'none': |
|
image_masks, text_masks = mask_for_random_drop_text_or_image_feature(masks, self.random_drop_embedding) |
|
else: |
|
image_masks = masks |
|
text_masks = masks |
|
|
|
|
|
out["boxes"] = boxes |
|
out["masks"] = masks |
|
out["image_masks"] = image_masks |
|
out["text_masks"] = text_masks |
|
out["text_embeddings"] = text_embeddings |
|
out["image_embeddings"] = image_embeddings |
|
|
|
|
|
|
|
|
|
if random.uniform(0, 1) < self.prob_use_caption: |
|
if is_det: |
|
out["caption"] = make_a_sentence(category_names) |
|
else: |
|
out["caption"] = raw_item["caption"] |
|
else: |
|
out["caption"] = "" |
|
|
|
return out |
|
|
|
|
|
|
|
def __len__(self): |
|
return len(self.tsv_file) |
|
|
|
|
|
|