Spaces:
Running
Running
#!/usr/bin/env python3 | |
# -*- coding: utf-8 -*- | |
"""Script to download the chorus detection model from HuggingFace. | |
This script checks if the model file exists locally, and if not, downloads it | |
from the specified HuggingFace repository. | |
""" | |
import os | |
import sys | |
from pathlib import Path | |
import logging | |
# Configure logging | |
logging.basicConfig( | |
level=logging.INFO, | |
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" | |
) | |
logger = logging.getLogger("model-downloader") | |
# Debug environment info | |
logger.info(f"Current working directory: {os.getcwd()}") | |
logger.info(f"Python path: {sys.path}") | |
logger.info(f"MODEL_REVISION: {os.environ.get('MODEL_REVISION')}") | |
logger.info(f"MODEL_HF_REPO: {os.environ.get('MODEL_HF_REPO')}") | |
logger.info(f"HF_MODEL_FILENAME: {os.environ.get('HF_MODEL_FILENAME')}") | |
# Use huggingface_hub for better integration with HF ecosystem | |
try: | |
from huggingface_hub import hf_hub_download | |
HF_HUB_AVAILABLE = True | |
logger.info("huggingface_hub is available") | |
except ImportError: | |
HF_HUB_AVAILABLE = False | |
logger.warning("huggingface_hub is not available, falling back to direct download") | |
import requests | |
from tqdm import tqdm | |
def download_file_with_progress(url: str, destination: Path) -> None: | |
"""Download a file with a progress bar. | |
Args: | |
url: URL to download from | |
destination: Path to save the file to | |
""" | |
# Create parent directories if they don't exist | |
destination.parent.mkdir(parents=True, exist_ok=True) | |
# Stream the download with progress bar | |
response = requests.get(url, stream=True) | |
response.raise_for_status() | |
total_size = int(response.headers.get('content-length', 0)) | |
block_size = 1024 # 1 Kibibyte | |
logger.info(f"Downloading model from {url}") | |
logger.info(f"File size: {total_size / (1024*1024):.1f} MB") | |
with open(destination, 'wb') as file, tqdm( | |
desc=destination.name, | |
total=total_size, | |
unit='iB', | |
unit_scale=True, | |
unit_divisor=1024, | |
) as bar: | |
for data in response.iter_content(block_size): | |
size = file.write(data) | |
bar.update(size) | |
def ensure_model_exists( | |
model_filename: str = "best_model_V3.h5", | |
repo_id: str = None, | |
model_dir: Path = None, | |
hf_model_filename: str = None, | |
revision: str = None | |
) -> Path: | |
"""Ensure the model file exists, downloading it if necessary. | |
Args: | |
model_filename: Local filename for the model | |
repo_id: HuggingFace repository ID | |
model_dir: Directory to save the model to | |
hf_model_filename: Filename of the model in the HuggingFace repo | |
revision: Specific version of the model to use (SHA-256 hash) | |
Returns: | |
Path to the model file | |
""" | |
# Get parameters from environment variables if not provided | |
if repo_id is None: | |
repo_id = os.environ.get("MODEL_HF_REPO", "dennisvdang/chorus-detection") | |
if hf_model_filename is None: | |
hf_model_filename = os.environ.get("HF_MODEL_FILENAME", "chorus_detection_crnn.h5") | |
if revision is None: | |
revision = os.environ.get("MODEL_REVISION", "20e66eb3d0788373c3bdc5b28fa2f2587b0e475f3bbc47e8ab9ff0dbdbb2df32") | |
# Handle model directory paths for different environments | |
if model_dir is None: | |
# Check if we're in HF Spaces | |
if os.environ.get("SPACE_ID"): | |
# Try several possible locations | |
possible_dirs = [ | |
Path("models/CRNN"), | |
Path("/home/user/app/models/CRNN"), | |
Path("/app/models/CRNN"), | |
Path(os.getcwd()) / "models" / "CRNN" | |
] | |
for directory in possible_dirs: | |
if directory.exists() or directory.parent.exists(): | |
model_dir = directory | |
break | |
# If none exist, use the first option and create it | |
if model_dir is None: | |
model_dir = possible_dirs[0] | |
else: | |
model_dir = Path("models/CRNN") | |
# Make sure model_dir is a Path object | |
if isinstance(model_dir, str): | |
model_dir = Path(model_dir) | |
logger.info(f"Using model directory: {model_dir}") | |
model_path = model_dir / model_filename | |
# Log environment info when running in HF Space | |
if os.environ.get("SPACE_ID"): | |
logger.info(f"Running in Hugging Face Space: {os.environ.get('SPACE_ID')}") | |
logger.info(f"Using model repo: {repo_id}") | |
logger.info(f"Using model file: {hf_model_filename}") | |
logger.info(f"Using revision: {revision}") | |
# Check if the model already exists | |
if model_path.exists(): | |
logger.info(f"Model already exists at {model_path}") | |
return model_path | |
# Create model directory if it doesn't exist | |
model_dir.mkdir(parents=True, exist_ok=True) | |
logger.info(f"Model not found at {model_path}. Downloading...") | |
try: | |
if HF_HUB_AVAILABLE: | |
# Use huggingface_hub to download the model | |
logger.info(f"Downloading model from {repo_id}/{hf_model_filename}") | |
downloaded_path = hf_hub_download( | |
repo_id=repo_id, | |
filename=hf_model_filename, | |
local_dir=model_dir, | |
local_dir_use_symlinks=False, | |
revision=revision | |
) | |
# Rename if necessary | |
if os.path.basename(downloaded_path) != model_filename: | |
downloaded_path_obj = Path(downloaded_path) | |
model_path.parent.mkdir(parents=True, exist_ok=True) | |
if model_path.exists(): | |
model_path.unlink() | |
downloaded_path_obj.rename(model_path) | |
logger.info(f"Renamed {downloaded_path} to {model_path}") | |
else: | |
# Fallback to direct download if huggingface_hub is not available | |
huggingface_url = f"https://huggingface.co/{repo_id}/resolve/{revision}/{hf_model_filename}" | |
download_file_with_progress(huggingface_url, model_path) | |
logger.info(f"Successfully downloaded model to {model_path}") | |
return model_path | |
except Exception as e: | |
logger.error(f"Failed to download model: {e}", exc_info=True) | |
# Handle error more gracefully in production environment | |
if os.environ.get("SPACE_ID"): | |
logger.warning("Continuing despite model download failure") | |
return model_path | |
else: | |
sys.exit(1) | |
if __name__ == "__main__": | |
ensure_model_exists() |