File size: 3,962 Bytes
d9272c6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
"""
 Copyright (c) 2022, salesforce.com, inc.
 All rights reserved.
 SPDX-License-Identifier: BSD-3-Clause
 For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""

import json
import logging
import os
import time
from multiprocessing import Pool

import numpy as np
import requests
import tqdm
from lavis.common.utils import cleanup_dir, get_abs_path, get_cache_path
from omegaconf import OmegaConf

header_mzl = {
    "User-Agent": "Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/71.0.3578.98 Safari/537.36",
    # "User-Agent": "Googlebot-Image/1.0",  # Pretend to be googlebot
    # "X-Forwarded-For": "64.18.15.200",
}

header_gbot = {
    "User-Agent": "Googlebot-Image/1.0",  # Pretend to be googlebot
}

headers = [header_mzl, header_gbot]

# Setup
logging.basicConfig(filename="download_nocaps.log", filemode="w", level=logging.INFO)
requests.packages.urllib3.disable_warnings(
    requests.packages.urllib3.exceptions.InsecureRequestWarning
)


def download_file(url, filename):
    max_retries = 20
    cur_retries = 0

    header = headers[0]

    while cur_retries < max_retries:
        try:
            r = requests.get(url, headers=header, timeout=10)
            with open(filename, "wb") as f:
                f.write(r.content)

            break
        except Exception as e:
            logging.info(" ".join(repr(e).splitlines()))
            logging.error(url)
            cur_retries += 1

            # random sample a header from headers
            header = headers[np.random.randint(0, len(headers))]

    time.sleep(3 + cur_retries * 2)


def download_image_from_url_val(url):
    basename = os.path.basename(url)
    filename = os.path.join(storage_dir, "val", basename)

    download_file(url, filename)


def download_image_from_url_test(url):
    basename = os.path.basename(url)
    filename = os.path.join(storage_dir, "test", basename)

    download_file(url, filename)


if __name__ == "__main__":
    os.makedirs("tmp", exist_ok=True)

    # storage dir
    config_path = get_abs_path("configs/datasets/nocaps/defaults.yaml")

    storage_dir = OmegaConf.load(config_path).datasets.nocaps.build_info.images.storage
    storage_dir = get_cache_path(storage_dir)
    # make sure the storage dir exists
    os.makedirs(storage_dir, exist_ok=True)
    print("Storage dir:", storage_dir)

    # make sure the storage dir for val and test exists
    os.makedirs(os.path.join(storage_dir, "val"), exist_ok=True)
    os.makedirs(os.path.join(storage_dir, "test"), exist_ok=True)

    # download annotations
    val_url = "https://nocaps.s3.amazonaws.com/nocaps_val_4500_captions.json"
    tst_url = "https://s3.amazonaws.com/nocaps/nocaps_test_image_info.json"

    print("Downloading validation annotations from %s" % val_url)
    download_file(val_url, "tmp/nocaps_val_ann.json")
    print("Downloading testing annotations from %s" % tst_url)
    download_file(tst_url, "tmp/nocaps_tst_ann.json")

    # open annotations
    val_ann = json.load(open("tmp/nocaps_val_ann.json"))
    tst_ann = json.load(open("tmp/nocaps_tst_ann.json"))

    # collect image urls
    val_info = val_ann["images"]
    tst_info = tst_ann["images"]

    val_urls = [info["coco_url"] for info in val_info]
    tst_urls = [info["coco_url"] for info in tst_info]

    # setup multiprocessing
    # large n_procs possibly causes server to reject requests
    n_procs = 16

    with Pool(n_procs) as pool:
        print("Downloading validation images...")
        list(
            tqdm.tqdm(
                pool.imap(download_image_from_url_val, val_urls), total=len(val_urls)
            )
        )

    with Pool(n_procs) as pool:
        print("Downloading test images...")
        list(
            tqdm.tqdm(
                pool.imap(download_image_from_url_test, tst_urls), total=len(tst_urls)
            )
        )

    # clean tmp
    cleanup_dir("tmp")