Hammad712 commited on
Commit
521243d
·
verified ·
1 Parent(s): 3dceebc

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +109 -266
main.py CHANGED
@@ -1,187 +1,68 @@
1
- from fastapi import FastAPI, HTTPException, UploadFile, File
2
- from pydantic import BaseModel
3
  import torch
4
  import librosa
5
  import numpy as np
6
- import os
7
- from transformers import AutoProcessor, AutoModelForCTC
8
  import tempfile
9
- import shutil
10
- import uvicorn
11
- from fastapi.middleware.cors import CORSMiddleware
12
- import warnings
13
-
14
- # Ignore deprecation warnings
15
- warnings.filterwarnings("ignore")
16
-
17
- # Load environment variables
18
- HF_TOKEN = os.getenv("HF_TOKEN")
19
 
20
- app = FastAPI(title="Quran Recitation Comparer API")
21
-
22
- # Add CORS middleware
23
- app.add_middleware(
24
- CORSMiddleware,
25
- allow_origins=["*"],
26
- allow_credentials=True,
27
- allow_methods=["*"],
28
- allow_headers=["*"],
29
- )
30
-
31
- class ComparisonResult(BaseModel):
32
- similarity_score: float
33
- interpretation: str
34
-
35
- # Custom implementation of DTW
36
- def custom_dtw(X, Y, metric='euclidean'):
37
- """
38
- Custom Dynamic Time Warping implementation.
39
-
40
- Args:
41
- X: First sequence
42
- Y: Second sequence
43
- metric: Distance metric ('euclidean' or 'cosine')
44
-
45
- Returns:
46
- D: Cost matrix
47
- wp: Warping path
48
- """
49
- n, m = len(X), len(Y)
50
- D = np.zeros((n + 1, m + 1))
51
- D[0, 1:] = np.inf
52
- D[1:, 0] = np.inf
53
- D[0, 0] = 0
54
-
55
- for i in range(1, n + 1):
56
- for j in range(1, m + 1):
57
- if metric == 'euclidean':
58
- cost = np.sum((X[i-1] - Y[j-1])**2)
59
- elif metric == 'cosine':
60
- cost = 1 - np.dot(X[i-1], Y[j-1]) / (np.linalg.norm(X[i-1]) * np.linalg.norm(Y[j-1]))
61
- D[i, j] = cost + min(D[i-1, j], D[i, j-1], D[i-1, j-1])
62
-
63
- wp = [(n, m)]
64
- i, j = n, m
65
- while i > 0 or j > 0:
66
- if i == 0:
67
- j -= 1
68
- elif j == 0:
69
- i -= 1
70
- else:
71
- min_idx = np.argmin([D[i-1, j-1], D[i-1, j], D[i, j-1]])
72
- if min_idx == 0:
73
- i -= 1
74
- j -= 1
75
- elif min_idx == 1:
76
- i -= 1
77
- else:
78
- j -= 1
79
- wp.append((i, j))
80
-
81
- wp.reverse()
82
- return D, wp
83
 
 
84
  class QuranRecitationComparer:
85
- def __init__(self, model_name="jonatasgrosman/wav2vec2-large-xlsr-53-arabic", token=None):
86
- """Initialize the Quran recitation comparer with a specific Wav2Vec2 model."""
 
 
87
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
88
- print(f"Using device: {self.device}")
89
-
90
- try:
91
- if token:
92
- print(f"Loading model {model_name} with token...")
93
- # Use 'use_auth_token' instead of the deprecated 'token' parameter
94
- self.processor = AutoProcessor.from_pretrained(model_name, use_auth_token=token)
95
- self.model = AutoModelForCTC.from_pretrained(model_name, use_auth_token=token)
96
- else:
97
- print(f"Loading model {model_name} without token...")
98
- self.processor = AutoProcessor.from_pretrained(model_name)
99
- self.model = AutoModelForCTC.from_pretrained(model_name)
100
-
101
- self.model = self.model.to(self.device)
102
- self.model.eval()
103
- # Ensure that hidden states are returned by default
104
- self.model.config.output_hidden_states = True
105
- print("Model loaded successfully!")
106
- except Exception as e:
107
- print(f"Error loading model: {str(e)}")
108
- raise
109
-
110
  # Cache for embeddings to avoid recomputation
