Hammad712 commited on
Commit
1dbeaf5
·
verified ·
1 Parent(s): 94ba3d3

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +91 -12
main.py CHANGED
@@ -1,16 +1,15 @@
1
- from fastapi import FastAPI, HTTPException, UploadFile, File, Form
2
  from pydantic import BaseModel
3
- from typing import Optional
4
  import torch
5
  import librosa
6
  import numpy as np
7
  import os
8
  from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
9
- from librosa.sequence import dtw
10
  import tempfile
11
  import shutil
12
  from dotenv import load_dotenv
13
  import uvicorn
 
14
 
15
  # Load environment variables
16
  load_dotenv()
@@ -22,16 +21,73 @@ class ComparisonResult(BaseModel):
22
  similarity_score: float
23
  interpretation: str
24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  class QuranRecitationComparer:
26
  def __init__(self, model_name="jonatasgrosman/wav2vec2-large-xlsr-53-arabic", token=None):
27
  """Initialize the Quran recitation comparer with a specific Wav2Vec2 model."""
28
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
29
 
30
  # Load model and processor once during initialization
31
  if token:
 
32
  self.processor = Wav2Vec2Processor.from_pretrained(model_name, use_auth_token=token)
33
  self.model = Wav2Vec2ForCTC.from_pretrained(model_name, use_auth_token=token)
34
  else:
 
35
  self.processor = Wav2Vec2Processor.from_pretrained(model_name)
36
  self.model = Wav2Vec2ForCTC.from_pretrained(model_name)
37
 
@@ -40,18 +96,21 @@ class QuranRecitationComparer:
40
 
41
  # Cache for embeddings to avoid recomputation
42
  self.embedding_cache = {}
 
43
 
44
  def load_audio(self, file_path, target_sr=16000, trim_silence=True, normalize=True):
45
  """Load and preprocess an audio file."""
46
  if not os.path.exists(file_path):
47
  raise FileNotFoundError(f"Audio file not found: {file_path}")
48
 
 
49
  y, sr = librosa.load(file_path, sr=target_sr)
50
 
51
  if normalize:
52
  y = librosa.util.normalize(y)
53
 
54
  if trim_silence:
 
55
  y, _ = librosa.effects.trim(y, top_db=30)
56
 
57
  return y
@@ -74,7 +133,7 @@ class QuranRecitationComparer:
74
 
75
  def compute_dtw_distance(self, features1, features2):
76
  """Compute the DTW distance between two sequences of features."""
77
- D, wp = dtw(X=features1, Y=features2, metric='euclidean')
78
  distance = D[-1, -1]
79
  normalized_distance = distance / len(wp)
80
  return normalized_distance
@@ -105,13 +164,16 @@ class QuranRecitationComparer:
105
  def get_embedding_for_file(self, file_path):
106
  """Get embedding for a file, using cache if available."""
107
  if file_path in self.embedding_cache:
 
108
  return self.embedding_cache[file_path]
109
 
 
110
  audio = self.load_audio(file_path)
111
  embedding = self.get_deep_embedding(audio)
112
 
113
  # Store in cache for future use
114
  self.embedding_cache[file_path] = embedding
 
115
 
116
  return embedding
117
 
@@ -128,21 +190,26 @@ class QuranRecitationComparer:
128
  float: Similarity score
129
  str: Interpretation of similarity
