|
import os |
|
import zipfile |
|
|
|
import requests |
|
from hydra import main |
|
from loguru import logger |
|
from tqdm import tqdm |
|
|
|
|
|
def download_file(url, destination): |
|
""" |
|
Downloads a file from the specified URL to the destination path with progress logging. |
|
""" |
|
logger.info(f"Downloading {os.path.basename(destination)}...") |
|
with requests.get(url, stream=True) as response: |
|
response.raise_for_status() |
|
total_size = int(response.headers.get("content-length", 0)) |
|
progress = tqdm(total=total_size, unit="iB", unit_scale=True, desc=os.path.basename(destination), leave=True) |
|
|
|
with open(destination, "wb") as file: |
|
for data in response.iter_content(chunk_size=1024 * 1024): |
|
file.write(data) |
|
progress.update(len(data)) |
|
progress.close() |
|
logger.info("Download completed.") |
|
|
|
|
|
def unzip_file(source, destination): |
|
""" |
|
Extracts a ZIP file to the specified directory and removes the ZIP file after extraction. |
|
""" |
|
logger.info(f"Unzipping {os.path.basename(source)}...") |
|
with zipfile.ZipFile(source, "r") as zip_ref: |
|
zip_ref.extractall(destination) |
|
os.remove(source) |
|
logger.info(f"Removed {source}.") |
|
|
|
|
|
def check_files(directory, expected_count=None): |
|
""" |
|
Returns True if the number of files in the directory matches expected_count, False otherwise. |
|
""" |
|
files = [f for f in os.listdir(directory) if os.path.isfile(os.path.join(directory, f))] |
|
return len(files) == expected_count if expected_count is not None else bool(files) |
|
|
|
|
|
@main(config_path="../config/data", config_name="download", version_base=None) |
|
def prepare_dataset(cfg): |
|
""" |
|
Prepares dataset by downloading and unzipping if necessary. |
|
""" |
|
data_dir = cfg.save_path |
|
for data_type, settings in cfg.datasets.items(): |
|
base_url = settings["base_url"] |
|
for dataset_type, dataset_args in settings.items(): |
|
if dataset_type == "base_url": |
|
continue |
|
file_name = f"{dataset_args.get('file_name', dataset_type)}.zip" |
|
url = f"{base_url}{file_name}" |
|
local_zip_path = os.path.join(data_dir, file_name) |
|
extract_to = os.path.join(data_dir, data_type) if data_type != "annotations" else data_dir |
|
final_place = os.path.join(extract_to, dataset_type) |
|
|
|
os.makedirs(extract_to, exist_ok=True) |
|
if check_files(final_place, dataset_args.get("file_num")): |
|
logger.info(f"Dataset {dataset_type} already verified.") |
|
continue |
|
|
|
if not os.path.exists(local_zip_path): |
|
download_file(url, local_zip_path) |
|
unzip_file(local_zip_path, extract_to) |
|
|
|
if not check_files(final_place, dataset_args.get("file_num")): |
|
logger.error(f"Error verifying the {dataset_type} dataset after extraction.") |
|
|
|
|
|
if __name__ == "__main__": |
|
import sys |
|
|
|
sys.path.append("./") |
|
from tools.log_helper import custom_logger |
|
|
|
custom_logger() |
|
prepare_dataset() |
|
|