File size: 2,835 Bytes
d3c8b75
5002339
 
 
d3c8b75
5002339
 
d3c8b75
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5002339
1ff26a6
5002339
 
 
d3c8b75
5002339
 
d3c8b75
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d5ba31a
d3c8b75
 
eb855bd
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import os
import zipfile

import hydra
from loguru import logger
import requests
from tqdm.rich import tqdm


def download_file(url, dest_path):
    """
    Downloads a file from a specified URL to a destination path with progress logging.
    """
    logger.info(f"Downloading {os.path.basename(dest_path)}...")
    with requests.get(url, stream=True) as r:
        r.raise_for_status()
        total_length = int(r.headers.get("content-length", 0))
        with open(dest_path, "wb") as f, tqdm(
            total=total_length, unit="iB", unit_scale=True, desc=os.path.basename(dest_path), leave=True
        ) as bar:
            for chunk in r.iter_content(chunk_size=1024 * 1024):
                f.write(chunk)
                bar.update(len(chunk))
    logger.info("Download complete!")


def unzip_file(zip_path, extract_to):
    """
    Unzips a ZIP file to a specified directory.
    """
    logger.info(f"Unzipping {os.path.basename(zip_path)}...")
    with zipfile.ZipFile(zip_path, "r") as zip_ref:
        zip_ref.extractall(extract_to)
    os.remove(zip_path)
    logger.info(f"Removed {zip_path}")


def check_files(directory, expected_count):
    """
    Checks if the specified directory has the expected number of files.
    """
    num_files = len([name for name in os.listdir(directory) if os.path.isfile(os.path.join(directory, name))])
    return num_files == expected_count


@hydra.main(config_path="../config/data", config_name="download", version_base=None)
def prepare_dataset(download_cfg):
    data_dir = download_cfg.path
    base_url = download_cfg.images.base_url
    datasets = download_cfg.images.datasets

    for dataset_type in datasets:
        file_name, expected_files = datasets[dataset_type].values()
        url = f"{base_url}{file_name}"
        local_zip_path = os.path.join(data_dir, file_name)
        extract_to = os.path.join(data_dir, dataset_type, "images")

        # Ensure the extraction directory exists
        os.makedirs(extract_to, exist_ok=True)

        # Check if the correct number of files exists
        if check_files(extract_to, expected_files):
            logger.info(f"Dataset {dataset_type} already verified.")
            continue

        if os.path.exists(local_zip_path):
            logger.info(f"Dataset {dataset_type} already downloaded.")
        else:
            download_file(url, local_zip_path)

        unzip_file(local_zip_path, extract_to)

        print(os.path.exists(local_zip_path), check_files(extract_to, expected_files))

        # Additional verification post extraction
        if not check_files(extract_to, expected_files):
            logger.error(f"Error in verifying the {dataset_type} dataset after extraction.")


if __name__ == "__main__":
    from tools.log_helper import custom_logger

    custom_logger()
    prepare_dataset()