130
  """
 
131
  # Get embeddings (using cache if available)
132
  embedding1 = self.get_embedding_for_file(file_path1)
133
  embedding2 = self.get_embedding_for_file(file_path2)
134
 
135
  # Compute DTW distance
 
136
  norm_distance = self.compute_dtw_distance(embedding1.T, embedding2.T)
 
137
 
138
  # Interpret results
139
  interpretation, similarity_score = self.interpret_similarity(norm_distance)
 
140
 
141
  return similarity_score, interpretation
142
 
143
  def clear_cache(self):
144
  """Clear the embedding cache to free memory."""
145
  self.embedding_cache = {}
 
146
 
147
  # Global variable for the comparer instance
148
  comparer = None
@@ -152,11 +219,15 @@ async def startup_event():
152
  """Initialize the model when the application starts."""
153
  global comparer
154
  print("Initializing model... This may take a moment.")
155
- comparer = QuranRecitationComparer(
156
- model_name="jonatasgrosman/wav2vec2-large-xlsr-53-arabic",
157
- token=HF_TOKEN
158
- )
159
- print("Model initialized and ready for predictions!")
 
 
 
 
160
 
161
  @app.get("/")
162
  async def root():
@@ -179,7 +250,9 @@ async def compare_files(
179
  if not comparer:
180
  raise HTTPException(status_code=500, detail="Model not initialized. Please try again later.")
181
 
 
182
  temp_dir = tempfile.mkdtemp()
 
183
 
184
  try:
185
  # Save uploaded files to temporary directory
@@ -187,10 +260,14 @@ async def compare_files(
187
  temp_file2 = os.path.join(temp_dir, file2.filename)
188
 
189
  with open(temp_file1, "wb") as f:
190
- shutil.copyfileobj(file1.file, f)
 
191
 
192
  with open(temp_file2, "wb") as f:
193
- shutil.copyfileobj(file2.file, f)
 
 
 
194
 
195
  # Compare the files
196
  similarity_score, interpretation = comparer.predict(temp_file1, temp_file2)
@@ -201,10 +278,12 @@ async def compare_files(
201
  )
202
 
203
  except Exception as e:
 
204
  raise HTTPException(status_code=500, detail=f"Error processing files: {str(e)}")
205
 
206
  finally:
207
  # Clean up temporary files
 
208
  shutil.rmtree(temp_dir, ignore_errors=True)
209
 
210
  @app.post("/clear-cache")
@@ -217,4 +296,4 @@ async def clear_cache():
217
  return {"message": "Embedding cache cleared successfully"}
218
 
219
  if __name__ == "__main__":
220
- uvicorn.run("app:app", host="0.0.0.0", port=8000, reload=True)
 
1
+ from fastapi import FastAPI, HTTPException, UploadFile, File
2
  from pydantic import BaseModel
 
3
  import torch
4
  import librosa
5
  import numpy as np
6
  import os
7
  from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
 
8
  import tempfile
9
  import shutil
10
  from dotenv import load_dotenv
11
  import uvicorn
12
+ import scipy.spatial.distance as distance
13
 
14
  # Load environment variables
15
  load_dotenv()
 
21
  similarity_score: float
22
  interpretation: str
23
 
24
+ # Custom implementation of DTW to replace librosa.sequence.dtw
25
+ def custom_dtw(X, Y, metric='euclidean'):
26
+ """
27
+ Custom Dynamic Time Warping implementation.
28
+
29
+ Args:
30
+ X: First sequence
31
+ Y: Second sequence
32
+ metric: Distance metric ('euclidean' or 'cosine')
33
+
34
+ Returns:
35
+ D: Cost matrix
36
+ wp: Warping path
37
+ """
38
+ # Get sequence lengths
39
+ n, m = len(X), len(Y)
40
+
41
+ # Initialize cost matrix
42
+ D = np.zeros((n + 1, m + 1))
43
+ D[0, 1:] = np.inf
44
+ D[1:, 0] = np.inf
45
+ D[0, 0] = 0
46
+
47
+ # Fill cost matrix
48
+ for i in range(1, n + 1):
49
+ for j in range(1, m + 1):
50
+ if metric == 'euclidean':
51
+ cost = np.sum((X[i-1] - Y[j-1])**2)
52
+ elif metric == 'cosine':
53
+ cost = 1 - np.dot(X[i-1], Y[j-1]) / (np.linalg.norm(X[i-1]) * np.linalg.norm(Y[j-1]))
54
+ D[i, j] = cost + min(D[i-1, j], D[i, j-1], D[i-1, j-1])
55
+
56
+ # Backtracking
57
+ wp = [(n, m)]
58
+ i, j = n, m
59
+ while i > 0 or j > 0:
60
+ if i == 0:
61
+ j -= 1
62
+ elif j == 0:
63
+ i -= 1
64
+ else:
65
+ min_idx = np.argmin([D[i-1, j-1], D[i-1, j], D[i, j-1]])
66
+ if min_idx == 0:
67
+ i -= 1
68
+ j -= 1
69
+ elif min_idx == 1:
70
+ i -= 1
71
+ else:
72
+ j -= 1
73
+ wp.append((i, j))
74
+
75
+ wp.reverse()
76
+ return D, wp
77
+
78
  class QuranRecitationComparer:
79
  def __init__(self, model_name="jonatasgrosman/wav2vec2-large-xlsr-53-arabic", token=None):
80
  """Initialize the Quran recitation comparer with a specific Wav2Vec2 model."""
81
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
82
+ print(f"Using device: {self.device}")
83
 
84
  # Load model and processor once during initialization
85
  if token:
86
+ print(f"Loading model {model_name} with token...")
87
  self.processor = Wav2Vec2Processor.from_pretrained(model_name, use_auth_token=token)
88
  self.model = Wav2Vec2ForCTC.from_pretrained(model_name, use_auth_token=token)
89
  else:
90
+ print(f"Loading model {model_name} without token...")
91
  self.processor = Wav2Vec2Processor.from_pretrained(model_name)
92
  self.model = Wav2Vec2ForCTC.from_pretrained(model_name)
93
 
 
96
 
97
  # Cache for embeddings to avoid recomputation
98
  self.embedding_cache = {}
99
+ print("Model loaded successfully!")
100
 
101
  def load_audio(self, file_path, target_sr=16000, trim_silence=True, normalize=True):
102
  """Load and preprocess an audio file."""
103
  if not os.path.exists(file_path):
104
  raise FileNotFoundError(f"Audio file not found: {file_path}")
105
 
106
+ print(f"Loading audio: {file_path}")
107
  y, sr = librosa.load(file_path, sr=target_sr)
108
 
109
  if normalize:
110
  y = librosa.util.normalize(y)
111
 
112
  if trim_silence:
113
+ # Use librosa.effects.trim which should be available in most versions
114
  y, _ = librosa.effects.trim(y, top_db=30)
115
 
116
  return y
 
133
 
134
  def compute_dtw_distance(self, features1, features2):
135
  """Compute the DTW distance between two sequences of features."""
136
+ D, wp = custom_dtw(X=features1, Y=features2, metric='euclidean')
137
  distance = D[-1, -1]
138
  normalized_distance = distance / len(wp)
139
  return normalized_distance
 
164
  def get_embedding_for_file(self, file_path):
165
  """Get embedding for a file, using cache if available."""
166
  if file_path in self.embedding_cache:
167
+ print(f"Using cached embedding for {file_path}")
168
  return self.embedding_cache[file_path]
169
 
170
+ print(f"Computing new embedding for {file_path}")
171
  audio = self.load_audio(file_path)
172
  embedding = self.get_deep_embedding(audio)
173
 
174
  # Store in cache for future use
175
  self.embedding_cache[file_path] = embedding
176
+ print(f"Embedding shape: {embedding.shape}")
177
 
178
  return embedding
179
 
 
190
  float: Similarity score
191
  str: Interpretation of similarity
192
  """