111
  self.embedding_cache = {}
112
 
113
- def load_audio(self, file_path, target_sr=16000, normalize=True):
114
  """Load and preprocess an audio file."""
115
  if not os.path.exists(file_path):
116
  raise FileNotFoundError(f"Audio file not found: {file_path}")
117
-
118
- print(f"Loading audio: {file_path}")
119
  y, sr = librosa.load(file_path, sr=target_sr)
120
-
121
  if normalize:
122
  y = librosa.util.normalize(y)
123
-
124
- # Trim silence using a simplified approach
125
- trim_y = []
126
- threshold = 0.02 # Threshold for silence detection
127
- for i in range(len(y)):
128
- if abs(y[i]) > threshold:
129
- trim_y.append(y[i])
130
-
131
- if len(trim_y) > 0:
132
- y = np.array(trim_y)
133
-
134
  return y
135
 
136
  def get_deep_embedding(self, audio, sr=16000):
137
  """Extract frame-wise deep embeddings using the pretrained model."""
138
- try:
139
- inputs = self.processor(
140
- audio,
141
- sampling_rate=sr,
142
- return_tensors="pt"
143
- ).input_values.to(self.device)
144
-
145
- with torch.no_grad():
146
- # Call the model without explicitly passing output_hidden_states
147
- outputs = self.model(inputs)
148
-
149
- hidden_states = outputs.hidden_states[-1]
150
- embedding_seq = hidden_states.squeeze(0).cpu().numpy()
151
-
152
- return embedding_seq
153
- except Exception as e:
154
- print(f"Error in get_deep_embedding: {str(e)}")
155
- raise
156
 
157
  def compute_dtw_distance(self, features1, features2):
158
  """Compute the DTW distance between two sequences of features."""
