Detic / tools /download_cc.py
AK391
files
159f437
# Copyright (c) Facebook, Inc. and its affiliates.
import os
import json
import argparse
from PIL import Image
import numpy as np
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--ann', default='datasets/cc3m/Train_GCC-training.tsv')
parser.add_argument('--save_image_path', default='datasets/cc3m/training/')
parser.add_argument('--cat_info', default='datasets/lvis/lvis_v1_val.json')
parser.add_argument('--out_path', default='datasets/cc3m/train_image_info.json')
parser.add_argument('--not_download_image', action='store_true')
args = parser.parse_args()
categories = json.load(open(args.cat_info, 'r'))['categories']
images = []
if not os.path.exists(args.save_image_path):
os.makedirs(args.save_image_path)
f = open(args.ann)
for i, line in enumerate(f):
cap, path = line[:-1].split('\t')
print(i, cap, path)
if not args.not_download_image:
os.system(
'wget {} -O {}/{}.jpg'.format(
path, args.save_image_path, i + 1))
try:
img = Image.open(
open('{}/{}.jpg'.format(args.save_image_path, i + 1), "rb"))
img = np.asarray(img.convert("RGB"))
h, w = img.shape[:2]
except:
continue
image_info = {
'id': i + 1,
'file_name': '{}.jpg'.format(i + 1),
'height': h,
'width': w,
'captions': [cap],
}
images.append(image_info)
data = {'categories': categories, 'images': images, 'annotations': []}
for k, v in data.items():
print(k, len(v))
print('Saving to', args.out_path)
json.dump(data, open(args.out_path, 'w'))