Hammad712 commited on
Commit
1d61cef
·
verified ·
1 Parent(s): 0318876

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +153 -223
main.py CHANGED
@@ -2,252 +2,182 @@ import os
2
  import torch
3
  import librosa
4
  import numpy as np
5
- from typing import List, Dict, Any, Optional
6
- from fastapi import FastAPI, UploadFile, File, HTTPException, BackgroundTasks
7
- from fastapi.responses import JSONResponse
8
- from fastapi.middleware.cors import CORSMiddleware
9
- from pydantic import BaseModel
10
- from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
11
  import tempfile
12
- import uuid
 
 
 
13
  import shutil
14
- from contextlib import asynccontextmanager
15
-
16
- # Disable numba JIT to avoid caching issues
17
- os.environ["NUMBA_DISABLE_JIT"] = "1"
18
-
19
- # Global variables
20
- MODEL = None
21
- PROCESSOR = None
22
- UPLOAD_DIR = os.path.join(tempfile.gettempdir(), "quran_comparison_uploads")
23
- os.makedirs(UPLOAD_DIR, exist_ok=True)
24
-
25
- # Response models
26
- class SimilarityResponse(BaseModel):
27
- similarity_score: float
28
- interpretation: str
29
-
30
- class ErrorResponse(BaseModel):
31
- error: str
32
-
33
- # Initialize model from environment variable
34
- def initialize_model():
35
- global MODEL, PROCESSOR
36
- hf_token = os.environ.get("HF_TOKEN", None)
37
- model_name = os.environ.get("MODEL_NAME", "jonatasgrosman/wav2vec2-large-xlsr-53-arabic")
38
-
39
- try:
40
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
41
- print(f"Loading model on device: {device}")
42
-
43
- # Load model and processor using updated parameter `token`
44
- if hf_token:
45
- PROCESSOR = Wav2Vec2Processor.from_pretrained(model_name, token=hf_token)
46
- MODEL = Wav2Vec2ForCTC.from_pretrained(model_name, token=hf_token)
47
- else:
48
- PROCESSOR = Wav2Vec2Processor.from_pretrained(model_name)
49
- MODEL = Wav2Vec2ForCTC.from_pretrained(model_name)
50
-
51
- MODEL = MODEL.to(device)
52
- MODEL.eval()
53
- print("Model loaded successfully")
54
- except Exception as e:
55
- print(f"Error loading model: {e}")
56
- raise e
57
 
58
- # Lifespan event handler to initialize the model at startup
59
- @asynccontextmanager
60
- async def lifespan(app: FastAPI):
61
- initialize_model()
62
- yield
63
 
64
- # Create the FastAPI app with the lifespan handler and add CORS middleware
65
- app = FastAPI(
66
- title="Quran Recitation Comparison API",
67
- description="API for comparing similarity between Quran recitations using Wav2Vec2 embeddings",
68
- version="1.0.0",
69
- lifespan=lifespan
70
- )
71
 
72
- app.add_middleware(
73
- CORSMiddleware,
74
- allow_origins=["*"], # Allows all origins
75
- allow_credentials=True,
76
- allow_methods=["*"], # Allows all methods
77
- allow_headers=["*"], # Allows all headers
78
- )
79
 
80
- # Root endpoint
81
- @app.get("/")
82
- async def root():
83
- """Welcome endpoint."""
84
- return {"message": "Welcome to the Quran Recitation Comparison API"}
 
 
85
 
86
- # Load audio file
87
- def load_audio(file_path, target_sr=16000, trim_silence=True, normalize=True):
88
- """Load and preprocess an audio file."""
89
- try:
90
  y, sr = librosa.load(file_path, sr=target_sr)
 
91
  if normalize:
92
  y = librosa.util.normalize(y)
 
93
  if trim_silence:
94
  y, _ = librosa.effects.trim(y, top_db=30)
 
95
  return y
96
- except Exception as e:
97
- raise HTTPException(status_code=400, detail=f"Error loading audio: {e}")
98
-
99
- # Get deep embedding
100
- def get_deep_embedding(audio, sr=16000):
101
- """Extract frame-wise deep embeddings using the pretrained model."""
102
- global MODEL, PROCESSOR
103
- if MODEL is None or PROCESSOR is None:
104
- raise HTTPException(status_code=500, detail="Model not initialized")
105
- try:
106
- device = next(MODEL.parameters()).device
107
- input_values = PROCESSOR(
108
- audio,
109
- sampling_rate=sr,
110
  return_tensors="pt"
111
- ).input_values.to(device)
112
-
113
  with torch.no_grad():
