Spaces:
Running
Running
import os | |
import torch | |
import librosa | |
import numpy as np | |
from typing import List, Dict, Any, Optional | |
from fastapi import FastAPI, UploadFile, File, HTTPException, BackgroundTasks | |
from fastapi.responses import JSONResponse | |
from fastapi.middleware.cors import CORSMiddleware | |
from pydantic import BaseModel | |
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC | |
from librosa.sequence import dtw | |
import tempfile | |
import uuid | |
import shutil | |
# Initialize FastAPI app | |
app = FastAPI( | |
title="Quran Recitation Comparison API", | |
description="API for comparing similarity between Quran recitations using Wav2Vec2 embeddings", | |
version="1.0.0" | |
) | |
# Add CORS middleware | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], # Allows all origins | |
allow_credentials=True, | |
allow_methods=["*"], # Allows all methods | |
allow_headers=["*"], # Allows all headers | |
) | |
# Global variables | |
MODEL = None | |
PROCESSOR = None | |
UPLOAD_DIR = os.path.join(tempfile.gettempdir(), "quran_comparison_uploads") | |
# Ensure upload directory exists | |
os.makedirs(UPLOAD_DIR, exist_ok=True) | |
# Response models | |
class SimilarityResponse(BaseModel): | |
similarity_score: float | |
interpretation: str | |
class ErrorResponse(BaseModel): | |
error: str | |
# Initialize model from environment variable | |
def initialize_model(): | |
global MODEL, PROCESSOR | |
# Get HF token from environment variable | |
hf_token = os.environ.get("HF_TOKEN", None) | |
model_name = os.environ.get("MODEL_NAME", "jonatasgrosman/wav2vec2-large-xlsr-53-arabic") | |
try: | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
print(f"Loading model on device: {device}") | |
# Load model and processor | |
if hf_token: | |
PROCESSOR = Wav2Vec2Processor.from_pretrained(model_name, use_auth_token=hf_token) | |
MODEL = Wav2Vec2ForCTC.from_pretrained(model_name, use_auth_token=hf_token) | |
else: | |
PROCESSOR = Wav2Vec2Processor.from_pretrained(model_name) | |
MODEL = Wav2Vec2ForCTC.from_pretrained(model_name) | |
MODEL = MODEL.to(device) | |
MODEL.eval() | |
print("Model loaded successfully") | |
except Exception as e: | |
print(f"Error loading model: {e}") | |
raise e | |
# Load audio file | |
def load_audio(file_path, target_sr=16000, trim_silence=True, normalize=True): | |
"""Load and preprocess an audio file.""" | |
try: | |
y, sr = librosa.load(file_path, sr=target_sr) | |
if normalize: | |
y = librosa.util.normalize(y) | |
if trim_silence: | |
y, _ = librosa.effects.trim(y, top_db=30) | |
return y | |
except Exception as e: | |
raise HTTPException(status_code=400, detail=f"Error loading audio: {e}") | |
# Get deep embedding | |
def get_deep_embedding(audio, sr=16000): | |
"""Extract frame-wise deep embeddings using the pretrained model.""" | |
global MODEL, PROCESSOR | |
if MODEL is None or PROCESSOR is None: | |
raise HTTPException(status_code=500, detail="Model not initialized") | |
try: | |
device = next(MODEL.parameters()).device | |
input_values = PROCESSOR( | |
audio, | |
sampling_rate=sr, | |
return_tensors="pt" | |
).input_values.to(device) | |
with torch.no_grad(): | |
outputs = MODEL(input_values, output_hidden_states=True) | |
hidden_states = outputs.hidden_states[-1] | |
embedding_seq = hidden_states.squeeze(0).cpu().numpy() | |
return embedding_seq | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=f"Error extracting embeddings: {e}") | |
# Compute DTW distance | |
def compute_dtw_distance(features1, features2): | |
"""Compute the DTW distance between two sequences of features.""" | |
try: | |
D, wp = dtw(X=features1, Y=features2, metric='euclidean') | |
distance = D[-1, -1] | |
normalized_distance = distance / len(wp) | |
return normalized_distance | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=f"Error computing DTW distance: {e}") | |
# Interpret similarity | |
def interpret_similarity(norm_distance): | |
"""Interpret the normalized distance value.""" | |
if norm_distance == 0: | |
result = "The recitations are identical based on the deep embeddings." | |
score = 100 | |
elif norm_distance < 1: | |
result = "The recitations are extremely similar." | |
score = 95 | |
elif norm_distance < 5: | |
result = "The recitations are very similar with minor differences." | |
score = 80 | |
elif norm_distance < 10: | |
result = "The recitations show moderate similarity." | |
score = 60 | |
elif norm_distance < 20: | |
result = "The recitations show some noticeable differences." | |
score = 40 | |
else: | |
result = "The recitations are quite different." | |
score = max(0, 100 - norm_distance) | |
return result, score | |
# Clean up temporary files | |
def cleanup_temp_files(file_paths): | |
"""Remove temporary files.""" | |
for file_path in file_paths: | |
if os.path.exists(file_path): | |
try: | |
os.remove(file_path) | |
except Exception as e: | |
print(f"Error removing temporary file {file_path}: {e}") | |
# API endpoints | |
async def compare_recitations( | |
background_tasks: BackgroundTasks, | |
file1: UploadFile = File(...), | |
file2: UploadFile = File(...) | |
): | |
""" | |
Compare two Quran recitations and return similarity metrics. | |
- **file1**: First audio file | |
- **file2**: Second audio file | |
Returns: | |
- **similarity_score**: Score between 0-100 indicating similarity | |
- **interpretation**: Text interpretation of the similarity | |
""" | |
# Check if model is initialized | |
if MODEL is None or PROCESSOR is None: | |
raise HTTPException(status_code=500, detail="Model not initialized") | |
# Temporary file paths | |
temp_file1 = os.path.join(UPLOAD_DIR, f"{uuid.uuid4()}.wav") | |
temp_file2 = os.path.join(UPLOAD_DIR, f"{uuid.uuid4()}.wav") | |
try: | |
# Save uploaded files | |
with open(temp_file1, "wb") as f: | |
shutil.copyfileobj(file1.file, f) | |
with open(temp_file2, "wb") as f: | |
shutil.copyfileobj(file2.file, f) | |
# Load audio files | |
audio1 = load_audio(temp_file1) | |
audio2 = load_audio(temp_file2) | |
# Extract embeddings | |
embedding1 = get_deep_embedding(audio1) | |
embedding2 = get_deep_embedding(audio2) | |
# Compute DTW distance | |
norm_distance = compute_dtw_distance(embedding1.T, embedding2.T) | |
# Interpret results | |
interpretation, similarity_score = interpret_similarity(norm_distance) | |
# Add cleanup task | |
background_tasks.add_task(cleanup_temp_files, [temp_file1, temp_file2]) | |
return { | |
"similarity_score": similarity_score, | |
"interpretation": interpretation | |
} | |
except Exception as e: | |
# Ensure files are cleaned up even in case of error | |
background_tasks.add_task(cleanup_temp_files, [temp_file1, temp_file2]) | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def health_check(): | |
"""Health check endpoint.""" | |
if MODEL is None or PROCESSOR is None: | |
return JSONResponse( | |
status_code=503, | |
content={"status": "error", "message": "Model not initialized"} | |
) | |
return {"status": "ok", "model_loaded": True} | |
# Initialize model on startup | |
async def startup_event(): | |
initialize_model() | |
# Run the FastAPI app | |
if __name__ == "__main__": | |
import uvicorn | |
port = int(os.environ.get("PORT", 7860)) # Default to port 7860 for Hugging Face Spaces | |
uvicorn.run("main:app", host="0.0.0.0", port=port, reload=False) |