Chorus-Detection / download_model.py
dennisvdang's picture
Refactor code and remove unnecessary files
606184e
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""Script to download the chorus detection model from HuggingFace.
This script checks if the model file exists locally, and if not, downloads it
from the specified HuggingFace repository.
"""
import os
import sys
from pathlib import Path
import logging
# Configure logging
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger("model-downloader")
# Debug environment info
logger.info(f"Current working directory: {os.getcwd()}")
logger.info(f"Python path: {sys.path}")
logger.info(f"MODEL_REVISION: {os.environ.get('MODEL_REVISION')}")
logger.info(f"MODEL_HF_REPO: {os.environ.get('MODEL_HF_REPO')}")
logger.info(f"HF_MODEL_FILENAME: {os.environ.get('HF_MODEL_FILENAME')}")
# Use huggingface_hub for better integration with HF ecosystem
try:
from huggingface_hub import hf_hub_download
HF_HUB_AVAILABLE = True
logger.info("huggingface_hub is available")
except ImportError:
HF_HUB_AVAILABLE = False
logger.warning("huggingface_hub is not available, falling back to direct download")
import requests
from tqdm import tqdm
def download_file_with_progress(url: str, destination: Path) -> None:
"""Download a file with a progress bar.
Args:
url: URL to download from
destination: Path to save the file to
"""
# Create parent directories if they don't exist
destination.parent.mkdir(parents=True, exist_ok=True)
# Stream the download with progress bar
response = requests.get(url, stream=True)
response.raise_for_status()
total_size = int(response.headers.get('content-length', 0))
block_size = 1024 # 1 Kibibyte
logger.info(f"Downloading model from {url}")
logger.info(f"File size: {total_size / (1024*1024):.1f} MB")
with open(destination, 'wb') as file, tqdm(
desc=destination.name,
total=total_size,
unit='iB',
unit_scale=True,
unit_divisor=1024,
) as bar:
for data in response.iter_content(block_size):
size = file.write(data)
bar.update(size)
def ensure_model_exists(
model_filename: str = "best_model_V3.h5",
repo_id: str = None,
model_dir: Path = None,
hf_model_filename: str = None,
revision: str = None
) -> Path:
"""Ensure the model file exists, downloading it if necessary.
Args:
model_filename: Local filename for the model
repo_id: HuggingFace repository ID
model_dir: Directory to save the model to
hf_model_filename: Filename of the model in the HuggingFace repo
revision: Specific version of the model to use (SHA-256 hash)
Returns:
Path to the model file
"""
# Get parameters from environment variables if not provided
if repo_id is None:
repo_id = os.environ.get("MODEL_HF_REPO", "dennisvdang/chorus-detection")
if hf_model_filename is None:
hf_model_filename = os.environ.get("HF_MODEL_FILENAME", "chorus_detection_crnn.h5")
if revision is None:
revision = os.environ.get("MODEL_REVISION", "20e66eb3d0788373c3bdc5b28fa2f2587b0e475f3bbc47e8ab9ff0dbdbb2df32")
# Handle model directory paths for different environments
if model_dir is None:
# Check if we're in HF Spaces
if os.environ.get("SPACE_ID"):
# Try several possible locations
possible_dirs = [
Path("models/CRNN"),
Path("/home/user/app/models/CRNN"),
Path("/app/models/CRNN"),
Path(os.getcwd()) / "models" / "CRNN"
]
for directory in possible_dirs:
if directory.exists() or directory.parent.exists():
model_dir = directory
break
# If none exist, use the first option and create it
if model_dir is None:
model_dir = possible_dirs[0]
else:
model_dir = Path("models/CRNN")
# Make sure model_dir is a Path object
if isinstance(model_dir, str):
model_dir = Path(model_dir)
logger.info(f"Using model directory: {model_dir}")
model_path = model_dir / model_filename
# Log environment info when running in HF Space
if os.environ.get("SPACE_ID"):
logger.info(f"Running in Hugging Face Space: {os.environ.get('SPACE_ID')}")
logger.info(f"Using model repo: {repo_id}")
logger.info(f"Using model file: {hf_model_filename}")
logger.info(f"Using revision: {revision}")
# Check if the model already exists
if model_path.exists():
logger.info(f"Model already exists at {model_path}")
return model_path
# Create model directory if it doesn't exist
model_dir.mkdir(parents=True, exist_ok=True)
logger.info(f"Model not found at {model_path}. Downloading...")
try:
if HF_HUB_AVAILABLE:
# Use huggingface_hub to download the model
logger.info(f"Downloading model from {repo_id}/{hf_model_filename}")
downloaded_path = hf_hub_download(
repo_id=repo_id,
filename=hf_model_filename,
local_dir=model_dir,
local_dir_use_symlinks=False,
revision=revision
)
# Rename if necessary
if os.path.basename(downloaded_path) != model_filename:
downloaded_path_obj = Path(downloaded_path)
model_path.parent.mkdir(parents=True, exist_ok=True)
if model_path.exists():
model_path.unlink()
downloaded_path_obj.rename(model_path)
logger.info(f"Renamed {downloaded_path} to {model_path}")
else:
# Fallback to direct download if huggingface_hub is not available
huggingface_url = f"https://huggingface.co/{repo_id}/resolve/{revision}/{hf_model_filename}"
download_file_with_progress(huggingface_url, model_path)
logger.info(f"Successfully downloaded model to {model_path}")
return model_path
except Exception as e:
logger.error(f"Failed to download model: {e}", exc_info=True)
# Handle error more gracefully in production environment
if os.environ.get("SPACE_ID"):
logger.warning("Continuing despite model download failure")
return model_path
else:
sys.exit(1)
if __name__ == "__main__":
ensure_model_exists()