File size: 3,961 Bytes
1197f7d 230a441 1197f7d 230a441 1197f7d b5fa3f1 1197f7d 230a441 1197f7d b5fa3f1 1197f7d b5fa3f1 1197f7d 24b85bd 1197f7d 490e893 1197f7d 230a441 1197f7d 24b85bd 1197f7d 230a441 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 |
import os
import zipfile
from typing import Optional
import requests
from loguru import logger
from rich.progress import BarColumn, Progress, TextColumn, TimeRemainingColumn
from yolo.config.config import DatasetConfig
def download_file(url, destination):
"""
Downloads a file from the specified URL to the destination path with progress logging.
"""
with requests.get(url, stream=True) as response:
response.raise_for_status()
total_size = int(response.headers.get("content-length", 0))
with Progress(
TextColumn("[progress.description]{task.description}"),
BarColumn(),
"[progress.percentage]{task.percentage:>3.1f}%",
"•",
"{task.completed}/{task.total} bytes",
"•",
TimeRemainingColumn(),
) as progress:
task = progress.add_task(f"📥 Downloading {os.path.basename(destination)}...", total=total_size)
with open(destination, "wb") as file:
for data in response.iter_content(chunk_size=1024 * 1024): # 1 MB chunks
file.write(data)
progress.update(task, advance=len(data))
logger.info("✅ Download completed.")
def unzip_file(source, destination):
"""
Extracts a ZIP file to the specified directory and removes the ZIP file after extraction.
"""
logger.info(f"Unzipping {os.path.basename(source)}...")
with zipfile.ZipFile(source, "r") as zip_ref:
zip_ref.extractall(destination)
os.remove(source)
logger.info(f"Removed {source}.")
def check_files(directory, expected_count=None):
"""
Returns True if the number of files in the directory matches expected_count, False otherwise.
"""
files = [f for f in os.listdir(directory) if os.path.isfile(os.path.join(directory, f))]
return len(files) == expected_count if expected_count is not None else bool(files)
def prepare_dataset(cfg: DatasetConfig):
"""
Prepares dataset by downloading and unzipping if necessary.
"""
data_dir = cfg.path
for data_type, settings in cfg.auto_download.items():
base_url = settings["base_url"]
for dataset_type, dataset_args in settings.items():
if dataset_type == "base_url":
continue # Skip the base_url entry
file_name = f"{dataset_args.get('file_name', dataset_type)}.zip"
url = f"{base_url}{file_name}"
local_zip_path = os.path.join(data_dir, file_name)
extract_to = os.path.join(data_dir, data_type) if data_type != "annotations" else data_dir
final_place = os.path.join(extract_to, dataset_type)
os.makedirs(final_place, exist_ok=True)
if check_files(final_place, dataset_args.get("file_num")):
logger.info(f"✅ Dataset {dataset_type: <12} already verified.")
continue
if not os.path.exists(local_zip_path):
download_file(url, local_zip_path)
unzip_file(local_zip_path, extract_to)
if not check_files(final_place, dataset_args.get("file_num")):
logger.error(f"Error verifying the {dataset_type} dataset after extraction.")
def prepare_weight(downlaod_link: Optional[str] = None, weight_name: str = "v9-c.pt"):
if downlaod_link is None:
downlaod_link = "https://github.com/WongKinYiu/yolov9mit/releases/download/v1.0-alpha/"
weight_link = f"{downlaod_link}{weight_name}"
if os.path.exists(weight_name):
logger.info(f"Weight file '{weight_name}' already exists.")
try:
download_file(weight_link, weight_name)
except requests.exceptions.RequestException as e:
logger.warning(f"Failed to download the weight file: {e}")
if __name__ == "__main__":
import sys
sys.path.append("./")
from utils.logging_utils import custom_logger
custom_logger()
prepare_weight()
|