|
|
|
import os |
|
import json |
|
import argparse |
|
from PIL import Image |
|
import numpy as np |
|
|
|
if __name__ == '__main__': |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument('--ann', default='datasets/cc3m/Train_GCC-training.tsv') |
|
parser.add_argument('--save_image_path', default='datasets/cc3m/training/') |
|
parser.add_argument('--cat_info', default='datasets/lvis/lvis_v1_val.json') |
|
parser.add_argument('--out_path', default='datasets/cc3m/train_image_info.json') |
|
parser.add_argument('--not_download_image', action='store_true') |
|
args = parser.parse_args() |
|
categories = json.load(open(args.cat_info, 'r'))['categories'] |
|
images = [] |
|
if not os.path.exists(args.save_image_path): |
|
os.makedirs(args.save_image_path) |
|
f = open(args.ann) |
|
for i, line in enumerate(f): |
|
cap, path = line[:-1].split('\t') |
|
print(i, cap, path) |
|
if not args.not_download_image: |
|
os.system( |
|
'wget {} -O {}/{}.jpg'.format( |
|
path, args.save_image_path, i + 1)) |
|
try: |
|
img = Image.open( |
|
open('{}/{}.jpg'.format(args.save_image_path, i + 1), "rb")) |
|
img = np.asarray(img.convert("RGB")) |
|
h, w = img.shape[:2] |
|
except: |
|
continue |
|
image_info = { |
|
'id': i + 1, |
|
'file_name': '{}.jpg'.format(i + 1), |
|
'height': h, |
|
'width': w, |
|
'captions': [cap], |
|
} |
|
images.append(image_info) |
|
data = {'categories': categories, 'images': images, 'annotations': []} |
|
for k, v in data.items(): |
|
print(k, len(v)) |
|
print('Saving to', args.out_path) |
|
json.dump(data, open(args.out_path, 'w')) |
|
|