159
- if features1.ndim == 1:
160
- features1 = features1.reshape(-1, 1)
161
- if features2.ndim == 1:
162
- features2 = features2.reshape(-1, 1)
163
-
164
- print(f"Feature shapes: {features1.shape}, {features2.shape}")
165
-
166
- max_length = 300
167
- if features1.shape[0] > max_length or features2.shape[0] > max_length:
168
- step1 = max(1, features1.shape[0] // max_length)
169
- step2 = max(1, features2.shape[0] // max_length)
170
- features1 = features1[::step1]
171
- features2 = features2[::step2]
172
- print(f"Subsampled feature shapes: {features1.shape}, {features2.shape}")
173
-
174
- try:
175
- D, wp = custom_dtw(X=features1, Y=features2, metric='euclidean')
176
- distance = D[-1, -1]
177
- normalized_distance = distance / len(wp)
178
- return normalized_distance
179
- except Exception as e:
180
- print(f"Error in compute_dtw_distance: {str(e)}")
181
- mean_1 = np.mean(features1, axis=0)
182
- mean_2 = np.mean(features2, axis=0)
183
- euclidean_distance = np.sqrt(np.sum((mean_1 - mean_2) ** 2))
184
- return euclidean_distance
185
 
186
  def interpret_similarity(self, norm_distance):
187
  """Interpret the normalized distance value."""
@@ -203,142 +84,104 @@ class QuranRecitationComparer:
203
  else:
204
  result = "The recitations are quite different."
205
  score = max(0, 100 - norm_distance)
206
-
207
  return result, score
208
 
209
  def get_embedding_for_file(self, file_path):
210
  """Get embedding for a file, using cache if available."""
211
  if file_path in self.embedding_cache:
212
- print(f"Using cached embedding for {file_path}")
213
  return self.embedding_cache[file_path]
214
-
215
- print(f"Computing new embedding for {file_path}")
216
- try:
217
- audio = self.load_audio(file_path)
218
- embedding = self.get_deep_embedding(audio)
219
-
220
- self.embedding_cache[file_path] = embedding
221
- print(f"Embedding shape: {embedding.shape}")
222
-
223
- return embedding
224
- except Exception as e:
225
- print(f"Error getting embedding: {str(e)}")
226
- raise
227
 
228
  def predict(self, file_path1, file_path2):
229
  """
230
  Predict the similarity between two audio files.
231
-
232
  Args:
233
- file_path1 (str): Path to first audio file
234
- file_path2 (str): Path to second audio file
235
-
236
  Returns:
237
- float: Similarity score
238
- str: Interpretation of similarity
239
  """
240
- print(f"Comparing {file_path1} and {file_path2}")
241
- try:
242
- embedding1 = self.get_embedding_for_file(file_path1)
243
- embedding2 = self.get_embedding_for_file(file_path2)
244
-
245
- print("Computing DTW distance...")
246
- norm_distance = self.compute_dtw_distance(embedding1.T, embedding2.T)
247
- print(f"Normalized distance: {norm_distance}")
248
-
249
- interpretation, similarity_score = self.interpret_similarity(norm_distance)
250
- print(f"Similarity score: {similarity_score}, Interpretation: {interpretation}")
251
-
252
- return similarity_score, interpretation
253
- except Exception as e:
254
- print(f"Error in predict: {str(e)}")
255
- return 0, f"Error comparing files: {str(e)}"
256
 
257
  def clear_cache(self):
258
  """Clear the embedding cache to free memory."""
259
  self.embedding_cache = {}
260
- print("Embedding cache cleared")
261
 
262
- # Global variable for the comparer instance
263
- comparer = None
264
 
 
 
265
  @app.on_event("startup")
266
- async def startup_event():
267
- """Initialize the model when the application starts."""
268
  global comparer
269
- print("Initializing model... This may take a moment.")
270
- try:
271
- comparer = QuranRecitationComparer(
272
- model_name="jonatasgrosman/wav2vec2-large-xlsr-53-arabic",
273
- token=HF_TOKEN
274
- )
275
- print("Model initialized and ready for predictions!")
276
- except Exception as e:
277
- print(f"Error initializing model: {str(e)}")
278
 
279
- @app.get("/")
 
 
280
  async def root():
281
- """Root endpoint to check if the API is running."""
282
- status = "active" if comparer else "model not loaded"
283
- return {"message": "Quran Recitation Comparer API is running", "status": status}
284
 
285
- @app.post("/compare", response_model=ComparisonResult)
286
- async def compare_files(
287
- file1: UploadFile = File(...),
288
- file2: UploadFile = File(...)
289
- ):
290
  """
291
- Compare two audio files and return similarity metrics.
292
-
293
- - **file1**: First audio file (MP3, WAV, etc.)
294
- - **file2**: Second audio file (MP3, WAV, etc.)
295
 
296
- Returns similarity score and interpretation.
 
297
  """
298
- if not comparer:
299
- raise HTTPException(status_code=500, detail="Model not initialized. Please try again later.")
300
-
301
- print(f"Received files: {file1.filename} and {file2.filename}")
302
- temp_dir = tempfile.mkdtemp()
303
- print(f"Created temporary directory: {temp_dir}")
304
-
305
  try:
306
- temp_file1 = os.path.join(temp_dir, file1.filename)
307
- temp_file2 = os.path.join(temp_dir, file2.filename)
308
-
309
- with open(temp_file1, "wb") as f:
310
- content = await file1.read()
311
- f.write(content)
312
-
313
- with open(temp_file2, "wb") as f:
314
- content = await file2.read()
315
- f.write(content)
316
-
317
- print(f"Files saved to: {temp_file1} and {temp_file2}")
318
-
319
- similarity_score, interpretation = comparer.predict(temp_file1, temp_file2)
320
-
321
- return ComparisonResult(
322
- similarity_score=similarity_score,
323
- interpretation=interpretation
324
- )
325
-
326
  except Exception as e:
327
- print(f"Error processing files: {str(e)}")
328
- raise HTTPException(status_code=500, detail=f"Error processing files: {str(e)}")
329
-
330
  finally:
331
- print(f"Cleaning up temporary directory: {temp_dir}")
332
- shutil.rmtree(temp_dir, ignore_errors=True)
 
 
 
333
 
334
- @app.post("/clear-cache")
 
335
  async def clear_cache():
336
- """Clear the embedding cache to free memory."""
337
- if not comparer:
338
- raise HTTPException(status_code=500, detail="Model not initialized.")
339
-
340
  comparer.clear_cache()
341
- return {"message": "Embedding cache cleared successfully"}
342
-
343
- if __name__ == "__main__":
344
- uvicorn.run("main:app", host="0.0.0.0", port=7860, log_level="info")
 
1
+ import os
 
2
  import torch
3
  import librosa
4
  import numpy as np
 
 
5
  import tempfile
6
+ from fastapi import FastAPI, UploadFile, File, HTTPException
7
+ from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
8
+ from librosa.sequence import dtw
 
 
 
 
 
 
 
9
 
10
+ app = FastAPI(title="Quran Recitation Comparer API", description="Compares two Quran recitations using a deep wav2vec2 model.", version="1.0")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
+ # --- Core Class Definition ---
13
  class QuranRecitationComparer:
14
+ def __init__(self, model_name="jonatasgrosman/wav2vec2-large-xlsr-53-arabic", auth_token=None):
15
+ """
16
+ Initialize the Quran recitation comparer with a specific Wav2Vec2 model.
17
+ """
18
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
19
+
20
+ # Load model and processor once during initialization
21
+ if auth_token:
22
+ self.processor = Wav2Vec2Processor.from_pretrained(model_name, token=auth_token)
23
+ self.model = Wav2Vec2ForCTC.from_pretrained(model_name, token=auth_token)
24
+ else:
25
+ self.processor = Wav2Vec2Processor.from_pretrained(model_name)
26
+ self.model = Wav2Vec2ForCTC.from_pretrained(model_name)
27
+
28
+ self.model = self.model.to(self.device)
29
+ self.model.eval()
30
+
 
 
 
 
 
 
 
 
 
 
31
  # Cache for embeddings to avoid recomputation
32
  self.embedding_cache = {}
33
 
34
+ def load_audio(self, file_path, target_sr=16000, trim_silence=True, normalize=True):
35
  """Load and preprocess an audio file."""
36
  if not os.path.exists(file_path):
37
  raise FileNotFoundError(f"Audio file not found: {file_path}")
 
 
38
  y, sr = librosa.load(file_path, sr=target_sr)
 
39
  if normalize:
40
  y = librosa.util.normalize(y)
41
+ if trim_silence:
42
+ y, _ = librosa.effects.trim(y, top_db=30)
 
 
 
 
 
 
 
 
 
43
  return y
44
 
45
  def get_deep_embedding(self, audio, sr=16000):
46
  """Extract frame-wise deep embeddings using the pretrained model."""
47
+ input_values = self.processor(
48
+ audio,
49
+ sampling_rate=sr,
50
+ return_tensors="pt"
51
+ ).input_values.to(self.device)
52
+
53
+ with torch.no_grad():
54
+ outputs = self.model(input_values, output_hidden_states=True)
55
+
56
+ hidden_states = outputs.hidden_states[-1]
57
+ embedding_seq = hidden_states.squeeze(0).cpu().numpy()
58
+ return embedding_seq
 
 
 
 
 
 
59
 
60
  def compute_dtw_distance(self, features1, features2):
61
  """Compute the DTW distance between two sequences of features."""
62
+ D, wp = dtw(X=features1, Y=features2, metric='euclidean')
63
+ distance = D[-1, -1]
64
+ normalized_distance = distance / len(wp)
65
+ return normalized_distance
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
 
67
  def interpret_similarity(self, norm_distance):
68
  """Interpret the normalized distance value."""
 
84
  else:
85
  result = "The recitations are quite different."
86
  score = max(0, 100 - norm_distance)
 
87
  return result, score
88
 
89
  def get_embedding_for_file(self, file_path):
90
  """Get embedding for a file, using cache if available."""
91
  if file_path in self.embedding_cache:
 
92
  return self.embedding_cache[file_path]
93
+ audio = self.load_audio(file_path)
94
+ embedding = self.get_deep_embedding(audio)
95
+ # Store in cache for future use
96
+ self.embedding_cache[file_path] = embedding
97
+ return embedding
 
 
 
 
 
 
 
 
98
 
99
  def predict(self, file_path1, file_path2):
100
  """
101
  Predict the similarity between two audio files.
 
102
  Args:
103
+ file_path1 (str): Path to first audio file.
104
+ file_path2 (str): Path to second audio file.
 
105
  Returns:
106
+ (float, str): Similarity score and interpretation.
 
107
  """
108
+ embedding1 = self.get_embedding_for_file(file_path1)
109
+ embedding2 = self.get_embedding_for_file(file_path2)
110
+ norm_distance = self.compute_dtw_distance(embedding1.T, embedding2.T)
111
+ interpretation, similarity_score = self.interpret_similarity(norm_distance)
112
+ # Optionally log the results instead of printing in production
113
+ print(f"Similarity Score: {similarity_score:.1f}/100")
114
+ print(f"Interpretation: {interpretation}")
115
+ return similarity_score, interpretation
 
 
 
 
 
 
 
 
116
 
117
  def clear_cache(self):
118
  """Clear the embedding cache to free memory."""
119
  self.embedding_cache = {}
 
120
 
 
 
121
 
122
+ # --- FastAPI Startup Event ---
123
+ # In production, consider loading sensitive tokens from environment variables or configuration files.
124
  @app.on_event("startup")
125
+ def startup_event():
 
126
  global comparer
127
+ # For production, do not hardcode tokens; use os.environ.get(...) or a configuration system.
128
+ auth_token = os.environ.get("HF_TOKEN")
129
+ comparer = QuranRecitationComparer(
130
+ model_name="jonatasgrosman/wav2vec2-large-xlsr-53-arabic",
131
+ auth_token=auth_token
132
+ )
133
+ print("Model initialized and ready for predictions!")
 
 
134
 
135
+
136
+ # --- API Endpoints ---
137
+ @app.get("/", summary="Health Check")
138
  async def root():
139
+ return {"message": "Quran Recitation Comparer API is up and running."}
140
+
 
141
 
142
+ @app.post("/predict", summary="Compare Two Audio Files", response_model=dict)
143
+ async def predict(file1: UploadFile = File(...), file2: UploadFile = File(...)):
 
 
 
144
  """
145
+ Compare two uploaded audio files and return a similarity score along with an interpretation.
 
 
 
146
 
147
+ - **file1**: The first audio file.
148
+ - **file2**: The second audio file.
149
  """
150
+ tmp1_path = None
151
+ tmp2_path = None
152
+
 
 
 
 
153
  try:
154
+ # Save first file to a temporary location
155
+ suffix1 = os.path.splitext(file1.filename)[1] or ".wav"
156
+ with tempfile.NamedTemporaryFile(delete=False, suffix=suffix1) as tmp1:
157
+ content1 = await file1.read()
158
+ tmp1.write(content1)
159
+ tmp1_path = tmp1.name
160
+
161
+ # Save second file to a temporary location
162
+ suffix2 = os.path.splitext(file2.filename)[1] or ".wav"
163
+ with tempfile.NamedTemporaryFile(delete=False, suffix=suffix2) as tmp2:
164
+ content2 = await file2.read()
165
+ tmp2.write(content2)
166
+ tmp2_path = tmp2.name
167
+
168
+ similarity_score, interpretation = comparer.predict(tmp1_path, tmp2_path)
169
+ return {"similarity_score": similarity_score, "interpretation": interpretation}
170
+
 
 
 
171
  except Exception as e:
172
+ raise HTTPException(status_code=500, detail=str(e))
 
 
173
  finally:
174
+ # Clean up temporary files
175
+ if tmp1_path and os.path.exists(tmp1_path):
176
+ os.remove(tmp1_path)
177
+ if tmp2_path and os.path.exists(tmp2_path):
178
+ os.remove(tmp2_path)
179
 
180
+
181
+ @app.post("/clear_cache", summary="Clear Embedding Cache", response_model=dict)
182
  async def clear_cache():
183
+ """
184
+ Clear the embedding cache. This can help free memory if many comparisons have been made.
185
+ """
 
186
  comparer.clear_cache()
187
+ return {"message": "Cache cleared."}