114
- outputs = MODEL(input_values, output_hidden_states=True)
115
-
116
  hidden_states = outputs.hidden_states[-1]
117
  embedding_seq = hidden_states.squeeze(0).cpu().numpy()
 
118
  return embedding_seq
119
- except Exception as e:
120
- raise HTTPException(status_code=500, detail=f"Error extracting embeddings: {e}")
121
-
122
- # Custom DTW implementation to avoid issues with librosa's dtw
123
- def custom_dtw(X, Y, metric='euclidean'):
124
- """
125
- Custom implementation of DTW.
126
- X and Y are expected to be 2D numpy arrays.
127
- """
128
- # Check inputs are 2D and non-empty
129
- if X.ndim != 2 or Y.ndim != 2:
130
- raise ValueError("Input features must be 2D arrays.")
131
- if X.shape[1] == 0 or Y.shape[1] == 0:
132
- raise ValueError("Empty embedding sequence encountered.")
133
-
134
- n, m = len(X[0]), len(Y[0])
135
- D = np.zeros((n+1, m+1))
136
- D[0, :] = np.inf
137
- D[:, 0] = np.inf
138
- D[0, 0] = 0
139
-
140
- for i in range(1, n+1):
141
- for j in range(1, m+1):
142
- if metric == 'euclidean':
143
- cost = np.sqrt(np.sum((X[:, i-1] - Y[:, j-1])**2))
144
- elif metric == 'cosine':
145
- cost = 1 - np.dot(X[:, i-1], Y[:, j-1]) / (np.linalg.norm(X[:, i-1]) * np.linalg.norm(Y[:, j-1]))
146
- else:
147
- cost = np.sum(np.abs(X[:, i-1] - Y[:, j-1]))
148
- D[i, j] = cost + min(D[i-1, j], D[i, j-1], D[i-1, j-1])
149
-
150
- i, j = n, m
151
- wp = [(i, j)]
152
- while i > 1 or j > 1:
153
- candidates = [(i-1, j-1), (i-1, j), (i, j-1)]
154
- valid_candidates = [(ii, jj) for ii, jj in candidates if ii > 0 and jj > 0]
155
- i, j = min(valid_candidates, key=lambda x: D[x[0], x[1]])
156
- wp.append((i, j))
157
-
158
- wp.reverse()
159
- return D, wp
160
-
161
- # Compute DTW distance
162
- def compute_dtw_distance(features1, features2):
163
- """Compute the DTW distance between two sequences of features."""
164
- try:
165
- D, wp = custom_dtw(features1, features2, metric='euclidean')
166
  distance = D[-1, -1]
167
  normalized_distance = distance / len(wp)
168
  return normalized_distance
169
- except Exception as e:
170
- raise HTTPException(status_code=500, detail=f"Error computing DTW distance: {e}")
171
-
172
- # Interpret similarity based on the normalized distance
173
- def interpret_similarity(norm_distance):
174
- if norm_distance == 0:
175
- result = "The recitations are identical based on the deep embeddings."
176
- score = 100
177
- elif norm_distance < 1:
178
- result = "The recitations are extremely similar."
179
- score = 95
180
- elif norm_distance < 5:
181
- result = "The recitations are very similar with minor differences."
182
- score = 80
183
- elif norm_distance < 10:
184
- result = "The recitations show moderate similarity."
185
- score = 60
186
- elif norm_distance < 20:
187
- result = "The recitations show some noticeable differences."
188
- score = 40
189
- else:
190
- result = "The recitations are quite different."
191
- score = max(0, 100 - norm_distance)
192
- return result, score
193
-
194
- # Clean up temporary files
195
- def cleanup_temp_files(file_paths):
196
- for file_path in file_paths:
197
- if os.path.exists(file_path):
198
- try:
199
- os.remove(file_path)
200
- except Exception as e:
201
- print(f"Error removing temporary file {file_path}: {e}")
202
-
203
- # API endpoint for comparing recitations
204
- @app.post("/compare", response_model=SimilarityResponse)
205
- async def compare_recitations(
206
- background_tasks: BackgroundTasks,
207
- file1: UploadFile = File(...),
208
- file2: UploadFile = File(...)
209
- ):
210
- temp_file1 = os.path.join(UPLOAD_DIR, f"{uuid.uuid4()}.wav")
211
- temp_file2 = os.path.join(UPLOAD_DIR, f"{uuid.uuid4()}.wav")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
212
  try:
