Spaces:
Sleeping
Sleeping
File size: 3,962 Bytes
d9272c6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 |
"""
Copyright (c) 2022, salesforce.com, inc.
All rights reserved.
SPDX-License-Identifier: BSD-3-Clause
For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""
import json
import logging
import os
import time
from multiprocessing import Pool
import numpy as np
import requests
import tqdm
from lavis.common.utils import cleanup_dir, get_abs_path, get_cache_path
from omegaconf import OmegaConf
header_mzl = {
"User-Agent": "Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/71.0.3578.98 Safari/537.36",
# "User-Agent": "Googlebot-Image/1.0", # Pretend to be googlebot
# "X-Forwarded-For": "64.18.15.200",
}
header_gbot = {
"User-Agent": "Googlebot-Image/1.0", # Pretend to be googlebot
}
headers = [header_mzl, header_gbot]
# Setup
logging.basicConfig(filename="download_nocaps.log", filemode="w", level=logging.INFO)
requests.packages.urllib3.disable_warnings(
requests.packages.urllib3.exceptions.InsecureRequestWarning
)
def download_file(url, filename):
max_retries = 20
cur_retries = 0
header = headers[0]
while cur_retries < max_retries:
try:
r = requests.get(url, headers=header, timeout=10)
with open(filename, "wb") as f:
f.write(r.content)
break
except Exception as e:
logging.info(" ".join(repr(e).splitlines()))
logging.error(url)
cur_retries += 1
# random sample a header from headers
header = headers[np.random.randint(0, len(headers))]
time.sleep(3 + cur_retries * 2)
def download_image_from_url_val(url):
basename = os.path.basename(url)
filename = os.path.join(storage_dir, "val", basename)
download_file(url, filename)
def download_image_from_url_test(url):
basename = os.path.basename(url)
filename = os.path.join(storage_dir, "test", basename)
download_file(url, filename)
if __name__ == "__main__":
os.makedirs("tmp", exist_ok=True)
# storage dir
config_path = get_abs_path("configs/datasets/nocaps/defaults.yaml")
storage_dir = OmegaConf.load(config_path).datasets.nocaps.build_info.images.storage
storage_dir = get_cache_path(storage_dir)
# make sure the storage dir exists
os.makedirs(storage_dir, exist_ok=True)
print("Storage dir:", storage_dir)
# make sure the storage dir for val and test exists
os.makedirs(os.path.join(storage_dir, "val"), exist_ok=True)
os.makedirs(os.path.join(storage_dir, "test"), exist_ok=True)
# download annotations
val_url = "https://nocaps.s3.amazonaws.com/nocaps_val_4500_captions.json"
tst_url = "https://s3.amazonaws.com/nocaps/nocaps_test_image_info.json"
print("Downloading validation annotations from %s" % val_url)
download_file(val_url, "tmp/nocaps_val_ann.json")
print("Downloading testing annotations from %s" % tst_url)
download_file(tst_url, "tmp/nocaps_tst_ann.json")
# open annotations
val_ann = json.load(open("tmp/nocaps_val_ann.json"))
tst_ann = json.load(open("tmp/nocaps_tst_ann.json"))
# collect image urls
val_info = val_ann["images"]
tst_info = tst_ann["images"]
val_urls = [info["coco_url"] for info in val_info]
tst_urls = [info["coco_url"] for info in tst_info]
# setup multiprocessing
# large n_procs possibly causes server to reject requests
n_procs = 16
with Pool(n_procs) as pool:
print("Downloading validation images...")
list(
tqdm.tqdm(
pool.imap(download_image_from_url_val, val_urls), total=len(val_urls)
)
)
with Pool(n_procs) as pool:
print("Downloading test images...")
list(
tqdm.tqdm(
pool.imap(download_image_from_url_test, tst_urls), total=len(tst_urls)
)
)
# clean tmp
cleanup_dir("tmp")
|