|
import zipfile |
|
from pathlib import Path |
|
from typing import Optional |
|
|
|
import requests |
|
from rich.progress import BarColumn, Progress, TextColumn, TimeRemainingColumn |
|
|
|
from yolo.config.config import DatasetConfig |
|
from yolo.utils.logger import logger |
|
|
|
|
|
def download_file(url, destination: Path): |
|
""" |
|
Downloads a file from the specified URL to the destination path with progress logging. |
|
""" |
|
with requests.get(url, stream=True) as response: |
|
response.raise_for_status() |
|
total_size = int(response.headers.get("content-length", 0)) |
|
with Progress( |
|
TextColumn("[progress.description]{task.description}"), |
|
BarColumn(), |
|
"[progress.percentage]{task.percentage:>3.1f}%", |
|
"β’", |
|
"{task.completed}/{task.total} bytes", |
|
"β’", |
|
TimeRemainingColumn(), |
|
) as progress: |
|
task = progress.add_task(f"π₯ Downloading {destination.name }...", total=total_size) |
|
with open(destination, "wb") as file: |
|
for data in response.iter_content(chunk_size=1024 * 1024): |
|
file.write(data) |
|
progress.update(task, advance=len(data)) |
|
logger.info(":white_check_mark: Download completed.") |
|
|
|
|
|
def unzip_file(source: Path, destination: Path): |
|
""" |
|
Extracts a ZIP file to the specified directory and removes the ZIP file after extraction. |
|
""" |
|
logger.info(f"Unzipping {source.name}...") |
|
with zipfile.ZipFile(source, "r") as zip_ref: |
|
zip_ref.extractall(destination) |
|
source.unlink() |
|
logger.info(f"Removed {source}.") |
|
|
|
|
|
def check_files(directory, expected_count=None): |
|
""" |
|
Returns True if the number of files in the directory matches expected_count, False otherwise. |
|
""" |
|
files = [f.name for f in Path(directory).iterdir() if f.is_file()] |
|
return len(files) == expected_count if expected_count is not None else bool(files) |
|
|
|
|
|
def prepare_dataset(dataset_cfg: DatasetConfig, task: str): |
|
""" |
|
Prepares dataset by downloading and unzipping if necessary. |
|
""" |
|
|
|
data_dir = Path(dataset_cfg.path) |
|
for data_type, settings in dataset_cfg.auto_download.items(): |
|
base_url = settings["base_url"] |
|
for dataset_type, dataset_args in settings.items(): |
|
if dataset_type != "annotations" and dataset_cfg.get(task, task) != dataset_type: |
|
continue |
|
file_name = f"{dataset_args.get('file_name', dataset_type)}.zip" |
|
url = f"{base_url}{file_name}" |
|
local_zip_path = data_dir / file_name |
|
extract_to = data_dir / data_type if data_type != "annotations" else data_dir |
|
final_place = extract_to / dataset_type |
|
|
|
final_place.mkdir(parents=True, exist_ok=True) |
|
if check_files(final_place, dataset_args.get("file_num")): |
|
logger.info(f":white_check_mark: Dataset {dataset_type: <12} already verified.") |
|
continue |
|
|
|
if not local_zip_path.exists(): |
|
download_file(url, local_zip_path) |
|
unzip_file(local_zip_path, extract_to) |
|
|
|
if not check_files(final_place, dataset_args.get("file_num")): |
|
logger.error(f"Error verifying the {dataset_type} dataset after extraction.") |
|
|
|
|
|
def prepare_weight(download_link: Optional[str] = None, weight_path: Path = Path("v9-c.pt")): |
|
weight_name = weight_path.name |
|
if download_link is None: |
|
download_link = "https://github.com/WongKinYiu/yolov9mit/releases/download/v1.0-alpha/" |
|
weight_link = f"{download_link}{weight_name}" |
|
|
|
if not weight_path.parent.is_dir(): |
|
weight_path.parent.mkdir(parents=True, exist_ok=True) |
|
|
|
if weight_path.exists(): |
|
logger.info(f"Weight file '{weight_path}' already exists.") |
|
try: |
|
download_file(weight_link, weight_path) |
|
except requests.exceptions.RequestException as e: |
|
logger.warning(f"Failed to download the weight file: {e}") |
|
|