213
- # Save uploaded files to temporary locations
214
- with open(temp_file1, "wb") as f:
215
- shutil.copyfileobj(file1.file, f)
216
- with open(temp_file2, "wb") as f:
217
- shutil.copyfileobj(file2.file, f)
218
-
219
- # Load audio files
220
- audio1 = load_audio(temp_file1)
221
- audio2 = load_audio(temp_file2)
222
-
223
- # Extract embeddings
224
- embedding1 = get_deep_embedding(audio1)
225
- embedding2 = get_deep_embedding(audio2)
226
-
227
- # Compute DTW distance (transpose so each column represents a frame)
228
- norm_distance = compute_dtw_distance(embedding1.T, embedding2.T)
229
- interpretation, similarity_score = interpret_similarity(norm_distance)
230
-
231
- background_tasks.add_task(cleanup_temp_files, [temp_file1, temp_file2])
232
  return {"similarity_score": similarity_score, "interpretation": interpretation}
233
-
234
- except HTTPException as he:
235
- background_tasks.add_task(cleanup_temp_files, [temp_file1, temp_file2])
236
- raise he
237
  except Exception as e:
238
- background_tasks.add_task(cleanup_temp_files, [temp_file1, temp_file2])
239
- print(f"Unexpected error in /compare: {e}")
240
- raise HTTPException(status_code=500, detail="An unexpected error occurred during comparison.")
241
-
242
- # Health check endpoint
243
- @app.get("/health")
244
- async def health_check():
245
- if MODEL is None or PROCESSOR is None:
246
- return JSONResponse(status_code=503, content={"status": "error", "message": "Model not initialized"})
247
- return {"status": "ok", "model_loaded": True}
248
-
249
- # Run the FastAPI app
250
  if __name__ == "__main__":
251
  import uvicorn
252
- port = int(os.environ.get("PORT", 7860))
253
- uvicorn.run("main:app", host="0.0.0.0", port=port, reload=False)
 
2
  import torch
3
  import librosa
4
  import numpy as np
 
 
 
 
 
 
5
  import tempfile
6
+ from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
7
+ from librosa.sequence import dtw
8
+ from fastapi import FastAPI, UploadFile, File, HTTPException
9
+ from fastapi.responses import JSONResponse
10
  import shutil
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
+ # Define the QuranRecitationComparer class as provided
13
+ class QuranRecitationComparer:
14
+ def __init__(self, model_name="jonatasgrosman/wav2vec2-large-xlsr-53-arabic", auth_token=None):
15
+ """Initialize the Quran recitation comparer with a specific Wav2Vec2 model."""
16
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
17
 
18
+ # Load model and processor once during initialization
19
+ if auth_token:
20
+ self.processor = Wav2Vec2Processor.from_pretrained(model_name, token=auth_token)
21
+ self.model = Wav2Vec2ForCTC.from_pretrained(model_name, token=auth_token)
22
+ else:
23
+ self.processor = Wav2Vec2Processor.from_pretrained(model_name)
24
+ self.model = Wav2Vec2ForCTC.from_pretrained(model_name)
25
 
26
+ self.model = self.model.to(self.device)
27
+ self.model.eval()
 
 
 
 
 
28
 
29
+ # Cache for embeddings to avoid recomputation
30
+ self.embedding_cache = {}
31
+
32
+ def load_audio(self, file_path, target_sr=16000, trim_silence=True, normalize=True):
33
+ """Load and preprocess an audio file."""
34
+ if not os.path.exists(file_path):
35
+ raise FileNotFoundError(f"Audio file not found: {file_path}")
36
 
 
 
 
 
37
  y, sr = librosa.load(file_path, sr=target_sr)
38
+
39
  if normalize:
40
  y = librosa.util.normalize(y)
41
+
42
  if trim_silence:
43
  y, _ = librosa.effects.trim(y, top_db=30)
44
+
45
  return y
46
+
47
+ def get_deep_embedding(self, audio, sr=16000):
48
+ """Extract frame-wise deep embeddings using the pretrained model."""
49
+ input_values = self.processor(
50
+ audio,
51
+ sampling_rate=sr,
 
 
 
 
 
 
 
 
52
  return_tensors="pt"
53
+ ).input_values.to(self.device)
54
+
55
  with torch.no_grad():
56
+ outputs = self.model(input_values, output_hidden_states=True)
57
+
58
  hidden_states = outputs.hidden_states[-1]
59
  embedding_seq = hidden_states.squeeze(0).cpu().numpy()
60
+
61
  return embedding_seq
62
+
63
+ def compute_dtw_distance(self, features1, features2):
64
+ """Compute the DTW distance between two sequences of features."""
65
+ D, wp = dtw(X=features1, Y=features2, metric='euclidean')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
  distance = D[-1, -1]
