Spaces:
Runtime error
Runtime error
import os | |
import re | |
import json | |
import numpy as np | |
from PIL import Image | |
from tqdm import tqdm | |
from torch.utils.data import Dataset | |
from torchvision.datasets.utils import download_url | |
from .constants import COCO_ROOT, FLICKR_ROOT | |
from .utils import AverageMeter | |
def pre_caption(caption,max_words=50): | |
caption = re.sub( | |
r"([.!\"()*#:;~])", | |
' ', | |
caption.lower(), | |
) | |
caption = re.sub( | |
r"\s{2,}", | |
' ', | |
caption, | |
) | |
caption = caption.rstrip('\n') | |
caption = caption.strip(' ') | |
#truncate caption | |
caption_words = caption.split(' ') | |
if len(caption_words)>max_words: | |
caption = ' '.join(caption_words[:max_words]) | |
return caption | |
class COCO_Retrieval(Dataset): | |
def __init__(self, image_preprocess=None, root_dir=COCO_ROOT, max_words=30, split="test", | |
image_perturb_fn=None, download=False): | |
""" | |
COCO Retrieval Dataset. | |
image_preprocess: image preprocessing function | |
root_dir: The directory of the coco dataset. This directory should contain test2014 files. | |
max_words: Cropping the caption to max_words. | |
split: 'val' or 'test' | |
image_perturb_fn: image perturbation function for patch permutation experiments. | |
download: Whether to download the dataset if it does not exist. | |
""" | |
self.root_dir = root_dir | |
if not os.path.exists(root_dir): | |
print("Directory for COCO could not be found!") | |
if download: | |
print("Downloading COCO now.") | |
self.download() | |
else: | |
raise RuntimeError("Please either download the dataset by letting `--download` or specify the correct directory.") | |
urls = {'val':'https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_val.json', | |
'test':'https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_test.json'} | |
filenames = {'val':'coco_karpathy_val.json','test':'coco_karpathy_test.json'} | |
download_url(urls[split],root_dir) | |
self.annotation = json.load(open(os.path.join(root_dir,filenames[split]),'r')) | |
self.image_preprocess = image_preprocess | |
self.image_perturb_fn = image_perturb_fn | |
self.image_root = root_dir | |
self.text = [] | |
self.image = [] | |
self.txt2img = {} | |
self.img2txt = {} | |
txt_id = 0 | |
for img_id, ann in enumerate(self.annotation): | |
self.image.append(ann['image']) | |
self.img2txt[img_id] = [] | |
for i, caption in enumerate(ann['caption']): | |
self.text.append(pre_caption(caption,max_words)) | |
self.img2txt[img_id].append(txt_id) | |
self.txt2img[txt_id] = img_id | |
txt_id += 1 | |
def __len__(self): | |
return len(self.annotation) | |
def __getitem__(self, index): | |
image_path = os.path.join(self.image_root, self.annotation[index]['image']) | |
image = Image.open(image_path).convert('RGB') | |
if self.image_preprocess is not None: | |
image = self.image_preprocess(image) | |
if self.image_perturb_fn is not None: | |
image = self.image_perturb_fn(image) | |
return {"image": image, "idx": index} | |
def download(self): | |
import subprocess | |
os.makedirs(self.root_dir, exist_ok=True) | |
#subprocess.call(["wget", "http://images.cocodataset.org/zips/train2014.zip"], cwd=self.root_dir) | |
#subprocess.call(["unzip", "train2014.zip"], cwd=self.root_dir) | |
subprocess.call(["wget", "http://images.cocodataset.org/zips/val2014.zip"], cwd=self.root_dir) | |
subprocess.call(["unzip", "val2014.zip"], cwd=self.root_dir) | |
subprocess.call(["wget", "http://images.cocodataset.org/zips/test2014.zip"], cwd=self.root_dir) | |
subprocess.call(["unzip", "test2014.zip"], cwd=self.root_dir) | |
def evaluate_scores(self, scores): | |
if isinstance(scores, tuple): | |
scores_i2t = scores[0] | |
scores_t2i = scores[1].T # Make it N_ims x N_text | |
else: | |
scores_t2i = scores | |
scores_i2t = scores | |
print(f"COCO results across {scores_i2t.shape} samples. ") | |
prec_at_1 = AverageMeter() | |
prec_at_5 = AverageMeter() | |
# Text retrieval | |
tqdm_iterator = tqdm(range(len(self.img2txt))) | |
for i in tqdm_iterator: | |
top5_captions = np.argsort(scores_i2t[i])[-5:] | |
true_captions = self.img2txt[i] | |
prec_at_1.update(len(set(true_captions) & set(top5_captions[-1:]))>0) | |
prec_at_5.update(len(set(true_captions) & set(top5_captions))>0) | |
tqdm_iterator.set_description(f"Text Retrieval Prec@1: {prec_at_1.avg:.3f}, Prec@5: {prec_at_5.avg:.3f}") | |
# Image Retrieval | |
image_prec_at_1 = AverageMeter() | |
image_prec_at_5 = AverageMeter() | |
tqdm_iterator = tqdm(range(len(self.txt2img))) | |
for i in tqdm_iterator: | |
top5_images = np.argsort(scores_t2i[:, i])[-5:] | |
true_image = self.txt2img[i] | |
image_prec_at_1.update(true_image in top5_images[-1:]) | |
image_prec_at_5.update(true_image in top5_images) | |
tqdm_iterator.set_description(f"Image Retrieval Prec@1: {image_prec_at_1.avg:.3f}, Prec@5: {image_prec_at_5.avg:.3f}") | |
records = [{"ImagePrec@1": image_prec_at_1.avg, "ImagePrec@5": image_prec_at_5.avg, "TextPrec@1": prec_at_1.avg, "TextPrec@5": prec_at_5.avg}] | |
return records | |
class Flickr30k_Retrieval(Dataset): | |
def __init__(self, image_preprocess, split, root_dir=FLICKR_ROOT, max_words=30, | |
image_perturb_fn=None, *args, **kwargs): | |
''' | |
Flickr30k dataset for retrieval. | |
image_preprocess: image preprocessing function | |
root_dir: The directory of the coco dataset. This directory should contain test2014 files. | |
max_words: Cropping the caption to max_words. | |
split: 'val' or 'test' | |
image_perturb_fn: image perturbation function for patch permutation experiments. | |
download: Whether to download the dataset if it does not exist. | |
''' | |
urls = {'val':'https://storage.googleapis.com/sfr-vision-language-research/datasets/flickr30k_val.json', | |
'test':'https://storage.googleapis.com/sfr-vision-language-research/datasets/flickr30k_test.json'} | |
filenames = {'val':'flickr30k_val.json','test':'flickr30k_test.json'} | |
if not os.path.exists(root_dir): | |
print("Directory for Flickr30k could not be found!") | |
flickr_url = "https://forms.illinois.edu/sec/229675" | |
raise RuntimeError(f"You need to manually sign up and download the dataset from {flickr_url} and place it in the `root_dir`.") | |
download_url(urls[split],root_dir) | |
self.annotation = json.load(open(os.path.join(root_dir,filenames[split]),'r')) | |
self.image_preprocess = image_preprocess | |
self.image_perturb_fn = image_perturb_fn | |
self.root_dir = root_dir | |
self.text = [] | |
self.image = [] | |
self.txt2img = {} | |
self.img2txt = {} | |
txt_id = 0 | |
for img_id, ann in enumerate(self.annotation): | |
self.image.append(ann['image']) | |
self.img2txt[img_id] = [] | |
for i, caption in enumerate(ann['caption']): | |
self.text.append(pre_caption(caption,max_words)) | |
self.img2txt[img_id].append(txt_id) | |
self.txt2img[txt_id] = img_id | |
txt_id += 1 | |
def __len__(self): | |
return len(self.annotation) | |
def __getitem__(self, index): | |
image_path = os.path.join(self.root_dir, self.annotation[index]['image']) | |
image = Image.open(image_path).convert('RGB') | |
if self.image_preprocess is not None: | |
image = self.image_preprocess(image) | |
if self.image_perturb_fn is not None: | |
image = self.image_perturb_fn(image) | |
return {"image": image, "idx": index} | |
def evaluate_scores(self, scores): | |
if isinstance(scores, tuple): | |
scores_i2t = scores[0] | |
scores_t2i = scores[1].T # Make it N_ims x N_text | |
else: | |
scores_t2i = scores | |
scores_i2t = scores | |
print(f"Flickr30k Retrieval results across {scores_i2t.shape} samples. ") | |
prec_at_1 = AverageMeter() | |
prec_at_5 = AverageMeter() | |
# Text retrieval | |
tqdm_iterator = tqdm(range(len(self.img2txt))) | |
for i in tqdm_iterator: | |
top5_captions = np.argsort(scores_i2t[i])[-5:] | |
true_captions = self.img2txt[i] | |
prec_at_1.update(len(set(true_captions) & set(top5_captions[-1:]))>0) | |
prec_at_5.update(len(set(true_captions) & set(top5_captions))>0) | |
tqdm_iterator.set_description(f"Text Retrieval Prec@1: {prec_at_1.avg:.3f}, Prec@5: {prec_at_5.avg:.3f}") | |
# Image Retrieval | |
image_prec_at_1 = AverageMeter() | |
image_prec_at_5 = AverageMeter() | |
tqdm_iterator = tqdm(range(len(self.txt2img))) | |
for i in tqdm_iterator: | |
top5_images = np.argsort(scores_t2i[:, i])[-5:] | |
true_image = self.txt2img[i] | |
image_prec_at_1.update(true_image in top5_images[-1:]) | |
image_prec_at_5.update(true_image in top5_images) | |
tqdm_iterator.set_description(f"Image Retrieval Prec@1: {image_prec_at_1.avg:.3f}, Prec@5: {image_prec_at_5.avg:.3f}") | |
records = [{"ImagePrec@1": image_prec_at_1.avg, "ImagePrec@5": image_prec_at_5.avg, "TextPrec@1": prec_at_1.avg, "TextPrec@5": prec_at_5.avg}] | |
return records | |
def download(self): | |
raise NotImplementedError("Flickr30k dataset is not available for download.") | |
def get_coco_retrieval(image_preprocess, image_perturb_fn, text_perturb_fn, max_words=30, download=False, root_dir=COCO_ROOT, split="test"): | |
dataset = COCO_Retrieval(root_dir=root_dir, split=split, image_preprocess=image_preprocess, image_perturb_fn=image_perturb_fn, max_words=max_words, | |
download=download) | |
return dataset | |
def get_flickr30k_retrieval(image_preprocess, image_perturb_fn, text_perturb_fn, max_words=30, download=False, root_dir=FLICKR_ROOT, split="test"): | |
dataset = Flickr30k_Retrieval(root_dir=root_dir, split=split, image_preprocess=image_preprocess, image_perturb_fn=image_perturb_fn, max_words=max_words, | |
download=download) | |
return dataset | |