Hammad712 commited on
Commit
60a573f
·
verified ·
1 Parent(s): 44a3c0a

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +164 -7
main.py CHANGED
@@ -1,27 +1,150 @@
1
  import os
 
2
  from fastapi import FastAPI, UploadFile, File
 
 
 
 
 
3
  from google import genai
4
  from google.genai import types
5
- import uvicorn
6
 
7
  app = FastAPI()
8
 
 
 
 
 
9
  # Retrieve the GenAI API key from the environment variable.
10
- api_key = os.getenv("GENAI_API_KEY")
11
- if not api_key:
12
  raise EnvironmentError("GENAI_API_KEY environment variable not set")
13
 
14
  # Initialize the GenAI client.
15
- client = genai.Client(api_key=api_key)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
  @app.get("/")
18
  async def root():
19
  return {
20
  "message": "Welcome to the Audio Similarity API!",
21
  "usage": {
22
- "endpoint": "/compare-audio",
23
- "description": "POST two audio files (user recitation and professional qarri) for similarity analysis.",
24
- "instructions": "Send audio files as form-data with keys 'audio1' and 'audio2'."
 
 
 
 
 
 
 
25
  }
26
  }
27
 
@@ -30,6 +153,10 @@ async def compare_audio(
30
  audio1: UploadFile = File(...),
31
  audio2: UploadFile = File(...)
32
  ):
 
 
 
 
33
  # Read the uploaded audio files.
34
  audio1_bytes = await audio1.read()
35
  audio2_bytes = await audio2.read()
@@ -65,5 +192,35 @@ Provide your response with:
65
  # Return the model's response.
66
  return {"result": response.text}
67
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
  if __name__ == "__main__":
69
  uvicorn.run(app, host="0.0.0.0", port=8000)
 
1
  import os
2
+ import tempfile
3
  from fastapi import FastAPI, UploadFile, File
4
+ import uvicorn
5
+ import torch
6
+ import librosa
7
+ from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
8
+ from librosa.sequence import dtw
9
  from google import genai
10
  from google.genai import types
 
11
 
12
  app = FastAPI()
13
 
14
+ # ---------------------------
15
+ # Gemini-based Comparison API
16
+ # ---------------------------
17
+
18
  # Retrieve the GenAI API key from the environment variable.
19
+ genai_api_key = os.getenv("GENAI_API_KEY")
20
+ if not genai_api_key:
21
  raise EnvironmentError("GENAI_API_KEY environment variable not set")
22
 
23
  # Initialize the GenAI client.
