YOLO / utils /get_dataset.py
henry000's picture
πŸ› [Fix] Dataset autodownload bug, new dataset format
eff2849
raw
history blame
3.1 kB
import os
import zipfile
import requests
from hydra import main
from loguru import logger
from tqdm import tqdm
def download_file(url, destination):
"""
Downloads a file from the specified URL to the destination path with progress logging.
"""
logger.info(f"Downloading {os.path.basename(destination)}...")
with requests.get(url, stream=True) as response:
response.raise_for_status()
total_size = int(response.headers.get("content-length", 0))
progress = tqdm(total=total_size, unit="iB", unit_scale=True, desc=os.path.basename(destination), leave=True)
with open(destination, "wb") as file:
for data in response.iter_content(chunk_size=1024 * 1024): # 1 MB chunks
file.write(data)
progress.update(len(data))
progress.close()
logger.info("Download completed.")
def unzip_file(source, destination):
"""
Extracts a ZIP file to the specified directory and removes the ZIP file after extraction.
"""
logger.info(f"Unzipping {os.path.basename(source)}...")
with zipfile.ZipFile(source, "r") as zip_ref:
zip_ref.extractall(destination)
os.remove(source)
logger.info(f"Removed {source}.")
def check_files(directory, expected_count=None):
"""
Returns True if the number of files in the directory matches expected_count, False otherwise.
"""
files = [f for f in os.listdir(directory) if os.path.isfile(os.path.join(directory, f))]
return len(files) == expected_count if expected_count is not None else bool(files)
@main(config_path="../config/data", config_name="download", version_base=None)
def prepare_dataset(cfg):
"""
Prepares dataset by downloading and unzipping if necessary.
"""
data_dir = cfg.save_path
for data_type, settings in cfg.datasets.items():
base_url = settings["base_url"]
for dataset_type, dataset_args in settings.items():
if dataset_type == "base_url":
continue # Skip the base_url entry
file_name = f"{dataset_args.get('file_name', dataset_type)}.zip"
url = f"{base_url}{file_name}"
local_zip_path = os.path.join(data_dir, file_name)
extract_to = os.path.join(data_dir, data_type) if data_type != "annotations" else data_dir
final_place = os.path.join(extract_to, dataset_type)
os.makedirs(extract_to, exist_ok=True)
if check_files(final_place, dataset_args.get("file_num")):
logger.info(f"Dataset {dataset_type} already verified.")
continue
if not os.path.exists(local_zip_path):
download_file(url, local_zip_path)
unzip_file(local_zip_path, extract_to)
if not check_files(final_place, dataset_args.get("file_num")):
logger.error(f"Error verifying the {dataset_type} dataset after extraction.")
if __name__ == "__main__":
import sys
sys.path.append("./")
from tools.log_helper import custom_logger
custom_logger()
prepare_dataset()