Spaces:
Running
Running
Create main.py
Browse files
main.py
ADDED
@@ -0,0 +1,241 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
import librosa
|
4 |
+
import numpy as np
|
5 |
+
from typing import List, Dict, Any, Optional
|
6 |
+
from fastapi import FastAPI, UploadFile, File, HTTPException, BackgroundTasks
|
7 |
+
from fastapi.responses import JSONResponse
|
8 |
+
from fastapi.middleware.cors import CORSMiddleware
|
9 |
+
from pydantic import BaseModel
|
10 |
+
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
|
11 |
+
from librosa.sequence import dtw
|
12 |
+
import tempfile
|
13 |
+
import uuid
|
14 |
+
import shutil
|
15 |
+
|
16 |
+
# Initialize FastAPI app
|
17 |
+
app = FastAPI(
|
18 |
+
title="Quran Recitation Comparison API",
|
19 |
+
description="API for comparing similarity between Quran recitations using Wav2Vec2 embeddings",
|
20 |
+
version="1.0.0"
|
21 |
+
)
|
22 |
+
|
23 |
+
# Add CORS middleware
|
24 |
+
app.add_middleware(
|
25 |
+
CORSMiddleware,
|
26 |
+
allow_origins=["*"], # Allows all origins
|
27 |
+
allow_credentials=True,
|
28 |
+
allow_methods=["*"], # Allows all methods
|
29 |
+
allow_headers=["*"], # Allows all headers
|
30 |
+
)
|
31 |
+
|
32 |
+
# Global variables
|
33 |
+
MODEL = None
|
34 |
+
PROCESSOR = None
|
35 |
+
UPLOAD_DIR = os.path.join(tempfile.gettempdir(), "quran_comparison_uploads")
|
36 |
+
|
37 |
+
# Ensure upload directory exists
|
38 |
+
os.makedirs(UPLOAD_DIR, exist_ok=True)
|
39 |
+
|
40 |
+
# Response models
|
41 |
+
class SimilarityResponse(BaseModel):
|
42 |
+
similarity_score: float
|
43 |
+
interpretation: str
|
44 |
+
|
45 |
+
class ErrorResponse(BaseModel):
|
46 |
+
error: str
|
47 |
+
|
48 |
+
# Initialize model from environment variable
|
49 |
+
def initialize_model():
|
50 |
+
global MODEL, PROCESSOR
|
51 |
+
|
52 |
+
# Get HF token from environment variable
|
53 |
+
hf_token = os.environ.get("HF_TOKEN", None)
|
54 |
+
model_name = os.environ.get("MODEL_NAME", "jonatasgrosman/wav2vec2-large-xlsr-53-arabic")
|
55 |
+
|
56 |
+
try:
|
57 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
58 |
+
print(f"Loading model on device: {device}")
|
59 |
+
|
60 |
+
# Load model and processor
|
61 |
+
if hf_token:
|
62 |
+
PROCESSOR = Wav2Vec2Processor.from_pretrained(model_name, use_auth_token=hf_token)
|
63 |
+
MODEL = Wav2Vec2ForCTC.from_pretrained(model_name, use_auth_token=hf_token)
|
64 |
+
else:
|
65 |
+
PROCESSOR = Wav2Vec2Processor.from_pretrained(model_name)
|
66 |
+
MODEL = Wav2Vec2ForCTC.from_pretrained(model_name)
|
67 |
+
|
68 |
+
MODEL = MODEL.to(device)
|
69 |
+
MODEL.eval()
|
70 |
+
print("Model loaded successfully")
|
71 |
+
except Exception as e:
|
72 |
+
print(f"Error loading model: {e}")
|
73 |
+
raise e
|
74 |
+
|
75 |
+
# Load audio file
|
76 |
+
def load_audio(file_path, target_sr=16000, trim_silence=True, normalize=True):
|
77 |
+
"""Load and preprocess an audio file."""
|
78 |
+
try:
|
79 |
+
y, sr = librosa.load(file_path, sr=target_sr)
|
80 |
+
|
81 |
+
if normalize:
|
82 |
+
y = librosa.util.normalize(y)
|
83 |
+
|
84 |
+
if trim_silence:
|
85 |
+
y, _ = librosa.effects.trim(y, top_db=30)
|
86 |
+
|
87 |
+
return y
|
88 |
+
except Exception as e:
|
89 |
+
raise HTTPException(status_code=400, detail=f"Error loading audio: {e}")
|
90 |
+
|
91 |
+
# Get deep embedding
|
92 |
+
def get_deep_embedding(audio, sr=16000):
|
93 |
+
"""Extract frame-wise deep embeddings using the pretrained model."""
|
94 |
+
global MODEL, PROCESSOR
|
95 |
+
|
96 |
+
if MODEL is None or PROCESSOR is None:
|
97 |
+
raise HTTPException(status_code=500, detail="Model not initialized")
|
98 |
+
|
99 |
+
try:
|
100 |
+
device = next(MODEL.parameters()).device
|
101 |
+
input_values = PROCESSOR(
|
102 |
+
audio,
|
103 |
+
sampling_rate=sr,
|
104 |
+
return_tensors="pt"
|
105 |
+
).input_values.to(device)
|
106 |
+
|
107 |
+
with torch.no_grad():
|
108 |
+
outputs = MODEL(input_values, output_hidden_states=True)
|
109 |
+
|
110 |
+
hidden_states = outputs.hidden_states[-1]
|
111 |
+
embedding_seq = hidden_states.squeeze(0).cpu().numpy()
|
112 |
+
|
113 |
+
return embedding_seq
|
114 |
+
except Exception as e:
|
115 |
+
raise HTTPException(status_code=500, detail=f"Error extracting embeddings: {e}")
|
116 |
+
|
117 |
+
# Compute DTW distance
|
118 |
+
def compute_dtw_distance(features1, features2):
|
119 |
+
"""Compute the DTW distance between two sequences of features."""
|
120 |
+
try:
|
121 |
+
D, wp = dtw(X=features1, Y=features2, metric='euclidean')
|
122 |
+
distance = D[-1, -1]
|
123 |
+
normalized_distance = distance / len(wp)
|
124 |
+
return normalized_distance
|
125 |
+
except Exception as e:
|
126 |
+
raise HTTPException(status_code=500, detail=f"Error computing DTW distance: {e}")
|
127 |
+
|
128 |
+
# Interpret similarity
|
129 |
+
def interpret_similarity(norm_distance):
|
130 |
+
"""Interpret the normalized distance value."""
|
131 |
+
if norm_distance == 0:
|
132 |
+
result = "The recitations are identical based on the deep embeddings."
|
133 |
+
score = 100
|
134 |
+
elif norm_distance < 1:
|
135 |
+
result = "The recitations are extremely similar."
|
136 |
+
score = 95
|
137 |
+
elif norm_distance < 5:
|
138 |
+
result = "The recitations are very similar with minor differences."
|
139 |
+
score = 80
|
140 |
+
elif norm_distance < 10:
|
141 |
+
result = "The recitations show moderate similarity."
|
142 |
+
score = 60
|
143 |
+
elif norm_distance < 20:
|
144 |
+
result = "The recitations show some noticeable differences."
|
145 |
+
score = 40
|
146 |
+
else:
|
147 |
+
result = "The recitations are quite different."
|
148 |
+
score = max(0, 100 - norm_distance)
|
149 |
+
|
150 |
+
return result, score
|
151 |
+
|
152 |
+
# Clean up temporary files
|
153 |
+
def cleanup_temp_files(file_paths):
|
154 |
+
"""Remove temporary files."""
|
155 |
+
for file_path in file_paths:
|
156 |
+
if os.path.exists(file_path):
|
157 |
+
try:
|
158 |
+
os.remove(file_path)
|
159 |
+
except Exception as e:
|
160 |
+
print(f"Error removing temporary file {file_path}: {e}")
|
161 |
+
|
162 |
+
# API endpoints
|
163 |
+
@app.post("/compare", response_model=SimilarityResponse)
|
164 |
+
async def compare_recitations(
|
165 |
+
background_tasks: BackgroundTasks,
|
166 |
+
file1: UploadFile = File(...),
|
167 |
+
file2: UploadFile = File(...)
|
168 |
+
):
|
169 |
+
"""
|
170 |
+
Compare two Quran recitations and return similarity metrics.
|
171 |
+
|
172 |
+
- **file1**: First audio file
|
173 |
+
- **file2**: Second audio file
|
174 |
+
|
175 |
+
Returns:
|
176 |
+
- **similarity_score**: Score between 0-100 indicating similarity
|
177 |
+
- **interpretation**: Text interpretation of the similarity
|
178 |
+
"""
|
179 |
+
# Check if model is initialized
|
180 |
+
if MODEL is None or PROCESSOR is None:
|
181 |
+
raise HTTPException(status_code=500, detail="Model not initialized")
|
182 |
+
|
183 |
+
# Temporary file paths
|
184 |
+
temp_file1 = os.path.join(UPLOAD_DIR, f"{uuid.uuid4()}.wav")
|
185 |
+
temp_file2 = os.path.join(UPLOAD_DIR, f"{uuid.uuid4()}.wav")
|
186 |
+
|
187 |
+
try:
|
188 |
+
# Save uploaded files
|
189 |
+
with open(temp_file1, "wb") as f:
|
190 |
+
shutil.copyfileobj(file1.file, f)
|
191 |
+
|
192 |
+
with open(temp_file2, "wb") as f:
|
193 |
+
shutil.copyfileobj(file2.file, f)
|
194 |
+
|
195 |
+
# Load audio files
|
196 |
+
audio1 = load_audio(temp_file1)
|
197 |
+
audio2 = load_audio(temp_file2)
|
198 |
+
|
199 |
+
# Extract embeddings
|
200 |
+
embedding1 = get_deep_embedding(audio1)
|
201 |
+
embedding2 = get_deep_embedding(audio2)
|
202 |
+
|
203 |
+
# Compute DTW distance
|
204 |
+
norm_distance = compute_dtw_distance(embedding1.T, embedding2.T)
|
205 |
+
|
206 |
+
# Interpret results
|
207 |
+
interpretation, similarity_score = interpret_similarity(norm_distance)
|
208 |
+
|
209 |
+
# Add cleanup task
|
210 |
+
background_tasks.add_task(cleanup_temp_files, [temp_file1, temp_file2])
|
211 |
+
|
212 |
+
return {
|
213 |
+
"similarity_score": similarity_score,
|
214 |
+
"interpretation": interpretation
|
215 |
+
}
|
216 |
+
|
217 |
+
except Exception as e:
|
218 |
+
# Ensure files are cleaned up even in case of error
|
219 |
+
background_tasks.add_task(cleanup_temp_files, [temp_file1, temp_file2])
|
220 |
+
raise HTTPException(status_code=500, detail=str(e))
|
221 |
+
|
222 |
+
@app.get("/health")
|
223 |
+
async def health_check():
|
224 |
+
"""Health check endpoint."""
|
225 |
+
if MODEL is None or PROCESSOR is None:
|
226 |
+
return JSONResponse(
|
227 |
+
status_code=503,
|
228 |
+
content={"status": "error", "message": "Model not initialized"}
|
229 |
+
)
|
230 |
+
return {"status": "ok", "model_loaded": True}
|
231 |
+
|
232 |
+
# Initialize model on startup
|
233 |
+
@app.on_event("startup")
|
234 |
+
async def startup_event():
|
235 |
+
initialize_model()
|
236 |
+
|
237 |
+
# Run the FastAPI app
|
238 |
+
if __name__ == "__main__":
|
239 |
+
import uvicorn
|
240 |
+
port = int(os.environ.get("PORT", 7860)) # Default to port 7860 for Hugging Face Spaces
|
241 |
+
uvicorn.run("main:app", host="0.0.0.0", port=port, reload=False)
|