Spaces:
Running
Running
File size: 6,699 Bytes
da764f1 538987f da764f1 538987f da764f1 538987f da764f1 538987f da764f1 538987f da764f1 538987f da764f1 538987f da764f1 606184e da764f1 538987f 606184e da764f1 538987f da764f1 538987f da764f1 538987f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 |
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""Script to download the chorus detection model from HuggingFace.
This script checks if the model file exists locally, and if not, downloads it
from the specified HuggingFace repository.
"""
import os
import sys
from pathlib import Path
import logging
# Configure logging
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger("model-downloader")
# Debug environment info
logger.info(f"Current working directory: {os.getcwd()}")
logger.info(f"Python path: {sys.path}")
logger.info(f"MODEL_REVISION: {os.environ.get('MODEL_REVISION')}")
logger.info(f"MODEL_HF_REPO: {os.environ.get('MODEL_HF_REPO')}")
logger.info(f"HF_MODEL_FILENAME: {os.environ.get('HF_MODEL_FILENAME')}")
# Use huggingface_hub for better integration with HF ecosystem
try:
from huggingface_hub import hf_hub_download
HF_HUB_AVAILABLE = True
logger.info("huggingface_hub is available")
except ImportError:
HF_HUB_AVAILABLE = False
logger.warning("huggingface_hub is not available, falling back to direct download")
import requests
from tqdm import tqdm
def download_file_with_progress(url: str, destination: Path) -> None:
"""Download a file with a progress bar.
Args:
url: URL to download from
destination: Path to save the file to
"""
# Create parent directories if they don't exist
destination.parent.mkdir(parents=True, exist_ok=True)
# Stream the download with progress bar
response = requests.get(url, stream=True)
response.raise_for_status()
total_size = int(response.headers.get('content-length', 0))
block_size = 1024 # 1 Kibibyte
logger.info(f"Downloading model from {url}")
logger.info(f"File size: {total_size / (1024*1024):.1f} MB")
with open(destination, 'wb') as file, tqdm(
desc=destination.name,
total=total_size,
unit='iB',
unit_scale=True,
unit_divisor=1024,
) as bar:
for data in response.iter_content(block_size):
size = file.write(data)
bar.update(size)
def ensure_model_exists(
model_filename: str = "best_model_V3.h5",
repo_id: str = None,
model_dir: Path = None,
hf_model_filename: str = None,
revision: str = None
) -> Path:
"""Ensure the model file exists, downloading it if necessary.
Args:
model_filename: Local filename for the model
repo_id: HuggingFace repository ID
model_dir: Directory to save the model to
hf_model_filename: Filename of the model in the HuggingFace repo
revision: Specific version of the model to use (SHA-256 hash)
Returns:
Path to the model file
"""
# Get parameters from environment variables if not provided
if repo_id is None:
repo_id = os.environ.get("MODEL_HF_REPO", "dennisvdang/chorus-detection")
if hf_model_filename is None:
hf_model_filename = os.environ.get("HF_MODEL_FILENAME", "chorus_detection_crnn.h5")
if revision is None:
revision = os.environ.get("MODEL_REVISION", "20e66eb3d0788373c3bdc5b28fa2f2587b0e475f3bbc47e8ab9ff0dbdbb2df32")
# Handle model directory paths for different environments
if model_dir is None:
# Check if we're in HF Spaces
if os.environ.get("SPACE_ID"):
# Try several possible locations
possible_dirs = [
Path("models/CRNN"),
Path("/home/user/app/models/CRNN"),
Path("/app/models/CRNN"),
Path(os.getcwd()) / "models" / "CRNN"
]
for directory in possible_dirs:
if directory.exists() or directory.parent.exists():
model_dir = directory
break
# If none exist, use the first option and create it
if model_dir is None:
model_dir = possible_dirs[0]
else:
model_dir = Path("models/CRNN")
# Make sure model_dir is a Path object
if isinstance(model_dir, str):
model_dir = Path(model_dir)
logger.info(f"Using model directory: {model_dir}")
model_path = model_dir / model_filename
# Log environment info when running in HF Space
if os.environ.get("SPACE_ID"):
logger.info(f"Running in Hugging Face Space: {os.environ.get('SPACE_ID')}")
logger.info(f"Using model repo: {repo_id}")
logger.info(f"Using model file: {hf_model_filename}")
logger.info(f"Using revision: {revision}")
# Check if the model already exists
if model_path.exists():
logger.info(f"Model already exists at {model_path}")
return model_path
# Create model directory if it doesn't exist
model_dir.mkdir(parents=True, exist_ok=True)
logger.info(f"Model not found at {model_path}. Downloading...")
try:
if HF_HUB_AVAILABLE:
# Use huggingface_hub to download the model
logger.info(f"Downloading model from {repo_id}/{hf_model_filename}")
downloaded_path = hf_hub_download(
repo_id=repo_id,
filename=hf_model_filename,
local_dir=model_dir,
local_dir_use_symlinks=False,
revision=revision
)
# Rename if necessary
if os.path.basename(downloaded_path) != model_filename:
downloaded_path_obj = Path(downloaded_path)
model_path.parent.mkdir(parents=True, exist_ok=True)
if model_path.exists():
model_path.unlink()
downloaded_path_obj.rename(model_path)
logger.info(f"Renamed {downloaded_path} to {model_path}")
else:
# Fallback to direct download if huggingface_hub is not available
huggingface_url = f"https://huggingface.co/{repo_id}/resolve/{revision}/{hf_model_filename}"
download_file_with_progress(huggingface_url, model_path)
logger.info(f"Successfully downloaded model to {model_path}")
return model_path
except Exception as e:
logger.error(f"Failed to download model: {e}", exc_info=True)
# Handle error more gracefully in production environment
if os.environ.get("SPACE_ID"):
logger.warning("Continuing despite model download failure")
return model_path
else:
sys.exit(1)
if __name__ == "__main__":
ensure_model_exists() |