File size: 4,205 Bytes
1197f7d
 
230a441
1197f7d
 
 
230a441
1197f7d
b5fa3f1
 
1197f7d
 
 
 
 
 
 
 
230a441
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1197f7d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5958998
1197f7d
 
 
1504257
97e9dcb
 
1197f7d
 
5958998
 
1197f7d
 
 
 
 
 
24b85bd
1197f7d
490e893
1197f7d
 
 
 
 
 
 
 
 
 
2cbea31
dc1a0cf
2cbea31
 
 
230a441
dc1a0cf
 
 
 
 
230a441
dc1a0cf
230a441
 
 
 
1197f7d
 
 
 
24b85bd
1197f7d
 
230a441
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import os
import zipfile
from typing import Optional

import requests
from loguru import logger
from rich.progress import BarColumn, Progress, TextColumn, TimeRemainingColumn

from yolo.config.config import DatasetConfig


def download_file(url, destination):
    """
    Downloads a file from the specified URL to the destination path with progress logging.
    """
    with requests.get(url, stream=True) as response:
        response.raise_for_status()
        total_size = int(response.headers.get("content-length", 0))
        with Progress(
            TextColumn("[progress.description]{task.description}"),
            BarColumn(),
            "[progress.percentage]{task.percentage:>3.1f}%",
            "•",
            "{task.completed}/{task.total} bytes",
            "•",
            TimeRemainingColumn(),
        ) as progress:
            task = progress.add_task(f"📥 Downloading {os.path.basename(destination)}...", total=total_size)
            with open(destination, "wb") as file:
                for data in response.iter_content(chunk_size=1024 * 1024):  # 1 MB chunks
                    file.write(data)
                    progress.update(task, advance=len(data))
    logger.info("✅ Download completed.")


def unzip_file(source, destination):
    """
    Extracts a ZIP file to the specified directory and removes the ZIP file after extraction.
    """
    logger.info(f"Unzipping {os.path.basename(source)}...")
    with zipfile.ZipFile(source, "r") as zip_ref:
        zip_ref.extractall(destination)
    os.remove(source)
    logger.info(f"Removed {source}.")


def check_files(directory, expected_count=None):
    """
    Returns True if the number of files in the directory matches expected_count, False otherwise.
    """
    files = [f for f in os.listdir(directory) if os.path.isfile(os.path.join(directory, f))]
    return len(files) == expected_count if expected_count is not None else bool(files)


def prepare_dataset(dataset_cfg: DatasetConfig, task: str):
    """
    Prepares dataset by downloading and unzipping if necessary.
    """
    # TODO: do EDA of dataset
    data_dir = dataset_cfg.path
    for data_type, settings in dataset_cfg.auto_download.items():
        base_url = settings["base_url"]
        for dataset_type, dataset_args in settings.items():
            if dataset_type != "annotations" and dataset_cfg.get(task, task) != dataset_type:
                continue
            file_name = f"{dataset_args.get('file_name', dataset_type)}.zip"
            url = f"{base_url}{file_name}"
            local_zip_path = os.path.join(data_dir, file_name)
            extract_to = os.path.join(data_dir, data_type) if data_type != "annotations" else data_dir
            final_place = os.path.join(extract_to, dataset_type)

            os.makedirs(final_place, exist_ok=True)
            if check_files(final_place, dataset_args.get("file_num")):
                logger.info(f"✅ Dataset {dataset_type: <12} already verified.")
                continue

            if not os.path.exists(local_zip_path):
                download_file(url, local_zip_path)
            unzip_file(local_zip_path, extract_to)

            if not check_files(final_place, dataset_args.get("file_num")):
                logger.error(f"Error verifying the {dataset_type} dataset after extraction.")


def prepare_weight(download_link: Optional[str] = None, weight_path: str = "v9-c.pt"):
    weight_name = os.path.basename(weight_path)
    if download_link is None:
        download_link = "https://github.com/WongKinYiu/yolov9mit/releases/download/v1.0-alpha/"
    weight_link = f"{download_link}{weight_name}"

    if not os.path.isdir(os.path.dirname(weight_path)):
        os.makedirs(os.path.dirname(weight_path))

    if os.path.exists(weight_path):
        logger.info(f"Weight file '{weight_path}' already exists.")
    try:
        download_file(weight_link, weight_path)
    except requests.exceptions.RequestException as e:
        logger.warning(f"Failed to download the weight file: {e}")


if __name__ == "__main__":
    import sys

    sys.path.append("./")
    from utils.logging_utils import custom_logger

    custom_logger()
    prepare_weight()