193
+ print(f"Comparing {file_path1} and {file_path2}")
194
  # Get embeddings (using cache if available)
195
  embedding1 = self.get_embedding_for_file(file_path1)
196
  embedding2 = self.get_embedding_for_file(file_path2)
197
 
198
  # Compute DTW distance
199
+ print("Computing DTW distance...")
200
  norm_distance = self.compute_dtw_distance(embedding1.T, embedding2.T)
201
+ print(f"Normalized distance: {norm_distance}")
202
 
203
  # Interpret results
204
  interpretation, similarity_score = self.interpret_similarity(norm_distance)
205
+ print(f"Similarity score: {similarity_score}, Interpretation: {interpretation}")
206
 
207
  return similarity_score, interpretation
208
 
209
  def clear_cache(self):
210
  """Clear the embedding cache to free memory."""
211
  self.embedding_cache = {}
212
+ print("Embedding cache cleared")
213
 
214
  # Global variable for the comparer instance
215
  comparer = None
 
219
  """Initialize the model when the application starts."""
220
  global comparer
221
  print("Initializing model... This may take a moment.")
222
+ try:
223
+ comparer = QuranRecitationComparer(
224
+ model_name="jonatasgrosman/wav2vec2-large-xlsr-53-arabic",
225
+ token=HF_TOKEN
226
+ )
227
+ print("Model initialized and ready for predictions!")
228
+ except Exception as e:
229
+ print(f"Error initializing model: {str(e)}")
230
+ raise
231
 
232
  @app.get("/")
233
  async def root():
 
250
  if not comparer:
251
  raise HTTPException(status_code=500, detail="Model not initialized. Please try again later.")
252
 
253
+ print(f"Received files: {file1.filename} and {file2.filename}")
254
  temp_dir = tempfile.mkdtemp()
255
+ print(f"Created temporary directory: {temp_dir}")
256
 
257
  try:
258
  # Save uploaded files to temporary directory
 
260
  temp_file2 = os.path.join(temp_dir, file2.filename)
261
 
262
  with open(temp_file1, "wb") as f:
263
+ content = await file1.read()
264
+ f.write(content)
265
 
266
  with open(temp_file2, "wb") as f:
267
+ content = await file2.read()
268
+ f.write(content)
269
+
270
+ print(f"Files saved to: {temp_file1} and {temp_file2}")
271
 
272
  # Compare the files
273
  similarity_score, interpretation = comparer.predict(temp_file1, temp_file2)
 
278
  )
279
 
280
  except Exception as e:
281
+ print(f"Error processing files: {str(e)}")
282
  raise HTTPException(status_code=500, detail=f"Error processing files: {str(e)}")
283
 
284
  finally:
285
  # Clean up temporary files
286
+ print(f"Cleaning up temporary directory: {temp_dir}")
287
  shutil.rmtree(temp_dir, ignore_errors=True)
288
 
289
  @app.post("/clear-cache")
 
296
  return {"message": "Embedding cache cleared successfully"}
297
 
298
  if __name__ == "__main__":
299
+ uvicorn.run("main:app", host="0.0.0.0", port=7860, log_level="info")