Hammad712 commited on
Commit
94ba3d3
·
verified ·
1 Parent(s): 1d61cef

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +91 -54
main.py CHANGED
@@ -1,24 +1,36 @@
1
- import os
 
 
2
  import torch
3
  import librosa
4
  import numpy as np
5
- import tempfile
6
  from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
7
  from librosa.sequence import dtw
8
- from fastapi import FastAPI, UploadFile, File, HTTPException
9
- from fastapi.responses import JSONResponse
10
  import shutil
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
- # Define the QuranRecitationComparer class as provided
13
  class QuranRecitationComparer:
14
- def __init__(self, model_name="jonatasgrosman/wav2vec2-large-xlsr-53-arabic", auth_token=None):
15
  """Initialize the Quran recitation comparer with a specific Wav2Vec2 model."""
16
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
17
 
18
  # Load model and processor once during initialization
19
- if auth_token:
20
- self.processor = Wav2Vec2Processor.from_pretrained(model_name, token=auth_token)
21
- self.model = Wav2Vec2ForCTC.from_pretrained(model_name, token=auth_token)
22
  else:
23
  self.processor = Wav2Vec2Processor.from_pretrained(model_name)
24
  self.model = Wav2Vec2ForCTC.from_pretrained(model_name)
@@ -107,77 +119,102 @@ class QuranRecitationComparer:
107
  """
108
  Predict the similarity between two audio files.
109
  This method can be called repeatedly without reloading the model.
 
 
 
 
 
 
 
 
110
  """
111
  # Get embeddings (using cache if available)
112
  embedding1 = self.get_embedding_for_file(file_path1)
113
  embedding2 = self.get_embedding_for_file(file_path2)
114
 
115
- # Compute DTW distance (transposing so that each column represents a frame)
116
  norm_distance = self.compute_dtw_distance(embedding1.T, embedding2.T)
117
 
118
  # Interpret results
119
  interpretation, similarity_score = self.interpret_similarity(norm_distance)
120
 
121
- print(f"Similarity Score: {similarity_score:.1f}/100")
122
- print(f"Interpretation: {interpretation}")
123
-
124
  return similarity_score, interpretation
125
 
126
  def clear_cache(self):
127
  """Clear the embedding cache to free memory."""
128
  self.embedding_cache = {}
129
 
130
- # Create FastAPI application
131
- app = FastAPI(
132
- title="Quran Recitation Comparison API",
133
- description="API for comparing similarity between Quran recitations",
134
- version="1.0.0"
135
- )
136
-
137
- # Global instance of the comparer
138
  comparer = None
139
 
140
  @app.on_event("startup")
141
  async def startup_event():
 
142
  global comparer
143
- # Optionally, set the HF authentication token from an environment variable
144
- auth_token = os.getenv("HF_TOKEN", None)
145
- comparer = QuranRecitationComparer(auth_token=auth_token)
146
- print("Model initialized and ready for predictions.")
 
 
147
 
148
- # Root endpoint
149
  @app.get("/")
150
  async def root():
151
- return {"message": "Welcome to the Quran Recitation Comparison API"}
152
-
153
- # Compare endpoint that accepts two audio files
154
- @app.post("/compare")
155
- async def compare_recitations(file1: UploadFile = File(...), file2: UploadFile = File(...)):
156
- if comparer is None:
157
- raise HTTPException(status_code=503, detail="Model not initialized")
158
-
 
 
 
 
 
 
 
 
 
 
 
 
 
159
  try:
160
- # Save the uploaded files to temporary files
161
- with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp1:
162
- tmp1.write(await file1.read())
163
- file_path1 = tmp1.name
164
- with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp2:
165
- tmp2.write(await file2.read())
166
- file_path2 = tmp2.name
167
-
168
- # Use the comparer to predict similarity
169
- similarity_score, interpretation = comparer.predict(file_path1, file_path2)
170
-
 
 
 
 
 
 
 
 
 
 
 
171
  # Clean up temporary files
172
- os.remove(file_path1)
173
- os.remove(file_path2)
174
 
175
- return {"similarity_score": similarity_score, "interpretation": interpretation}
176
-
177
- except Exception as e:
178
- raise HTTPException(status_code=400, detail=str(e))
 
 
 
 
179
 
180
- # Run the application with uvicorn if this module is executed directly.
181
  if __name__ == "__main__":
182
- import uvicorn
183
- uvicorn.run("main:app", host="0.0.0.0", port=7860, reload=False)
 
1
+ from fastapi import FastAPI, HTTPException, UploadFile, File, Form
2
+ from pydantic import BaseModel
3
+ from typing import Optional
4
  import torch
5
  import librosa
6
  import numpy as np
7
+ import os
8
  from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
9
  from librosa.sequence import dtw
10
+ import tempfile
 
11
  import shutil
