|
""" |
|
download.py |
|
|
|
Utility functions for downloading and extracting various datasets to (local) disk. |
|
""" |
|
|
|
import os |
|
import shutil |
|
from pathlib import Path |
|
from typing import Dict, List, TypedDict |
|
from zipfile import ZipFile |
|
|
|
import requests |
|
from PIL import Image |
|
from rich.progress import BarColumn, DownloadColumn, MofNCompleteColumn, Progress, TextColumn, TransferSpeedColumn |
|
from tqdm import tqdm |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
DatasetComponent = TypedDict( |
|
"DatasetComponent", |
|
{"name": str, "extract": bool, "extract_type": str, "url": str, "do_rename": bool}, |
|
total=False |
|
) |
|
|
|
DATASET_REGISTRY: Dict[str, List[DatasetComponent]] = { |
|
|
|
|
|
|
|
|
|
"llava-v1.5-instruct": |
|
[ |
|
{ |
|
"name": "coco/train2017", |
|
"extract": True, |
|
"extract_type": "directory", |
|
"url": "http://images.cocodataset.org/zips/train2017.zip", |
|
"do_rename": True, |
|
}, |
|
{ |
|
"name": "gqa/images", |
|
"extract": True, |
|
"extract_type": "directory", |
|
"url": "https://downloads.cs.stanford.edu/nlp/data/gqa/images.zip", |
|
"do_rename": True, |
|
}, |
|
{ |
|
"name": "ocr_vqa/images", |
|
"extract": True, |
|
"extract_type": "directory", |
|
"url": "https://hf-mirror.com/datasets/qnguyen3/ocr_vqa/resolve/main/ocr_vqa.zip", |
|
"do_rename": True, |
|
}, |
|
{ |
|
"name": "textvqa/train_images", |
|
"extract": True, |
|
"extract_type": "directory", |
|
"url": "https://dl.fbaipublicfiles.com/textvqa/images/train_val_images.zip", |
|
"do_rename": True, |
|
}, |
|
{ |
|
"name": "vg/VG_100K", |
|
"extract": True, |
|
"extract_type": "directory", |
|
"url": "https://cs.stanford.edu/people/rak248/VG_100K_2/images.zip", |
|
"do_rename": True, |
|
}, |
|
{ |
|
"name": "vg/VG_100K_2", |
|
"extract": True, |
|
"extract_type": "directory", |
|
"url": "https://cs.stanford.edu/people/rak248/VG_100K_2/images2.zip", |
|
"do_rename": True, |
|
}, |
|
] |
|
} |
|
|
|
|
|
|
|
def convert_to_jpg(image_dir: Path) -> None: |
|
"""Handling for OCR-VQA Images specifically; iterates through directory, converts all GIFs/PNGs.""" |
|
print(f"Converting all Images in `{image_dir}` to JPG") |
|
|
|
for image_fn in tqdm(list(image_dir.iterdir())): |
|
if image_fn.suffix in {".jpg", ".jpeg"} or (jpg_fn := image_dir / f"{image_fn.stem}.jpg").exists(): |
|
continue |
|
|
|
if image_fn.suffix == ".gif": |
|
gif = Image.open(image_fn) |
|
gif.seek(0) |
|
gif.convert("RGB").save(jpg_fn) |
|
elif image_fn.suffix == ".png": |
|
Image.open(image_fn).convert("RGB").save(jpg_fn) |
|
else: |
|
raise ValueError(f"Unexpected image format `{image_fn.suffix}`") |
|
|
|
|
|
def download_with_progress(url: str, download_dir: Path, chunk_size_bytes: int = 1024) -> Path: |
|
"""Utility function for downloading files from the internet, with a handy Rich-based progress bar.""" |
|
print(f"Downloading {(dest_path := download_dir / Path(url).name)} from `{url}`", ctx_level=1) |
|
if dest_path.exists(): |
|
return dest_path |
|
|
|
|
|
response = requests.get(url, stream=True) |
|
|
|
|
|
|
|
with Progress( |
|
TextColumn("[bold]{task.description} - {task.fields[fname]}"), |
|
BarColumn(bar_width=None), |
|
"[progress.percentage]{task.percentage:>3.1f}%", |
|
"•", |
|
DownloadColumn(), |
|
"•", |
|
TransferSpeedColumn(), |
|
transient=True, |
|
) as dl_progress: |
|
dl_tid = dl_progress.add_task( |
|
"Downloading", fname=dest_path.name, total=int(response.headers.get("content-length", "None")) |
|
) |
|
with open(dest_path, "wb") as f: |
|
for data in response.iter_content(chunk_size=chunk_size_bytes): |
|
dl_progress.advance(dl_tid, f.write(data)) |
|
|
|
return dest_path |
|
|
|
|
|
def extract_with_progress(archive_path: Path, download_dir: Path, extract_type: str, cleanup: bool = False) -> Path: |
|
"""Utility function for extracting compressed archives, with a handy Rich-based progress bar.""" |
|
assert archive_path.suffix == ".zip", "Only `.zip` compressed archives are supported for now!" |
|
print(f"Extracting {archive_path.name} to `{download_dir}`", ctx_level=1) |
|
|
|
|
|
with Progress( |
|
TextColumn("[bold]{task.description} - {task.fields[aname]}"), |
|
BarColumn(bar_width=None), |
|
"[progress.percentage]{task.percentage:>3.1f}%", |
|
"•", |
|
MofNCompleteColumn(), |
|
transient=True, |
|
) as ext_progress: |
|
with ZipFile(archive_path) as zf: |
|
ext_tid = ext_progress.add_task("Extracting", aname=archive_path.name, total=len(members := zf.infolist())) |
|
extract_path = Path(zf.extract(members[0], download_dir)) |
|
if extract_type == "file": |
|
assert len(members) == 1, f"Archive `{archive_path}` with extract type `{extract_type} has > 1 member!" |
|
elif extract_type == "directory": |
|
for member in members[1:]: |
|
zf.extract(member, download_dir) |
|
ext_progress.advance(ext_tid) |
|
else: |
|
raise ValueError(f"Extract type `{extract_type}` for archive `{archive_path}` is not defined!") |
|
|
|
|
|
if cleanup: |
|
archive_path.unlink() |
|
|
|
return extract_path |
|
|
|
|
|
def download_extract(dataset_id: str, root_dir: Path) -> None: |
|
"""Download all files for a given dataset (querying registry above), extracting archives if necessary.""" |
|
os.makedirs(download_dir := root_dir / "download" / dataset_id, exist_ok=True) |
|
|
|
|
|
dl_tasks = [d for d in DATASET_REGISTRY[dataset_id] if not (download_dir / d["name"]).exists()] |
|
for dl_task in dl_tasks: |
|
dl_path = download_with_progress(dl_task["url"], download_dir) |
|
|
|
|
|
if dl_task["extract"]: |
|
dl_path = extract_with_progress(dl_path, download_dir, dl_task["extract_type"]) |
|
dl_path = dl_path.parent if dl_path.is_file() else dl_path |
|
|
|
|
|
if dl_task["do_rename"]: |
|
shutil.move(dl_path, download_dir / dl_task["name"]) |
|
if __name__ == "__main__": |
|
import sys |
|
from pathlib import Path |
|
|
|
|
|
root_dir = Path("./data") |
|
os.makedirs(root_dir, exist_ok=True) |
|
|
|
|
|
for dataset_id in DATASET_REGISTRY.keys(): |
|
print(f"开始下载数据集: {dataset_id}") |
|
download_extract(dataset_id, root_dir) |
|
|
|
print("所有数据集下载完成!") |
|
|