jchwenger's picture
Upload 351 files (#2)
d9272c6 verified
"""
Copyright (c) 2022, salesforce.com, inc.
All rights reserved.
SPDX-License-Identifier: BSD-3-Clause
For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""
import json
import logging
import os
import time
from multiprocessing import Pool
import numpy as np
import requests
import tqdm
from lavis.common.utils import cleanup_dir, get_abs_path, get_cache_path
from omegaconf import OmegaConf
header_mzl = {
"User-Agent": "Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/71.0.3578.98 Safari/537.36",
# "User-Agent": "Googlebot-Image/1.0", # Pretend to be googlebot
# "X-Forwarded-For": "64.18.15.200",
}
header_gbot = {
"User-Agent": "Googlebot-Image/1.0", # Pretend to be googlebot
}
headers = [header_mzl, header_gbot]
# Setup
logging.basicConfig(filename="download_nocaps.log", filemode="w", level=logging.INFO)
requests.packages.urllib3.disable_warnings(
requests.packages.urllib3.exceptions.InsecureRequestWarning
)
def download_file(url, filename):
max_retries = 20
cur_retries = 0
header = headers[0]
while cur_retries < max_retries:
try:
r = requests.get(url, headers=header, timeout=10)
with open(filename, "wb") as f:
f.write(r.content)
break
except Exception as e:
logging.info(" ".join(repr(e).splitlines()))
logging.error(url)
cur_retries += 1
# random sample a header from headers
header = headers[np.random.randint(0, len(headers))]
time.sleep(3 + cur_retries * 2)
def download_image_from_url_val(url):
basename = os.path.basename(url)
filename = os.path.join(storage_dir, "val", basename)
download_file(url, filename)
def download_image_from_url_test(url):
basename = os.path.basename(url)
filename = os.path.join(storage_dir, "test", basename)
download_file(url, filename)
if __name__ == "__main__":
os.makedirs("tmp", exist_ok=True)
# storage dir
config_path = get_abs_path("configs/datasets/nocaps/defaults.yaml")
storage_dir = OmegaConf.load(config_path).datasets.nocaps.build_info.images.storage
storage_dir = get_cache_path(storage_dir)
# make sure the storage dir exists
os.makedirs(storage_dir, exist_ok=True)
print("Storage dir:", storage_dir)
# make sure the storage dir for val and test exists
os.makedirs(os.path.join(storage_dir, "val"), exist_ok=True)
os.makedirs(os.path.join(storage_dir, "test"), exist_ok=True)
# download annotations
val_url = "https://nocaps.s3.amazonaws.com/nocaps_val_4500_captions.json"
tst_url = "https://s3.amazonaws.com/nocaps/nocaps_test_image_info.json"
print("Downloading validation annotations from %s" % val_url)
download_file(val_url, "tmp/nocaps_val_ann.json")
print("Downloading testing annotations from %s" % tst_url)
download_file(tst_url, "tmp/nocaps_tst_ann.json")
# open annotations
val_ann = json.load(open("tmp/nocaps_val_ann.json"))
tst_ann = json.load(open("tmp/nocaps_tst_ann.json"))
# collect image urls
val_info = val_ann["images"]
tst_info = tst_ann["images"]
val_urls = [info["coco_url"] for info in val_info]
tst_urls = [info["coco_url"] for info in tst_info]
# setup multiprocessing
# large n_procs possibly causes server to reject requests
n_procs = 16
with Pool(n_procs) as pool:
print("Downloading validation images...")
list(
tqdm.tqdm(
pool.imap(download_image_from_url_val, val_urls), total=len(val_urls)
)
)
with Pool(n_procs) as pool:
print("Downloading test images...")
list(
tqdm.tqdm(
pool.imap(download_image_from_url_test, tst_urls), total=len(tst_urls)
)
)
# clean tmp
cleanup_dir("tmp")