12
+ from dotenv import load_dotenv
13
+ import uvicorn
14
+
15
+ # Load environment variables
16
+ load_dotenv()
17
+ HF_TOKEN = os.getenv("HF_TOKEN")
18
+
19
+ app = FastAPI(title="Quran Recitation Comparer API")
20
+
21
+ class ComparisonResult(BaseModel):
22
+ similarity_score: float
23
+ interpretation: str
24
 
 
25
  class QuranRecitationComparer:
26
+ def __init__(self, model_name="jonatasgrosman/wav2vec2-large-xlsr-53-arabic", token=None):
27
  """Initialize the Quran recitation comparer with a specific Wav2Vec2 model."""
28
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
29
 
30
  # Load model and processor once during initialization
31
+ if token:
32
+ self.processor = Wav2Vec2Processor.from_pretrained(model_name, use_auth_token=token)
33
+ self.model = Wav2Vec2ForCTC.from_pretrained(model_name, use_auth_token=token)
34
  else:
35
  self.processor = Wav2Vec2Processor.from_pretrained(model_name)
36
  self.model = Wav2Vec2ForCTC.from_pretrained(model_name)
 
119
  """
120
  Predict the similarity between two audio files.
121
  This method can be called repeatedly without reloading the model.
122
+
123
+ Args:
124
+ file_path1 (str): Path to first audio file
125
+ file_path2 (str): Path to second audio file
126
+
127
+ Returns:
128
+ float: Similarity score
129
+ str: Interpretation of similarity
130
  """
131
  # Get embeddings (using cache if available)
132
  embedding1 = self.get_embedding_for_file(file_path1)
133
  embedding2 = self.get_embedding_for_file(file_path2)
134
 
135
+ # Compute DTW distance
136
  norm_distance = self.compute_dtw_distance(embedding1.T, embedding2.T)
137
 
138
  # Interpret results
139
  interpretation, similarity_score = self.interpret_similarity(norm_distance)
140
 
 
 
 
141
  return similarity_score, interpretation
142
 
143
  def clear_cache(self):
144
  """Clear the embedding cache to free memory."""
145
  self.embedding_cache = {}
146
 
147
+ # Global variable for the comparer instance
 
 
 
 
 
 
 
148
  comparer = None
149
 
150
  @app.on_event("startup")
151
  async def startup_event():
152
+ """Initialize the model when the application starts."""
153
  global comparer
154
+ print("Initializing model... This may take a moment.")
155
+ comparer = QuranRecitationComparer(
156
+ model_name="jonatasgrosman/wav2vec2-large-xlsr-53-arabic",
157
+ token=HF_TOKEN
158
+ )
159
+ print("Model initialized and ready for predictions!")
160
 
 
161
  @app.get("/")
162
  async def root():
163
+ """Root endpoint to check if the API is running."""
164
+ return {"message": "Quran Recitation Comparer API is running", "status": "active"}
165
+
166
+ @app.post("/compare", response_model=ComparisonResult)
167
+ async def compare_files(
168
+ file1: UploadFile = File(...),
169
+ file2: UploadFile = File(...)
170
+ ):
171
+ """
172
+ Compare two audio files and return similarity metrics.
173
+
174
+ - **file1**: First audio file (MP3, WAV, etc.)
175
+ - **file2**: Second audio file (MP3, WAV, etc.)
176
+
177
+ Returns similarity score and interpretation.
178
+ """
179
+ if not comparer:
180
+ raise HTTPException(status_code=500, detail="Model not initialized. Please try again later.")
181
+
182
+ temp_dir = tempfile.mkdtemp()
183
+
184
  try:
185
+ # Save uploaded files to temporary directory
186
+ temp_file1 = os.path.join(temp_dir, file1.filename)
187
+ temp_file2 = os.path.join(temp_dir, file2.filename)
188
+
189
+ with open(temp_file1, "wb") as f:
190
+ shutil.copyfileobj(file1.file, f)
191
+
192
+ with open(temp_file2, "wb") as f:
193
+ shutil.copyfileobj(file2.file, f)
194
+
195
+ # Compare the files
196
+ similarity_score, interpretation = comparer.predict(temp_file1, temp_file2)
197
+
198
+ return ComparisonResult(
199
+ similarity_score=similarity_score,
200
+ interpretation=interpretation
201
+ )
202
+
203
+ except Exception as e:
204
+ raise HTTPException(status_code=500, detail=f"Error processing files: {str(e)}")
205
+
206
+ finally:
207
  # Clean up temporary files
208
+ shutil.rmtree(temp_dir, ignore_errors=True)
 
209
 
210
+ @app.post("/clear-cache")
211
+ async def clear_cache():
212
+ """Clear the embedding cache to free memory."""
213
+ if not comparer:
214
+ raise HTTPException(status_code=500, detail="Model not initialized.")
215
+
216
+ comparer.clear_cache()
217
+ return {"message": "Embedding cache cleared successfully"}
218
 
 
219
  if __name__ == "__main__":
220
+ uvicorn.run("app:app", host="0.0.0.0", port=8000, reload=True)