Hammad712 commited on
Commit
0318876
·
verified ·
1 Parent(s): ef20d33

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +24 -57
main.py CHANGED
@@ -20,8 +20,6 @@ os.environ["NUMBA_DISABLE_JIT"] = "1"
20
  MODEL = None
21
  PROCESSOR = None
22
  UPLOAD_DIR = os.path.join(tempfile.gettempdir(), "quran_comparison_uploads")
23
-
24
- # Ensure upload directory exists
25
  os.makedirs(UPLOAD_DIR, exist_ok=True)
26
 
27
  # Response models
@@ -63,7 +61,7 @@ async def lifespan(app: FastAPI):
63
  initialize_model()
64
  yield
65
 
66
- # Create the FastAPI app with the lifespan handler and CORS middleware
67
  app = FastAPI(
68
  title="Quran Recitation Comparison API",
69
  description="API for comparing similarity between Quran recitations using Wav2Vec2 embeddings",
@@ -90,13 +88,10 @@ def load_audio(file_path, target_sr=16000, trim_silence=True, normalize=True):
90
  """Load and preprocess an audio file."""
91
  try:
92
  y, sr = librosa.load(file_path, sr=target_sr)
93
-
94
  if normalize:
95
  y = librosa.util.normalize(y)
96
-
97
  if trim_silence:
98
  y, _ = librosa.effects.trim(y, top_db=30)
99
-
100
  return y
101
  except Exception as e:
102
  raise HTTPException(status_code=400, detail=f"Error loading audio: {e}")
@@ -105,10 +100,8 @@ def load_audio(file_path, target_sr=16000, trim_silence=True, normalize=True):
105
  def get_deep_embedding(audio, sr=16000):
106
  """Extract frame-wise deep embeddings using the pretrained model."""
107
  global MODEL, PROCESSOR
108
-
109
  if MODEL is None or PROCESSOR is None:
110
  raise HTTPException(status_code=500, detail="Model not initialized")
111
-
112
  try:
113
  device = next(MODEL.parameters()).device
114
  input_values = PROCESSOR(
@@ -122,28 +115,22 @@ def get_deep_embedding(audio, sr=16000):
122
 
123
  hidden_states = outputs.hidden_states[-1]
124
  embedding_seq = hidden_states.squeeze(0).cpu().numpy()
125
-
126
  return embedding_seq
127
  except Exception as e:
128
  raise HTTPException(status_code=500, detail=f"Error extracting embeddings: {e}")
129
 
130
- # Custom DTW implementation to avoid librosa.sequence.dtw issues
131
  def custom_dtw(X, Y, metric='euclidean'):
132
  """
133
- Custom implementation of DTW to avoid librosa.sequence.dtw issues.
134
-
135
- Parameters:
136
- X, Y : numpy.ndarray
137
- The two sequences to be aligned
138
- metric : str, optional
139
- The distance metric to use
140
-
141
- Returns:
142
- D : numpy.ndarray
143
- The accumulated cost matrix
144
- wp : list
145
- The warping path
146
  """
 
 
 
 
 
 
147
  n, m = len(X[0]), len(Y[0])
148
  D = np.zeros((n+1, m+1))
149
  D[0, :] = np.inf
@@ -157,8 +144,7 @@ def custom_dtw(X, Y, metric='euclidean'):
157
  elif metric == 'cosine':
158
  cost = 1 - np.dot(X[:, i-1], Y[:, j-1]) / (np.linalg.norm(X[:, i-1]) * np.linalg.norm(Y[:, j-1]))
159
  else:
160
- cost = np.sum(np.abs(X[:, i-1] - Y[:, j-1])) # Manhattan by default
161
-
162
  D[i, j] = cost + min(D[i-1, j], D[i, j-1], D[i-1, j-1])
163
 
164
  i, j = n, m
@@ -183,9 +169,8 @@ def compute_dtw_distance(features1, features2):
183
  except Exception as e:
184
  raise HTTPException(status_code=500, detail=f"Error computing DTW distance: {e}")
185
 
186
- # Interpret similarity
187
  def interpret_similarity(norm_distance):
188
- """Interpret the normalized distance value."""
189
  if norm_distance == 0:
190
  result = "The recitations are identical based on the deep embeddings."
191
  score = 100
@@ -204,12 +189,10 @@ def interpret_similarity(norm_distance):
204
  else:
205
  result = "The recitations are quite different."
206
  score = max(0, 100 - norm_distance)
207
-
208
  return result, score
209
 
210
  # Clean up temporary files
211
  def cleanup_temp_files(file_paths):
212
- """Remove temporary files."""
213
  for file_path in file_paths:
214
  if os.path.exists(file_path):
215
  try:
@@ -224,63 +207,47 @@ async def compare_recitations(
224
  file1: UploadFile = File(...),
225
  file2: UploadFile = File(...)
226
  ):
227
- """
228
- Compare two Quran recitations and return similarity metrics.
229
-
230
- - **file1**: First audio file
231
- - **file2**: Second audio file
232
-
233
- Returns:
234
- - **similarity_score**: Score between 0-100 indicating similarity
235
- - **interpretation**: Text interpretation of the similarity
236
- """
237
- if MODEL is None or PROCESSOR is None:
238
- raise HTTPException(status_code=500, detail="Model not initialized")
239
-
240
  temp_file1 = os.path.join(UPLOAD_DIR, f"{uuid.uuid4()}.wav")
241
  temp_file2 = os.path.join(UPLOAD_DIR, f"{uuid.uuid4()}.wav")
242
-
243
  try:
 
244
  with open(temp_file1, "wb") as f:
245
  shutil.copyfileobj(file1.file, f)
246
-
247
  with open(temp_file2, "wb") as f:
248
  shutil.copyfileobj(file2.file, f)
249
 
 
250
  audio1 = load_audio(temp_file1)
251
  audio2 = load_audio(temp_file2)
252
 
 
253
  embedding1 = get_deep_embedding(audio1)
254
  embedding2 = get_deep_embedding(audio2)
255
 
 
256
  norm_distance = compute_dtw_distance(embedding1.T, embedding2.T)
257
-
258
  interpretation, similarity_score = interpret_similarity(norm_distance)
259
 
260
  background_tasks.add_task(cleanup_temp_files, [temp_file1, temp_file2])
261
-
262
- return {
263
- "similarity_score": similarity_score,
264
- "interpretation": interpretation
265
- }
266
 
 
 
 
267
  except Exception as e:
268
  background_tasks.add_task(cleanup_temp_files, [temp_file1, temp_file2])
269
- raise HTTPException(status_code=500, detail=str(e))
 
270
 
271
  # Health check endpoint
272
  @app.get("/health")
273
  async def health_check():
274
- """Health check endpoint."""
275
  if MODEL is None or PROCESSOR is None:
276
- return JSONResponse(
277
- status_code=503,
278
- content={"status": "error", "message": "Model not initialized"}
279
- )
280
  return {"status": "ok", "model_loaded": True}
281
 
282
  # Run the FastAPI app
283
  if __name__ == "__main__":
284
  import uvicorn
285
- port = int(os.environ.get("PORT", 7860)) # Default to port 7860
286
  uvicorn.run("main:app", host="0.0.0.0", port=port, reload=False)
 
20
  MODEL = None
21
  PROCESSOR = None
22
  UPLOAD_DIR = os.path.join(tempfile.gettempdir(), "quran_comparison_uploads")
 
 
23
  os.makedirs(UPLOAD_DIR, exist_ok=True)
24
 
25
  # Response models
 
61
  initialize_model()
62
  yield
63
 
64
+ # Create the FastAPI app with the lifespan handler and add CORS middleware
65
  app = FastAPI(
66
  title="Quran Recitation Comparison API",
67
  description="API for comparing similarity between Quran recitations using Wav2Vec2 embeddings",
 
88
  """Load and preprocess an audio file."""
89
  try:
90
  y, sr = librosa.load(file_path, sr=target_sr)
 
91
  if normalize:
92
  y = librosa.util.normalize(y)
 
93
  if trim_silence:
94
  y, _ = librosa.effects.trim(y, top_db=30)
 
95
  return y
96
  except Exception as e:
97
  raise HTTPException(status_code=400, detail=f"Error loading audio: {e}")
 
100
  def get_deep_embedding(audio, sr=16000):
101
  """Extract frame-wise deep embeddings using the pretrained model."""
102
  global MODEL, PROCESSOR
 
103
  if MODEL is None or PROCESSOR is None:
104
  raise HTTPException(status_code=500, detail="Model not initialized")
 
105
  try:
106
  device = next(MODEL.parameters()).device
107
  input_values = PROCESSOR(
 
115
 
116
  hidden_states = outputs.hidden_states[-1]
117
  embedding_seq = hidden_states.squeeze(0).cpu().numpy()
 
118
  return embedding_seq
119
  except Exception as e:
120
  raise HTTPException(status_code=500, detail=f"Error extracting embeddings: {e}")
121
 
122
+ # Custom DTW implementation to avoid issues with librosa's dtw
123
  def custom_dtw(X, Y, metric='euclidean'):
124
  """
125
+ Custom implementation of DTW.
126
+ X and Y are expected to be 2D numpy arrays.
 
 
 
 
 
 
 
 
 
 
 
127
  """
128
+ # Check inputs are 2D and non-empty
129
+ if X.ndim != 2 or Y.ndim != 2:
130
+ raise ValueError("Input features must be 2D arrays.")
131
+ if X.shape[1] == 0 or Y.shape[1] == 0:
132
+ raise ValueError("Empty embedding sequence encountered.")
133
+
134
  n, m = len(X[0]), len(Y[0])
135
  D = np.zeros((n+1, m+1))
136
  D[0, :] = np.inf
 
144
  elif metric == 'cosine':
145
  cost = 1 - np.dot(X[:, i-1], Y[:, j-1]) / (np.linalg.norm(X[:, i-1]) * np.linalg.norm(Y[:, j-1]))
146
  else:
147
+ cost = np.sum(np.abs(X[:, i-1] - Y[:, j-1]))
 
148
  D[i, j] = cost + min(D[i-1, j], D[i, j-1], D[i-1, j-1])
149
 
150
  i, j = n, m
 
169
  except Exception as e:
170
  raise HTTPException(status_code=500, detail=f"Error computing DTW distance: {e}")
171
 
172
+ # Interpret similarity based on the normalized distance
173
  def interpret_similarity(norm_distance):
 
174
  if norm_distance == 0:
175
  result = "The recitations are identical based on the deep embeddings."
176
  score = 100
 
189
  else:
190
  result = "The recitations are quite different."
191
  score = max(0, 100 - norm_distance)
 
192
  return result, score
193
 
194
  # Clean up temporary files
195
  def cleanup_temp_files(file_paths):
 
196
  for file_path in file_paths:
197
  if os.path.exists(file_path):
198
  try:
 
207
  file1: UploadFile = File(...),
208
  file2: UploadFile = File(...)
209
  ):
 
 
 
 
 
 
 
 
 
 
 
 
 
210
  temp_file1 = os.path.join(UPLOAD_DIR, f"{uuid.uuid4()}.wav")
211
  temp_file2 = os.path.join(UPLOAD_DIR, f"{uuid.uuid4()}.wav")
 
212
  try:
213
+ # Save uploaded files to temporary locations
214
  with open(temp_file1, "wb") as f:
215
  shutil.copyfileobj(file1.file, f)
 
216
  with open(temp_file2, "wb") as f:
217
  shutil.copyfileobj(file2.file, f)
218
 
219
+ # Load audio files
220
  audio1 = load_audio(temp_file1)
221
  audio2 = load_audio(temp_file2)
222
 
223
+ # Extract embeddings
224
  embedding1 = get_deep_embedding(audio1)
225
  embedding2 = get_deep_embedding(audio2)
226
 
227
+ # Compute DTW distance (transpose so each column represents a frame)
228
  norm_distance = compute_dtw_distance(embedding1.T, embedding2.T)
 
229
  interpretation, similarity_score = interpret_similarity(norm_distance)
230
 
231
  background_tasks.add_task(cleanup_temp_files, [temp_file1, temp_file2])
232
+ return {"similarity_score": similarity_score, "interpretation": interpretation}
 
 
 
 
233
 
234
+ except HTTPException as he:
235
+ background_tasks.add_task(cleanup_temp_files, [temp_file1, temp_file2])
236
+ raise he
237
  except Exception as e:
238
  background_tasks.add_task(cleanup_temp_files, [temp_file1, temp_file2])
239
+ print(f"Unexpected error in /compare: {e}")
240
+ raise HTTPException(status_code=500, detail="An unexpected error occurred during comparison.")
241
 
242
  # Health check endpoint
243
  @app.get("/health")
244
  async def health_check():
 
245
  if MODEL is None or PROCESSOR is None:
246
+ return JSONResponse(status_code=503, content={"status": "error", "message": "Model not initialized"})
 
 
 
247
  return {"status": "ok", "model_loaded": True}
248
 
249
  # Run the FastAPI app
250
  if __name__ == "__main__":
251
  import uvicorn
252
+ port = int(os.environ.get("PORT", 7860))
253
  uvicorn.run("main:app", host="0.0.0.0", port=port, reload=False)