|
import random |
|
import os |
|
import time |
|
import json |
|
from tqdm import trange |
|
|
|
from PIL import Image, ImageFile |
|
import copy |
|
|
|
|
|
import cv2 |
|
import base64 |
|
import numpy as np |
|
import pyarrow as pa |
|
import logging |
|
|
|
import glob |
|
from io import BytesIO |
|
import jsonlines |
|
|
|
import torch |
|
from torch.utils.data import Dataset |
|
from uniperceiver.functional import read_np, dict_as_tensor, boxes_to_locfeats |
|
from collections import defaultdict |
|
|
|
from uniperceiver.datasets.zipreader import ZipReader |
|
import errno |
|
from uniperceiver.datasets.circular_cached_loader import CircularCachedInputIterator |
|
|
|
from uniperceiver.tokenization import ClipTokenizer |
|
|
|
from ..build import DATASETS_REGISTRY |
|
|
|
from uniperceiver.config import configurable |
|
import pickle |
|
from uniperceiver.utils import comm |
|
|
|
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD |
|
from timm.data import create_transform |
|
from torchvision import transforms |
|
from uniperceiver.datasets.custom_transforms import clip_transforms |
|
|
|
__all__ = ["ImageTextPairDataset"] |
|
|
|
memorycache = False |
|
|
|
|
|
|
|
def makedirsExist(path): |
|
try: |
|
os.makedirs(path, exist_ok=True) |
|
except OSError as e: |
|
if e.errno == errno.EEXIST: |
|
print('Directory not created.') |
|
else: |
|
raise |
|
|
|
def _smart_join(str_or_list, delim): |
|
if isinstance(str_or_list, str): |
|
return str_or_list |
|
else: |
|
return delim.join(str_or_list) |
|
|
|
@DATASETS_REGISTRY.register() |
|
class ImageTextPairDataset(Dataset): |
|
|
|
@configurable |
|
def __init__(self, cfg, stage, ann_file, image_set, root_path, data_path, s3_path, |
|
feats_folder, |
|
dataset_name, |
|
data_percentage, |
|
seq_per_img, |
|
tokenizer, tokenizer_name, |
|
seq_len=64, |
|
mask_prob=(0.15, 0.8), repl_prob=0.1, |
|
task_type=True, |
|
transform=None, test_mode=False, |
|
zip_mode=False, |
|
cache_mode=False, |
|
cache_origin_image=False, |
|
cache_local_rank=0, cache_local_size=1, |
|
circular_cache_mode=False, |
|
ignore_db_cache=True, |
|
aspect_grouping=False, |
|
use_ceph=False, |
|
tcs_conf_path='', |
|
random_caption=False, |
|
max_length=-1, |
|
as_numpy_as_possible=False, |
|
use_node_distirbuted_sampler=False, |
|
**kwargs): |
|
""" |
|
Conceptual Captions Dataset |
|
|
|
:param ann_file: annotation jsonl file |
|
:param image_set: image folder name, e.g., 'vcr1images' |
|
:param root_path: root path to cache database loaded from annotation file |
|
:param data_path: path to vcr dataset |
|
:param transform: transform |
|
:param test_mode: test mode means no labels available |
|
:param zip_mode: reading images and metadata in zip archive |
|
:param cache_mode: cache whole dataset to RAM first, then __getitem__ read them from RAM |
|
:param ignore_db_cache: ignore previous cached database, reload it from annotation file |
|
:param tokenizer: default is BertTokenizer from pytorch_pretrained_bert |
|
:param aspect_grouping: whether to group images via their aspect |
|
:param kwargs: |
|
""" |
|
super(ImageTextPairDataset, self).__init__() |
|
|
|
|
|
assert not test_mode |
|
assert not (cache_mode and circular_cache_mode) |
|
|
|
self.mask_prob = mask_prob |
|
self.repl_prob = repl_prob |
|
self.seq_len = seq_len |
|
self.task_type = task_type |
|
self.cfg = cfg |
|
self.stage = stage |
|
self.dataset_name = dataset_name |
|
self.feats_folder = feats_folder |
|
self.seq_per_img = seq_per_img |
|
assert self.seq_per_img == 1 |
|
self.data_percentage = data_percentage |
|
|
|
|
|
self.data_path = data_path |
|
self.root_path = root_path |
|
self.ann_file = ann_file |
|
self.image_set = image_set |
|
self.transform = transform |
|
self.test_mode = test_mode |
|
self.zip_mode = zip_mode |
|
self.cache_mode = cache_mode |
|
self.cache_origin_image = cache_origin_image |
|
self.cache_local_rank = cache_local_rank |
|
self.cache_local_size = cache_local_size |
|
self.circular_cache_mode = circular_cache_mode |
|
self.ignore_db_cache = ignore_db_cache |
|
self.aspect_grouping = aspect_grouping |
|
self.cache_dir = os.path.join(self.data_path, 'cache') |
|
self.use_node_distirbuted_sampler = (use_node_distirbuted_sampler or cache_mode) |
|
if not os.path.exists(self.cache_dir): |
|
makedirsExist(self.cache_dir) |
|
|
|
self.initialized = False |
|
|
|
self.tokenizer = tokenizer |
|
self.tokenizer_name = tokenizer_name |
|
self.use_clip_tokenizer = tokenizer_name == 'clip' |
|
|
|
self.zipreader = ZipReader() |
|
|
|
self.use_ceph = use_ceph |
|
self.tcs_conf_path = tcs_conf_path |
|
if use_ceph: |
|
self.data_path = s3_path |
|
from uniperceiver.datasets.tcsreader import TCSLoader |
|
self.tcs_loader = TCSLoader(tcs_conf_path) |
|
else: |
|
self.data_path = feats_folder |
|
|
|
if comm.is_main_process(): |
|
print(f"data_path for Dataset {self.dataset_name} with task {self.task_type}: {self.data_path}") |
|
|
|
self.random_caption = random_caption |
|
|
|
|
|
if self.dataset_name == 'VG': |
|
self.load_VG(self.cfg) |
|
elif self.dataset_name in ['MSCOCO', 'FLICKR']: |
|
self.load_COCO_flickr(self.cfg) |
|
else: |
|
self.load_database() |
|
|
|
if self.circular_cache_mode: |
|
chunk_dir = os.path.join(self.data_path, '{}_chunks'.format(image_set)) |
|
self.chunk_path_list = glob.glob(os.path.join(chunk_dir, '*.pa')) |
|
|
|
if self.aspect_grouping: |
|
assert False, "not support aspect grouping currently!" |
|
self.group_ids = self.group_aspect(self.database) |
|
|
|
self.as_numpy_as_possible = as_numpy_as_possible |
|
self.max_length = max_length |
|
|
|
self.task_info = { |
|
'task_type' : self.cfg.DATASETS.TASK_TYPE, |
|
'dataset_name' : self.cfg.DATASETS.DATASET_NAME, |
|
'batch_size' : self.cfg.DATALOADER.TRAIN_BATCH_SIZE if self.stage == 'train' else self.cfg.DATALOADER.TEST_BATCH_SIZE, |
|
'sampling_weight': self.cfg.DATALOADER.SAMPLING_WEIGHT |
|
} |
|
|
|
@classmethod |
|
def from_config(cls, cfg, stage: str = "train"): |
|
|
|
if 'SLURM_PROCID' in os.environ: |
|
tcs_conf_path = cfg.DATALOADER.get("TCS_CONF_PATH", "petreloss.config") |
|
else: |
|
|
|
tcs_conf_path = "slurm_tools/petreloss_local.config" |
|
anno_filename = cfg.DATALOADER.get("ANNO_FILENAME", "train_spacy.json") |
|
if cfg.DATALOADER.USE_CEPH and cfg.DATALOADER.S3_ANNO_FOLDER is not None: |
|
anno_folder = cfg.DATALOADER.S3_ANNO_FOLDER |
|
else: |
|
anno_folder = cfg.DATALOADER.ANNO_FOLDER |
|
if cfg.DATASETS.DATASET_NAME == 'MSCOCO': |
|
anno_files = { |
|
"train": [os.path.join(anno_folder, "captions_train113k.json"), os.path.join(anno_folder, "captions_val5k.json")], |
|
|
|
"test": os.path.join(anno_folder, "captions_test5k.json") |
|
} |
|
elif cfg.DATASETS.DATASET_NAME == 'FLICKR': |
|
anno_files = { |
|
"train": [os.path.join(anno_folder, "all_data_final_train_2014.jsonline"), os.path.join(anno_folder, "all_data_final_val_set0_2014.jsonline")], |
|
|
|
|
|
"test": os.path.join(anno_folder, "all_data_final_test_set0_2014.jsonline") |
|
} |
|
else: |
|
anno_files = { |
|
"train": os.path.join(anno_folder, anno_filename), |
|
"val": os.path.join(anno_folder, anno_filename), |
|
"test": os.path.join(anno_folder, anno_filename), |
|
} |
|
if getattr(cfg.DATALOADER, 'TRANSFORM', None) == 'clip_transforms': |
|
transform = clip_transforms(stage, img_size=cfg.MODEL.IMG_INPUT_SIZE) |
|
else: |
|
|
|
transform = build_transform(is_train=(stage=='train')) |
|
|
|
ret = { |
|
'cfg': cfg, |
|
'stage': stage, |
|
'ann_file' : anno_files[stage], |
|
"seq_per_img": 1, |
|
'image_set' : stage, |
|
'root_path' : cfg.DATALOADER.ANNO_FOLDER, |
|
'data_path' : cfg.DATALOADER.FEATS_FOLDER, |
|
's3_path': cfg.DATALOADER.S3_PATH, |
|
'feats_folder': cfg.DATALOADER.FEATS_FOLDER, |
|
'dataset_name': cfg.DATASETS.DATASET_NAME, |
|
"data_percentage": cfg.DATALOADER.DATA_PERCENTAGE, |
|
'seq_len': cfg.MODEL.MAX_SEQ_LEN, |
|
'task_type': cfg.DATASETS.TASK_TYPE, |
|
'transform': transform, |
|
'zip_mode': cfg.DATALOADER.ZIP_MODE, |
|
"cache_mode": cfg.DATALOADER.CACHE_MODE, |
|
'cache_origin_image': cfg.DATALOADER.CACHE_ORIGIN_IMAGE, |
|
"cache_local_rank": comm.get_local_rank(), |
|
"cache_local_size": comm.get_local_size(), |
|
"circular_cache_mode": cfg.DATALOADER.CIRCULAR_CACHE_MODE, |
|
"use_ceph": getattr(cfg.DATALOADER, 'USE_CEPH', False), |
|
"tcs_conf_path": tcs_conf_path, |
|
"random_caption": cfg.DATALOADER.RANDOM_CAPTION, |
|
"as_numpy_as_possible": cfg.DATALOADER.AS_NUMPY_AS_POSSIBLE, |
|
"use_node_distirbuted_sampler": cfg.DATALOADER.SAMPLER == 'NodeDistributed', |
|
'tokenizer': ClipTokenizer(), |
|
'tokenizer_name': "clip", |
|
|
|
} |
|
|
|
|
|
|
|
return ret |
|
|
|
def _init_memcached(self): |
|
pass |
|
|
|
def load_img_info(self, anno_file): |
|
id2path = {} |
|
with jsonlines.open(anno_file) as reader: |
|
for annotation in reader: |
|
image_id = annotation["id"] |
|
id2path[image_id] = annotation["img_path"] |
|
|
|
return id2path |
|
|
|
def load_COCO_flickr(self, cfg): |
|
|
|
self.idx2name = dict() |
|
self.name2idx = dict() |
|
if isinstance(self.ann_file, list): |
|
imageinfo = list() |
|
self.id2path = dict() |
|
for anno_file in self.ann_file: |
|
if self.dataset_name == 'MSCOCO': |
|
imageinfo.extend(json.load(open(anno_file))["images"]) |
|
else: |
|
id2path = self.load_img_info(anno_file) |
|
self.id2path.update(id2path) |
|
else: |
|
if self.dataset_name == 'MSCOCO': |
|
imageinfo = json.load(open(self.ann_file))["images"] |
|
else: |
|
self.id2path = self.load_img_info(self.ann_file) |
|
|
|
if self.dataset_name == 'MSCOCO': |
|
for info in imageinfo: |
|
self.idx2name[info['id']] = { |
|
"split": info['file_path'], |
|
"name": info['file_name']} |
|
self.name2idx[info['file_name']] = info['id'] |
|
|
|
if self.stage == "test": |
|
if self.dataset_name == 'MSCOCO': |
|
cache_path = os.path.join( |
|
os.path.dirname(self.ann_file), "cache", |
|
"mscoco_caption_w_testcap_%s.pkl" % ( self.stage) |
|
) |
|
else: |
|
cache_path = os.path.join( |
|
self.root_path, "cache", |
|
"RetrievalFlickr30k_raw_%s_%s_%d.pkl" % (self.tokenizer_name, self.stage, self.seq_len) |
|
) |
|
|
|
if not os.path.exists(os.path.dirname(cache_path)): |
|
os.makedirs(os.path.dirname(cache_path)) |
|
if not os.path.exists(cache_path): |
|
datalist = self.load_raw_data(cfg, self.ann_file) |
|
pickle.dump(datalist, open(cache_path, "wb")) |
|
datalist = pickle.load(open(cache_path, "rb")) |
|
else: |
|
datalist = list() |
|
assert self.stage == "train", "no validation now" |
|
for i, stage in enumerate(["train", "val"]): |
|
if self.dataset_name == 'MSCOCO': |
|
cache_path = os.path.join( |
|
os.path.dirname(self.ann_file[i]), "cache", |
|
"mscoco_caption_w_testcap_%s.pkl" % ( stage) |
|
) |
|
else: |
|
cache_path = os.path.join( |
|
self.root_path, "cache", |
|
"RetrievalFlickr30k_raw_%s_%s_%d.pkl" % (self.tokenizer_name, stage, self.seq_len) |
|
) |
|
if not os.path.exists(os.path.dirname(cache_path)): |
|
os.makedirs(os.path.dirname(cache_path)) |
|
if not os.path.exists(cache_path): |
|
datalist_part = self.load_raw_data(cfg, self.ann_file[i]) |
|
pickle.dump(datalist_part, open(cache_path, "wb")) |
|
datalist_part = pickle.load(open(cache_path, "rb")) |
|
datalist.extend(datalist_part) |
|
|
|
if self.data_percentage < 1.0 and self.stage == 'train': |
|
datalist = random.sample(datalist, k = int(self.data_percentage* len(datalist) ) ) |
|
|
|
self.database = pa.array(datalist) |
|
if comm.is_main_process(): |
|
import sys |
|
print(f"!!! Dataset {self.dataset_name} with task {self.task_type}:") |
|
print('!!! length of _temp_list: ', len(datalist)) |
|
print('!!! size of _temp_list: ', sys.getsizeof(datalist)) |
|
print('!!! size of pa database: ', sys.getsizeof(self.database)) |
|
del datalist |
|
|
|
def load_raw_data(self, cfg, anno_file): |
|
datalist = [] |
|
if self.dataset_name == 'MSCOCO': |
|
annoinfo = json.load(open(anno_file)) |
|
captions_train = sorted( annoinfo['annotations'], key=lambda x: x['id']) |
|
image_caption_info = defaultdict(list) |
|
for cap_info in captions_train: |
|
image_caption_info[cap_info['image_id']].append(cap_info['caption']) |
|
|
|
for im_id, caps in image_caption_info.items(): |
|
datalist.append( |
|
{ |
|
"image_id": im_id, |
|
"captions": caps, |
|
} |
|
) |
|
else: |
|
with jsonlines.open(anno_file) as reader: |
|
for annotation in reader: |
|
sentences = annotation["sentences"] |
|
image_id = annotation["id"] |
|
datalist.append({ "image_id": image_id, "imagename": annotation["img_path"], "captions": sentences }) |
|
|
|
|
|
return datalist |
|
|
|
def load_VG(self, cfg): |
|
cache_path = os.path.join( |
|
os.path.dirname(self.ann_file), "cache", |
|
"vg_caption_spe_raw_%s.pkl" % (self.stage) |
|
) |
|
if not os.path.exists(os.path.dirname(cache_path)): |
|
os.makedirs(os.path.dirname(cache_path)) |
|
if not os.path.exists(cache_path): |
|
_temp_list = [] |
|
if self.use_ceph: |
|
anno_file = os.path.join('s3://visual_genome/annotations', os.path.basename(self.ann_file)) |
|
annotations = json.load(BytesIO(self.tcs_loader.client.get(anno_file))) |
|
else: |
|
annotations = json.load(open(self.ann_file)) |
|
|
|
for im_id, annoinfo in annotations['phrase'].items(): |
|
_temp_list.append( |
|
{ |
|
"image_id": im_id, |
|
"captions": annoinfo, |
|
'path': annotations['subset'][im_id], |
|
} |
|
) |
|
pickle.dump(_temp_list, open(cache_path, "wb")) |
|
else: |
|
_temp_list = pickle.load(open(cache_path, "rb")) |
|
self.database = pa.array(_temp_list) |
|
|
|
if comm.is_main_process(): |
|
import sys |
|
print(f"!!! Dataset {self.dataset_name} with task {self.task_type}:") |
|
print('!!! length of _temp_list: ', len(_temp_list)) |
|
print('!!! size of _temp_list: ', sys.getsizeof(_temp_list)) |
|
print('!!! size of pa database: ', sys.getsizeof(self.database)) |
|
del _temp_list |
|
|
|
def load_database(self): |
|
|
|
if self.random_caption: |
|
cache_filename = 'spe_cache_random_caption_' + os.path.basename(self.ann_file).replace('.', "_") + "_" + str(self.cache_local_rank) + "_" + str(self.cache_local_size) + '.pkl' |
|
else: |
|
cache_filename = 'spe_cache_' + os.path.basename(self.ann_file).replace('.', "_") + "_" + str(self.cache_local_rank) + "_" + str(self.cache_local_size) + '.pkl' |
|
|
|
|
|
cache_file = os.path.join(self.cache_dir, cache_filename) |
|
|
|
if not os.path.exists((cache_file)): |
|
_temp_list = [] |
|
self.img_path_to_index = {} |
|
if self.use_ceph: |
|
f = BytesIO(self.tcs_loader.client.get(self.ann_file)) |
|
else: |
|
f = open(self.ann_file, 'r') |
|
if self.dataset_name == 'SBU': |
|
annofile = json.load(f) |
|
else: |
|
annofile = f |
|
for i, l in enumerate(annofile): |
|
if self.use_node_distirbuted_sampler and ((i % self.cache_local_size) != self.cache_local_rank): |
|
_temp_list.append(None) |
|
continue |
|
l = l.strip() |
|
if (l == ''): |
|
continue |
|
if self.dataset_name == 'SBU': |
|
self.img_path_to_index[l] = i |
|
_temp_list.append([l, annofile[l]]) |
|
else: |
|
_data = json.loads(l) |
|
if not self.zip_mode: |
|
_data['image'] = _data['image'].replace('.zip@', '') |
|
self.img_path_to_index[_data['image']] = i |
|
if self.random_caption: |
|
_temp_list.append([_data['image'], _smart_join(_data['caption'], '\t'), _data['title'], _data['description']]) |
|
else: |
|
_temp_list.append([_data['image'], _smart_join(_data['caption'], '\t')]) |
|
|
|
f.close() |
|
|
|
|
|
pickle.dump({ |
|
"path_to_indext": self.img_path_to_index, |
|
"temp_list": _temp_list, |
|
}, open(cache_file, "wb"), protocol=4) |
|
else: |
|
cachedata = pickle.load(open(cache_file, "rb")) |
|
self.img_path_to_index, _temp_list = cachedata['path_to_indext'], cachedata['temp_list'] |
|
|
|
self.database = pa.array(_temp_list) |
|
if comm.is_main_process(): |
|
import sys |
|
print(f"!!! Dataset {self.dataset_name} with task {self.task_type}:") |
|
print('!!! length of _temp_list: ', len(_temp_list)) |
|
print('!!! size of _temp_list: ', sys.getsizeof(_temp_list)) |
|
print('!!! size of pa database: ', sys.getsizeof(self.database)) |
|
del _temp_list |
|
|
|
@property |
|
def data_names(self): |
|
return ['image', 'im_info', 'text', 'mlm_labels'] |
|
|
|
def __getitem__(self, index): |
|
for i_try in range(100): |
|
try: |
|
image_path = None |
|
image_id = None |
|
idb = None |
|
if self.dataset_name in ['VG', 'MSCOCO', 'FLICKR']: |
|
self.dataset_dict = self.database[index].as_py() |
|
image_id = self.dataset_dict['image_id'] |
|
if self.dataset_name == 'VG': |
|
imagepath = self.dataset_dict['path'] |
|
image_path = os.path.join(self.data_path, imagepath) |
|
elif self.dataset_name == 'FLICKR': |
|
image_path = os.path.join(self.data_path, self.id2path[image_id]) |
|
else: |
|
image_split = self.idx2name[int(image_id)]['split'] |
|
image_name = self.idx2name[int(image_id)]['name'] |
|
image_path = os.path.join(self.data_path, image_split, image_name) |
|
else: |
|
_idb = self.database[index] |
|
idb = {'image': str(_idb[0]).strip('./'), 'caption': str(_idb[1]).split('\t')} |
|
if self.random_caption: |
|
idb['title'] = [_idb[2].as_py()] |
|
idb['description'] = [_idb[3].as_py()] |
|
return self._data_transform(idb, index=index, as_numpy_as_possible=self.as_numpy_as_possible, image_path=image_path, image_id=image_id) |
|
except Exception as e: |
|
print( |
|
"Failed to load image from idb {} with error {} ; trial {};".format( |
|
self.database[index], e, i_try |
|
) |
|
) |
|
index = (index + 1)%len(self.database) |
|
while (self.database[index].as_py() is None): |
|
index = (index + 1)%len(self.database) |
|
continue |
|
|
|
def _data_transform(self, idb, index=None, as_numpy_as_possible=False, fail_image_fill=(0.0, 0.0, 0.0), image_path=None, image_id=None): |
|
|
|
if self.dataset_name in ['VG', 'MSCOCO', 'FLICKR']: |
|
image = self._load_image(image_path) |
|
else: |
|
if index is None: |
|
index = self.img_path_to_index[idb['image']] |
|
|
|
|
|
image = self.get_image(idb, index=index) |
|
if isinstance(image, Image.Image): |
|
w0, h0 = image.size |
|
elif isinstance(image, np.ndarray): |
|
h0, w0, c_ = image.shape |
|
assert c_ == 3 |
|
else: |
|
raise NotImplementedError |
|
|
|
if self.transform is not None: |
|
image = self.transform(image) |
|
|
|
if image_id is not None: |
|
img_sample_info = { |
|
'id': image_id, |
|
'path': image_path |
|
} |
|
else: |
|
img_sample_info = { |
|
'id': index |
|
} |
|
ret = { |
|
'input_sample': [{ |
|
'data' : image, |
|
'invalid_mask': None, |
|
'modality' : 'image', |
|
'data_type' : 'input', |
|
'sample_info' : copy.deepcopy(img_sample_info) |
|
}] |
|
} |
|
|
|
self.target_set = self.cfg.DATASETS.TARGET_SET |
|
|
|
mlm_labels = None |
|
u_mask_type = None |
|
if self.task_type == 'image_caption' and self.stage != 'train': |
|
ret.update({ |
|
'target_set': copy.deepcopy(self.target_set), |
|
'target_sample': [], |
|
'target_idx': [], |
|
'task_info': copy.deepcopy(self.task_info) |
|
}) |
|
dict_as_tensor(ret) |
|
return ret |
|
|
|
if self.task_type =='image_retrieval' and self.stage != 'train': |
|
captions = [caption + " <|endoftext|>" for caption in self.dataset_dict['captions']] |
|
caption_tokens_raw = [ self.tokenizer.encode(caption) for caption in captions] |
|
if self.dataset_name in ['MSCOCO', 'FLICKR']: |
|
caption_tokens = [ caption_token[:(self.seq_len - 1)] + [caption_token[-1]] |
|
if len(caption_token) > self.seq_len else caption_token |
|
for caption_token in caption_tokens_raw ] |
|
return self.package_item(ret, caption_tokens, mlm_labels, u_mask_type) |
|
|
|
|
|
if self.random_caption: |
|
if len(idb['title']) == 0: |
|
caption = idb['description'] |
|
if len(self.tokenizer.encode(' '.join(caption))) == 0: |
|
caption = ['image'] |
|
else: |
|
if random.random() < 0.5: |
|
caption = idb['title'] |
|
if len(self.tokenizer.encode(' '.join(caption))) == 0: |
|
caption = idb['description'] |
|
if len(self.tokenizer.encode(' '.join(caption))) == 0: |
|
caption = ['image'] |
|
else: |
|
caption = idb['description'] |
|
if len(self.tokenizer.encode(' '.join(caption))) == 0: |
|
caption = idb['title'] |
|
if len(self.tokenizer.encode(' '.join(caption))) == 0: |
|
caption = ['image'] |
|
else: |
|
if self.dataset_name == 'VG': |
|
caption = random.sample(self.dataset_dict['captions'], self.seq_per_img)[0] |
|
while len(caption) < 1: |
|
caption = random.sample(self.dataset_dict['captions'], self.seq_per_img)[0] |
|
if caption and caption.lower()[-1] in "qwertyuiopasdfghjklzxcvbnm1234567890": |
|
caption = caption + "." |
|
elif self.dataset_name in ['MSCOCO', 'FLICKR']: |
|
caption = random.sample(self.dataset_dict['captions'], self.seq_per_img)[0] |
|
else: |
|
caption = idb['caption'] |
|
if caption and caption[-1] and caption[-1].lower()[-1] in "1234567890qwertyuiopasdfghjklzxcvbnm": |
|
caption.append(".") |
|
|
|
|
|
|
|
for i_, tok in enumerate(caption): |
|
if '<PERSON>' in tok: |
|
tok = tok.replace('<PERSON>', 'person') |
|
caption[i_] = tok |
|
|
|
if self.task_type == 'mlm': |
|
u_mask_type = 1 |
|
elif self.task_type == 'image_caption': |
|
u_mask_type = 0 |
|
|
|
if self.dataset_name in ['VG', 'MSCOCO', 'FLICKR']: |
|
caption = caption + " <|endoftext|>" |
|
else: |
|
caption = caption + ["<|endoftext|>"] |
|
|
|
if self.task_type=='mlm': |
|
if self.dataset_name in ['VG', 'MSCOCO', 'FLICKR']: |
|
caption_tokens = self.tokenizer.basic_tokenize(caption) |
|
else: |
|
if self.use_clip_tokenizer: |
|
caption_tokens = self.tokenizer.basic_tokenize(' '.join(caption)) |
|
else: |
|
caption_tokens = self.tokenizer.basic_tokenizer.tokenize(' '.join(caption)) |
|
caption_tokens, mlm_labels = self.random_word_wwm(caption_tokens) |
|
elif self.task_type == 'image_caption': |
|
if self.dataset_name in ['VG', 'MSCOCO', 'FLICKR']: |
|
caption_tokens = self.tokenizer.encode(caption) |
|
mlm_labels = self.tokenizer.encode("<|spe|>")*len(caption_tokens) |
|
else: |
|
|
|
caption_tokens = self.tokenizer.encode(' '.join(caption)) |
|
mlm_labels = self.tokenizer.encode("<|spe|>")*len(caption_tokens) |
|
else: |
|
if self.dataset_name in ['VG', 'MSCOCO', 'FLICKR']: |
|
caption_tokens = self.tokenizer.encode(caption) |
|
else: |
|
caption_tokens = self.tokenizer.encode(' '.join(caption)) |
|
mlm_labels = [-1] * len(caption_tokens) |
|
|
|
text = caption_tokens |
|
|
|
|
|
if len(text) > self.seq_len: |
|
|
|
text_len_keep = self.seq_len |
|
text = text[:(text_len_keep - 1)] + [text[-1]] |
|
if self.task_type=='image_caption' or self.task_type=='mlm': |
|
mlm_labels = mlm_labels[:(text_len_keep - 1)] + [mlm_labels[-1]] |
|
|
|
|
|
if as_numpy_as_possible: |
|
text = np.array(text) |
|
mlm_labels = np.array(mlm_labels) |
|
|
|
return self.package_item(ret, text, mlm_labels, u_mask_type) |
|
|
|
|
|
|
|
|
|
def package_item(self, ret, text, mlm_labels, u_mask_type): |
|
|
|
|
|
if self.task_type == 'image_retrieval': |
|
if self.stage == 'train': |
|
ret.update({ |
|
'target_sample': [{ |
|
'data' : [np.array(text, dtype=np.int64)], |
|
'modality' : 'text', |
|
'data_type' : 'target', |
|
'invalid_mask': None, |
|
'sample_info' : {} |
|
}], |
|
'target_idx' : [], |
|
'target_set' : [], |
|
'task_info' : copy.deepcopy(self.task_info) |
|
}) |
|
else: |
|
image_id = ret['input_sample'][0]['sample_info']['id'] |
|
ret['input_sample'][0]['sample_info']['id'] = (image_id, [image_id] * len(text)) |
|
ret.update({ |
|
'target_sample': [{ |
|
'data': [np.array(single_text, dtype=np.int64) for single_text in text], |
|
'modality': 'text', |
|
'invalid_mask': None, |
|
'data_type': 'target', |
|
'sample_info': { |
|
'sample_alone': True, |
|
} |
|
}], |
|
'target_idx': [], |
|
'target_set': [], |
|
'task_info': |
|
copy.deepcopy(self.task_info) |
|
}) |
|
|
|
elif self.task_type == 'mlm': |
|
|
|
raise NotImplementedError('no needed for masked language modeling when given image now.') |
|
|
|
elif self.task_type == 'image_caption': |
|
source = np.array(text, dtype=np.int64) |
|
source2 = np.array(mlm_labels, dtype=np.int64) |
|
ret['input_sample'].append({ |
|
'data': [source, source2], |
|
'invalid_mask': None, |
|
'modality': 'text', |
|
'data_type': 'input', |
|
'sample_info': { |
|
'text_spe_cat': True, |
|
} |
|
}) |
|
ret.update({ |
|
'target_sample': [], |
|
'target_idx' : [np.array(text, dtype=np.int64)], |
|
'target_set' : copy.deepcopy(self.target_set), |
|
'task_info' : copy.deepcopy(self.task_info) |
|
}) |
|
else: |
|
raise NotImplementedError |
|
|
|
dict_as_tensor(ret) |
|
|
|
return ret |
|
|
|
def random_word_wwm(self, tokens): |
|
output_tokens = [] |
|
output_label = [] |
|
|
|
for i, token in enumerate(tokens): |
|
if self.use_clip_tokenizer: |
|
sub_tokens = self.tokenizer.encode_basic_tokenized_token(token) |
|
else: |
|
sub_tokens = self.tokenizer.wordpiece_tokenizer.tokenize(token) |
|
prob = random.random() |
|
|
|
if prob < 0.15: |
|
prob /= 0.15 |
|
|
|
|
|
if prob < 0.8: |
|
for sub_token in sub_tokens: |
|
if self.use_clip_tokenizer: |
|
output_tokens.append(self.tokenizer.encoder["<|spe|>"]) |
|
else: |
|
output_tokens.append("[MASK]") |
|
|
|
elif prob < 0.9: |
|
for sub_token in sub_tokens: |
|
if self.use_clip_tokenizer: |
|
output_tokens.append(random.choice(list(range(len(self.tokenizer.encoder))))) |
|
else: |
|
output_tokens.append(random.choice(list(self.tokenizer.vocab.keys()))) |
|
|
|
else: |
|
for sub_token in sub_tokens: |
|
output_tokens.append(sub_token) |
|
|
|
|
|
for sub_token in sub_tokens: |
|
if self.use_clip_tokenizer: |
|
output_label.append(sub_token) |
|
else: |
|
try: |
|
output_label.append(self.tokenizer.vocab[sub_token]) |
|
except KeyError: |
|
|
|
output_label.append(self.tokenizer.vocab["[UNK]"]) |
|
logging.warning("Cannot find sub_token '{}' in vocab. Using [UNK] insetad".format(sub_token)) |
|
else: |
|
for sub_token in sub_tokens: |
|
|
|
output_tokens.append(sub_token) |
|
output_label.append(-1) |
|
|
|
return output_tokens, output_label |
|
|
|
def cache_images(self, resize_to=(224, 224)): |
|
assert not self.zip_mode |
|
encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), 95] |
|
barray = bytearray() |
|
cursor = [] |
|
c_ = 0 |
|
for i in trange(len(self.database)): |
|
if i % self.cache_local_size != self.cache_local_rank: |
|
cursor.append(c_) |
|
continue |
|
idb = self.database[i] |
|
if self.cache_origin_image: |
|
try: |
|
with open(os.path.join(self.data_path, idb['image']), 'rb') as f: |
|
im = f.read() |
|
except: |
|
print("Failed to cache image {}, cache zero byte!".format(idb['image'])) |
|
im = bytes() |
|
else: |
|
im = cv2.imread(os.path.join(self.data_path, idb['image']), cv2.IMREAD_COLOR) |
|
if im is None: |
|
print("Failed to cache image {}, cache zero image!".format(idb['image'])) |
|
w, h = resize_to |
|
im = np.zeros((h, w, 3), dtype=np.uint8) |
|
else: |
|
im = cv2.resize(im, resize_to) |
|
_, im = cv2.imencode('.jpg', im, encode_param) |
|
im = im.tobytes() |
|
barray += im |
|
cursor.append(c_) |
|
c_ += len(im) |
|
cursor.append(c_) |
|
|
|
return barray, cursor |
|
|
|
def get_image(self, idb, index=None): |
|
if index is None: |
|
index = self.img_path_to_index[idb['image']] |
|
if self.circular_cache_mode: |
|
im = idb['image_augmented'] |
|
else: |
|
im = self._load_image(os.path.join(self.data_path, idb['image'])) |
|
return im |
|
|
|
@staticmethod |
|
def b64_decode(string): |
|
return base64.decodebytes(string.encode()) |
|
|
|
@staticmethod |
|
def group_aspect(database): |
|
print('grouping aspect...') |
|
t = time.time() |
|
|
|
|
|
widths = torch.as_tensor([idb['width'] for idb in database]) |
|
heights = torch.as_tensor([idb['height'] for idb in database]) |
|
|
|
|
|
group_ids = torch.zeros(len(database)) |
|
horz = widths >= heights |
|
vert = 1 - horz |
|
group_ids[horz] = 0 |
|
group_ids[vert] = 1 |
|
|
|
print('Done (t={:.2f}s)'.format(time.time() - t)) |
|
|
|
return group_ids |
|
|
|
def __len__(self): |
|
length = len(self.database) |
|
if self.max_length > 0: |
|
length = min(self.max_length, length) |
|
return length |
|
|
|
|
|
def _load_image(self, path): |
|
if '.zip@' in path: |
|
return self.zipreader.imread(path).convert('RGB') |
|
else: |
|
if self.use_ceph: |
|
|
|
return self.tcs_loader(path).convert('RGB') |
|
elif not memorycache: |
|
with open(path, 'rb') as f: |
|
return Image.open(f).convert('RGB') |
|
else: |
|
|
|
raise NotImplementedError |
|
|
|
def _load_json(self, path): |
|
if '.zip@' in path: |
|
f = self.zipreader.read(path) |
|
return json.loads(f.decode()) |
|
else: |
|
with open(path, 'r') as f: |
|
return json.load(f) |
|
|
|
|
|
def build_transform(is_train, |
|
input_size=224, |
|
color_jitter=0.4, |
|
auto_augment='rand-m9-mstd0.5-inc1', |
|
train_interpolation='bicubic', |
|
re_prob=0.25, |
|
re_mode='pixel', |
|
re_count=1 |
|
): |
|
if is_train: |
|
|
|
transform = create_transform( |
|
input_size=input_size, |
|
is_training=True, |
|
color_jitter=color_jitter, |
|
auto_augment=auto_augment, |
|
interpolation=train_interpolation, |
|
re_prob=re_prob, |
|
re_mode=re_mode, |
|
re_count=re_count |
|
) |
|
|
|
return transform |
|
|
|
t = [] |
|
size = int((256 / 224) * input_size) |
|
t.append( |
|
transforms.Resize(size, interpolation=3), |
|
) |
|
t.append(transforms.CenterCrop(input_size)) |
|
|
|
t.append(transforms.ToTensor()) |
|
t.append(transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)) |
|
return transforms.Compose(t) |
|
|