67
  normalized_distance = distance / len(wp)
68
  return normalized_distance
69
+
70
+ def interpret_similarity(self, norm_distance):
71
+ """Interpret the normalized distance value."""
72
+ if norm_distance == 0:
73
+ result = "The recitations are identical based on the deep embeddings."
74
+ score = 100
75
+ elif norm_distance < 1:
76
+ result = "The recitations are extremely similar."
77
+ score = 95
78
+ elif norm_distance < 5:
79
+ result = "The recitations are very similar with minor differences."
80
+ score = 80
81
+ elif norm_distance < 10:
82
+ result = "The recitations show moderate similarity."
83
+ score = 60
84
+ elif norm_distance < 20:
85
+ result = "The recitations show some noticeable differences."
86
+ score = 40
87
+ else:
88
+ result = "The recitations are quite different."
89
+ score = max(0, 100 - norm_distance)
90
+
91
+ return result, score
92
+
93
+ def get_embedding_for_file(self, file_path):
94
+ """Get embedding for a file, using cache if available."""
95
+ if file_path in self.embedding_cache:
96
+ return self.embedding_cache[file_path]
97
+
98
+ audio = self.load_audio(file_path)
99
+ embedding = self.get_deep_embedding(audio)
100
+
101
+ # Store in cache for future use
102
+ self.embedding_cache[file_path] = embedding
103
+
104
+ return embedding
105
+
106
+ def predict(self, file_path1, file_path2):
107
+ """
108
+ Predict the similarity between two audio files.
109
+ This method can be called repeatedly without reloading the model.
110
+ """
111
+ # Get embeddings (using cache if available)
112
+ embedding1 = self.get_embedding_for_file(file_path1)
113
+ embedding2 = self.get_embedding_for_file(file_path2)
114
+
115
+ # Compute DTW distance (transposing so that each column represents a frame)
116
+ norm_distance = self.compute_dtw_distance(embedding1.T, embedding2.T)
117
+
118
+ # Interpret results
119
+ interpretation, similarity_score = self.interpret_similarity(norm_distance)
120
+
121
+ print(f"Similarity Score: {similarity_score:.1f}/100")
122
+ print(f"Interpretation: {interpretation}")
123
+
124
+ return similarity_score, interpretation
125
+
126
+ def clear_cache(self):
127
+ """Clear the embedding cache to free memory."""
128
+ self.embedding_cache = {}
129
+
130
+ # Create FastAPI application
131
+ app = FastAPI(
132
+ title="Quran Recitation Comparison API",
133
+ description="API for comparing similarity between Quran recitations",
134
+ version="1.0.0"
135
+ )
136
+
137
+ # Global instance of the comparer
138
+ comparer = None
139
+
140
+ @app.on_event("startup")
141
+ async def startup_event():
142
+ global comparer
143
+ # Optionally, set the HF authentication token from an environment variable
144
+ auth_token = os.getenv("HF_TOKEN", None)
145
+ comparer = QuranRecitationComparer(auth_token=auth_token)
146
+ print("Model initialized and ready for predictions.")
147
+
148
+ # Root endpoint
149
+ @app.get("/")
150
+ async def root():
151
+ return {"message": "Welcome to the Quran Recitation Comparison API"}
152
+
153
+ # Compare endpoint that accepts two audio files
154
+ @app.post("/compare")
155
+ async def compare_recitations(file1: UploadFile = File(...), file2: UploadFile = File(...)):
156
+ if comparer is None:
157
+ raise HTTPException(status_code=503, detail="Model not initialized")
158
+
159
  try:
160
+ # Save the uploaded files to temporary files
161
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp1:
162
+ tmp1.write(await file1.read())
163
+ file_path1 = tmp1.name
164
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp2:
165
+ tmp2.write(await file2.read())
166
+ file_path2 = tmp2.name
167
+
168
+ # Use the comparer to predict similarity
169
+ similarity_score, interpretation = comparer.predict(file_path1, file_path2)
170
+
171
+ # Clean up temporary files
172
+ os.remove(file_path1)
173
+ os.remove(file_path2)
174
+
 
 
 
 
175
  return {"similarity_score": similarity_score, "interpretation": interpretation}
176
+
 
 
 
177
  except Exception as e:
178
+ raise HTTPException(status_code=400, detail=str(e))
179
+
180
+ # Run the application with uvicorn if this module is executed directly.
 
 
 
 
 
 
 
 
 
181
  if __name__ == "__main__":
182
  import uvicorn
183
+ uvicorn.run("main:app", host="0.0.0.0", port=7860, reload=False)