24
+ client = genai.Client(api_key=genai_api_key)
25
+
26
+ # ---------------------------
27
+ # DTW-based Comparison Class
28
+ # ---------------------------
29
+ class QuranRecitationComparer:
30
+ def __init__(self, model_name="jonatasgrosman/wav2vec2-large-xlsr-53-arabic", auth_token=None):
31
+ """Initialize the Quran recitation comparer with a specific Wav2Vec2 model."""
32
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
33
+ # Load model and processor once during initialization
34
+ if auth_token:
35
+ self.processor = Wav2Vec2Processor.from_pretrained(model_name, token=auth_token)
36
+ self.model = Wav2Vec2ForCTC.from_pretrained(model_name, token=auth_token)
37
+ else:
38
+ self.processor = Wav2Vec2Processor.from_pretrained(model_name)
39
+ self.model = Wav2Vec2ForCTC.from_pretrained(model_name)
40
+ self.model = self.model.to(self.device)
41
+ self.model.eval()
42
+ # Cache for embeddings to avoid recomputation
43
+ self.embedding_cache = {}
44
+
45
+ def load_audio(self, file_path, target_sr=16000, trim_silence=True, normalize=True):
46
+ """Load and preprocess an audio file."""
47
+ if not os.path.exists(file_path):
48
+ raise FileNotFoundError(f"Audio file not found: {file_path}")
49
+ y, sr = librosa.load(file_path, sr=target_sr)
50
+ if normalize:
51
+ y = librosa.util.normalize(y)
52
+ if trim_silence:
53
+ y, _ = librosa.effects.trim(y, top_db=30)
54
+ return y
55
+
56
+ def get_deep_embedding(self, audio, sr=16000):
57
+ """Extract frame-wise deep embeddings using the pretrained model."""
58
+ input_values = self.processor(
59
+ audio,
60
+ sampling_rate=sr,
61
+ return_tensors="pt"
62
+ ).input_values.to(self.device)
63
+ with torch.no_grad():
64
+ outputs = self.model(input_values, output_hidden_states=True)
65
+ hidden_states = outputs.hidden_states[-1]
66
+ embedding_seq = hidden_states.squeeze(0).cpu().numpy()
67
+ return embedding_seq
68
+
69
+ def compute_dtw_distance(self, features1, features2):
70
+ """Compute the DTW distance between two sequences of features."""
71
+ D, wp = dtw(X=features1, Y=features2, metric='euclidean')
72
+ distance = D[-1, -1]
73
+ normalized_distance = distance / len(wp)
74
+ return normalized_distance
75
+
76
+ def interpret_similarity(self, norm_distance):
77
+ """Interpret the normalized distance value."""
78
+ if norm_distance == 0:
79
+ result = "The recitations are identical based on the deep embeddings."
80
+ score = 100
81
+ elif norm_distance < 1:
82
+ result = "The recitations are extremely similar."
83
+ score = 95
84
+ elif norm_distance < 5:
85
+ result = "The recitations are very similar with minor differences."
86
+ score = 80
87
+ elif norm_distance < 10:
88
+ result = "The recitations show moderate similarity."
89
+ score = 60
90
+ elif norm_distance < 20:
91
+ result = "The recitations show some noticeable differences."
92
+ score = 40
93
+ else:
94
+ result = "The recitations are quite different."
95
+ score = max(0, 100 - norm_distance)
96
+ return result, score
97
+
98
+ def get_embedding_for_file(self, file_path):
99
+ """Get embedding for a file, using cache if available."""
100
+ if file_path in self.embedding_cache:
101
+ return self.embedding_cache[file_path]
102
+ audio = self.load_audio(file_path)
103
+ embedding = self.get_deep_embedding(audio)
104
+ self.embedding_cache[file_path] = embedding
105
+ return embedding
106
+
107
+ def predict(self, file_path1, file_path2):
108
+ """
109
+ Predict the similarity between two audio files.
110
+ Returns:
111
+ float: Similarity score
112
+ str: Interpretation of similarity
113
+ """
114
+ embedding1 = self.get_embedding_for_file(file_path1)
115
+ embedding2 = self.get_embedding_for_file(file_path2)
116
+ norm_distance = self.compute_dtw_distance(embedding1.T, embedding2.T)
117
+ interpretation, similarity_score = self.interpret_similarity(norm_distance)
118
+ return similarity_score, interpretation
119
+
120
+ def clear_cache(self):
121
+ """Clear the embedding cache to free memory."""
122
+ self.embedding_cache = {}
123
+
124
+ # Retrieve HuggingFace auth token from environment variable (if needed).
125
+ hf_auth_token = os.getenv("HF_AUTH_TOKEN")
126
+ # Initialize the comparer instance once at startup.
127
+ comparer = QuranRecitationComparer(auth_token=hf_auth_token)
128
+
129
+ # ---------------------------
130
+ # API Endpoints
131
+ # ---------------------------
132
 
133
  @app.get("/")
134
  async def root():
135
  return {
136
  "message": "Welcome to the Audio Similarity API!",
137
  "usage": {
138
+ "endpoints": {
139
+ "gemini": {
140
+ "path": "/compare-audio",
141
+ "description": "POST two audio files (user recitation and professional qarri) for similarity analysis using Gemini."
142
+ },
143
+ "dtw": {
144
+ "path": "/compare-dtw",
145
+ "description": "POST two audio files (user recitation and professional qarri) for similarity analysis using deep embeddings and DTW."
146
+ }
147
+ }
148
  }
149
  }
150
 
 
153
  audio1: UploadFile = File(...),
154
  audio2: UploadFile = File(...)
155
  ):
156
+ """
157
+ Compare two audio files using the Gemini approach.
158
+ The first audio is the user's recitation and the second is the professional qarri recitation.
159
+ """
160
  # Read the uploaded audio files.
161
  audio1_bytes = await audio1.read()
162
  audio2_bytes = await audio2.read()
 
192
  # Return the model's response.
193
  return {"result": response.text}
194
 
195
+ @app.post("/compare-dtw")
196
+ async def compare_dtw(
197
+ audio1: UploadFile = File(...),
198
+ audio2: UploadFile = File(...)
199
+ ):
200
+ """
201
+ Compare two audio files using deep embeddings and DTW.
202
+ The first audio is the user's recitation and the second is the professional qarri recitation.
203
+ """
204
+ # Save the uploaded files to temporary files so they can be processed by the comparer.
205
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp1:
206
+ tmp1.write(await audio1.read())
207
+ tmp1_path = tmp1.name
208
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp2:
209
+ tmp2.write(await audio2.read())
210
+ tmp2_path = tmp2.name
211
+
212
+ try:
213
+ # Get similarity score and interpretation using DTW-based approach.
214
+ similarity_score, interpretation = comparer.predict(tmp1_path, tmp2_path)
215
+ finally:
216
+ # Clean up temporary files.
217
+ os.remove(tmp1_path)
218
+ os.remove(tmp2_path)
219
+
220
+ return {
221
+ "similarity_score": similarity_score,
222
+ "interpretation": interpretation
223
+ }
224
+
225
  if __name__ == "__main__":
226
  uvicorn.run(app, host="0.0.0.0", port=8000)