Hammad712 commited on
Commit
46a11a0
·
verified ·
1 Parent(s): e6bd01f

Create main.py

Browse files
Files changed (1) hide show
  1. main.py +241 -0
main.py ADDED
@@ -0,0 +1,241 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import librosa
4
+ import numpy as np
5
+ from typing import List, Dict, Any, Optional
6
+ from fastapi import FastAPI, UploadFile, File, HTTPException, BackgroundTasks
7
+ from fastapi.responses import JSONResponse
8
+ from fastapi.middleware.cors import CORSMiddleware
9
+ from pydantic import BaseModel
10
+ from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
11
+ from librosa.sequence import dtw
12
+ import tempfile
13
+ import uuid
14
+ import shutil
15
+
16
+ # Initialize FastAPI app
17
+ app = FastAPI(
18
+ title="Quran Recitation Comparison API",
19
+ description="API for comparing similarity between Quran recitations using Wav2Vec2 embeddings",
20
+ version="1.0.0"
21
+ )
22
+
23
+ # Add CORS middleware
24
+ app.add_middleware(
25
+ CORSMiddleware,
26
+ allow_origins=["*"], # Allows all origins
27
+ allow_credentials=True,
28
+ allow_methods=["*"], # Allows all methods
29
+ allow_headers=["*"], # Allows all headers
30
+ )
31
+
32
+ # Global variables
33
+ MODEL = None
34
+ PROCESSOR = None
35
+ UPLOAD_DIR = os.path.join(tempfile.gettempdir(), "quran_comparison_uploads")
36
+
37
+ # Ensure upload directory exists
38
+ os.makedirs(UPLOAD_DIR, exist_ok=True)
39
+
40
+ # Response models
41
+ class SimilarityResponse(BaseModel):
42
+ similarity_score: float
43
+ interpretation: str
44
+
45
+ class ErrorResponse(BaseModel):
46
+ error: str
47
+
48
+ # Initialize model from environment variable
49
+ def initialize_model():
50
+ global MODEL, PROCESSOR
51
+
52
+ # Get HF token from environment variable
53
+ hf_token = os.environ.get("HF_TOKEN", None)
54
+ model_name = os.environ.get("MODEL_NAME", "jonatasgrosman/wav2vec2-large-xlsr-53-arabic")
55
+
56
+ try:
57
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
58
+ print(f"Loading model on device: {device}")
59
+
60
+ # Load model and processor
61
+ if hf_token:
62
+ PROCESSOR = Wav2Vec2Processor.from_pretrained(model_name, use_auth_token=hf_token)
63
+ MODEL = Wav2Vec2ForCTC.from_pretrained(model_name, use_auth_token=hf_token)
64
+ else:
65
+ PROCESSOR = Wav2Vec2Processor.from_pretrained(model_name)
66
+ MODEL = Wav2Vec2ForCTC.from_pretrained(model_name)
67
+
68
+ MODEL = MODEL.to(device)
69
+ MODEL.eval()
70
+ print("Model loaded successfully")
71
+ except Exception as e:
72
+ print(f"Error loading model: {e}")
73
+ raise e
74
+
75
+ # Load audio file
76
+ def load_audio(file_path, target_sr=16000, trim_silence=True, normalize=True):
77
+ """Load and preprocess an audio file."""
78
+ try:
79
+ y, sr = librosa.load(file_path, sr=target_sr)
80
+
81
+ if normalize:
82
+ y = librosa.util.normalize(y)
83
+
84
+ if trim_silence:
85
+ y, _ = librosa.effects.trim(y, top_db=30)
86
+
87
+ return y
88
+ except Exception as e:
89
+ raise HTTPException(status_code=400, detail=f"Error loading audio: {e}")
90
+
91
+ # Get deep embedding
92
+ def get_deep_embedding(audio, sr=16000):
93
+ """Extract frame-wise deep embeddings using the pretrained model."""
94
+ global MODEL, PROCESSOR
95
+
96
+ if MODEL is None or PROCESSOR is None:
97
+ raise HTTPException(status_code=500, detail="Model not initialized")
98
+
99
+ try:
100
+ device = next(MODEL.parameters()).device
101
+ input_values = PROCESSOR(
102
+ audio,
103
+ sampling_rate=sr,
104
+ return_tensors="pt"
105
+ ).input_values.to(device)
106
+
107
+ with torch.no_grad():
108
+ outputs = MODEL(input_values, output_hidden_states=True)
109
+
110
+ hidden_states = outputs.hidden_states[-1]
111
+ embedding_seq = hidden_states.squeeze(0).cpu().numpy()
112
+
113
+ return embedding_seq
114
+ except Exception as e:
115
+ raise HTTPException(status_code=500, detail=f"Error extracting embeddings: {e}")
116
+
117
+ # Compute DTW distance
118
+ def compute_dtw_distance(features1, features2):
119
+ """Compute the DTW distance between two sequences of features."""
120
+ try:
121
+ D, wp = dtw(X=features1, Y=features2, metric='euclidean')
122
+ distance = D[-1, -1]
123
+ normalized_distance = distance / len(wp)
124
+ return normalized_distance
125
+ except Exception as e:
126
+ raise HTTPException(status_code=500, detail=f"Error computing DTW distance: {e}")
127
+
128
+ # Interpret similarity
129
+ def interpret_similarity(norm_distance):
130
+ """Interpret the normalized distance value."""
131
+ if norm_distance == 0:
132
+ result = "The recitations are identical based on the deep embeddings."
133
+ score = 100
134
+ elif norm_distance < 1:
135
+ result = "The recitations are extremely similar."
136
+ score = 95
137
+ elif norm_distance < 5:
138
+ result = "The recitations are very similar with minor differences."
139
+ score = 80
140
+ elif norm_distance < 10:
141
+ result = "The recitations show moderate similarity."
142
+ score = 60
143
+ elif norm_distance < 20:
144
+ result = "The recitations show some noticeable differences."
145
+ score = 40
146
+ else:
147
+ result = "The recitations are quite different."
148
+ score = max(0, 100 - norm_distance)
149
+
150
+ return result, score
151
+
152
+ # Clean up temporary files
153
+ def cleanup_temp_files(file_paths):
154
+ """Remove temporary files."""
155
+ for file_path in file_paths:
156
+ if os.path.exists(file_path):
157
+ try:
158
+ os.remove(file_path)
159
+ except Exception as e:
160
+ print(f"Error removing temporary file {file_path}: {e}")
161
+
162
+ # API endpoints
163
+ @app.post("/compare", response_model=SimilarityResponse)
164
+ async def compare_recitations(
165
+ background_tasks: BackgroundTasks,
166
+ file1: UploadFile = File(...),
167
+ file2: UploadFile = File(...)
168
+ ):
169
+ """
170
+ Compare two Quran recitations and return similarity metrics.
171
+
172
+ - **file1**: First audio file
173
+ - **file2**: Second audio file
174
+
175
+ Returns:
176
+ - **similarity_score**: Score between 0-100 indicating similarity
177
+ - **interpretation**: Text interpretation of the similarity
178
+ """
179
+ # Check if model is initialized
180
+ if MODEL is None or PROCESSOR is None:
181
+ raise HTTPException(status_code=500, detail="Model not initialized")
182
+
183
+ # Temporary file paths
184
+ temp_file1 = os.path.join(UPLOAD_DIR, f"{uuid.uuid4()}.wav")
185
+ temp_file2 = os.path.join(UPLOAD_DIR, f"{uuid.uuid4()}.wav")
186
+
187
+ try:
188
+ # Save uploaded files
189
+ with open(temp_file1, "wb") as f:
190
+ shutil.copyfileobj(file1.file, f)
191
+
192
+ with open(temp_file2, "wb") as f:
193
+ shutil.copyfileobj(file2.file, f)
194
+
195
+ # Load audio files
196
+ audio1 = load_audio(temp_file1)
197
+ audio2 = load_audio(temp_file2)
198
+
199
+ # Extract embeddings
200
+ embedding1 = get_deep_embedding(audio1)
201
+ embedding2 = get_deep_embedding(audio2)
202
+
203
+ # Compute DTW distance
204
+ norm_distance = compute_dtw_distance(embedding1.T, embedding2.T)
205
+
206
+ # Interpret results
207
+ interpretation, similarity_score = interpret_similarity(norm_distance)
208
+
209
+ # Add cleanup task
210
+ background_tasks.add_task(cleanup_temp_files, [temp_file1, temp_file2])
211
+
212
+ return {
213
+ "similarity_score": similarity_score,
214
+ "interpretation": interpretation
215
+ }
216
+
217
+ except Exception as e:
218
+ # Ensure files are cleaned up even in case of error
219
+ background_tasks.add_task(cleanup_temp_files, [temp_file1, temp_file2])
220
+ raise HTTPException(status_code=500, detail=str(e))
221
+
222
+ @app.get("/health")
223
+ async def health_check():
224
+ """Health check endpoint."""
225
+ if MODEL is None or PROCESSOR is None:
226
+ return JSONResponse(
227
+ status_code=503,
228
+ content={"status": "error", "message": "Model not initialized"}
229
+ )
230
+ return {"status": "ok", "model_loaded": True}
231
+
232
+ # Initialize model on startup
233
+ @app.on_event("startup")
234
+ async def startup_event():
235
+ initialize_model()
236
+
237
+ # Run the FastAPI app
238
+ if __name__ == "__main__":
239
+ import uvicorn
240
+ port = int(os.environ.get("PORT", 7860)) # Default to port 7860 for Hugging Face Spaces
241
+ uvicorn.run("main:app", host="0.0.0.0", port=port, reload=False)