File size: 3,099 Bytes
d3c8b75
5002339
 
 
eff2849
31cab2b
eff2849
d3c8b75
 
eff2849
d3c8b75
eff2849
d3c8b75
eff2849
 
 
 
 
d3c8b75
eff2849
 
 
 
 
 
d3c8b75
eff2849
 
d3c8b75
eff2849
d3c8b75
eff2849
 
 
 
 
d3c8b75
 
eff2849
 
 
 
 
 
d3c8b75
 
eff2849
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d3c8b75
 
 
eff2849
 
 
d5ba31a
d3c8b75
 
eb855bd
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import os
import zipfile

import requests
from hydra import main
from loguru import logger
from tqdm import tqdm


def download_file(url, destination):
    """
    Downloads a file from the specified URL to the destination path with progress logging.
    """
    logger.info(f"Downloading {os.path.basename(destination)}...")
    with requests.get(url, stream=True) as response:
        response.raise_for_status()
        total_size = int(response.headers.get("content-length", 0))
        progress = tqdm(total=total_size, unit="iB", unit_scale=True, desc=os.path.basename(destination), leave=True)

        with open(destination, "wb") as file:
            for data in response.iter_content(chunk_size=1024 * 1024):  # 1 MB chunks
                file.write(data)
                progress.update(len(data))
        progress.close()
    logger.info("Download completed.")


def unzip_file(source, destination):
    """
    Extracts a ZIP file to the specified directory and removes the ZIP file after extraction.
    """
    logger.info(f"Unzipping {os.path.basename(source)}...")
    with zipfile.ZipFile(source, "r") as zip_ref:
        zip_ref.extractall(destination)
    os.remove(source)
    logger.info(f"Removed {source}.")


def check_files(directory, expected_count=None):
    """
    Returns True if the number of files in the directory matches expected_count, False otherwise.
    """
    files = [f for f in os.listdir(directory) if os.path.isfile(os.path.join(directory, f))]
    return len(files) == expected_count if expected_count is not None else bool(files)


@main(config_path="../config/data", config_name="download", version_base=None)
def prepare_dataset(cfg):
    """
    Prepares dataset by downloading and unzipping if necessary.
    """
    data_dir = cfg.save_path
    for data_type, settings in cfg.datasets.items():
        base_url = settings["base_url"]
        for dataset_type, dataset_args in settings.items():
            if dataset_type == "base_url":
                continue  # Skip the base_url entry
            file_name = f"{dataset_args.get('file_name', dataset_type)}.zip"
            url = f"{base_url}{file_name}"
            local_zip_path = os.path.join(data_dir, file_name)
            extract_to = os.path.join(data_dir, data_type) if data_type != "annotations" else data_dir
            final_place = os.path.join(extract_to, dataset_type)

            os.makedirs(extract_to, exist_ok=True)
            if check_files(final_place, dataset_args.get("file_num")):
                logger.info(f"Dataset {dataset_type} already verified.")
                continue

            if not os.path.exists(local_zip_path):
                download_file(url, local_zip_path)
            unzip_file(local_zip_path, extract_to)

            if not check_files(final_place, dataset_args.get("file_num")):
                logger.error(f"Error verifying the {dataset_type} dataset after extraction.")


if __name__ == "__main__":
    import sys

    sys.path.append("./")
    from tools.log_helper import custom_logger

    custom_logger()
    prepare_dataset()