Hammad712 commited on
Commit
8bb5ed1
·
verified ·
1 Parent(s): 7304ea8

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +60 -144
main.py CHANGED
@@ -1,153 +1,69 @@
1
  import os
2
- os.environ["TRANSFORMERS_CACHE"] = "/tmp" # Ensure the cache directory is writable
3
- os.environ["NUMBA_CACHE_DIR"] = "/tmp" # Ensure a writable cache directory for Numba
4
- os.environ["NUMBA_DISABLE_CACHE"] = "1" # Disable Numba caching to avoid errors
 
5
 
6
- import torch
7
- import librosa
8
- import numpy as np
9
- import tempfile
10
- from fastapi import FastAPI, UploadFile, File, HTTPException
11
- from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
12
- from librosa.sequence import dtw
13
- from contextlib import asynccontextmanager
14
 
15
- # --- Core Class Definition ---
16
- class QuranRecitationComparer:
17
- def __init__(self, model_name="jonatasgrosman/wav2vec2-large-xlsr-53-arabic", auth_token=None):
18
- self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
19
- if auth_token:
20
- self.processor = Wav2Vec2Processor.from_pretrained(model_name, token=auth_token)
21
- self.model = Wav2Vec2ForCTC.from_pretrained(model_name, token=auth_token)
22
- else:
23
- self.processor = Wav2Vec2Processor.from_pretrained(model_name)
24
- self.model = Wav2Vec2ForCTC.from_pretrained(model_name)
25
- self.model = self.model.to(self.device)
26
- self.model.eval()
27
- self.embedding_cache = {}
28
 
29
- def load_audio(self, file_path, target_sr=16000, trim_silence=True, normalize=True):
30
- if not os.path.exists(file_path):
31
- raise FileNotFoundError(f"Audio file not found: {file_path}")
32
- y, sr = librosa.load(file_path, sr=target_sr)
33
- if normalize:
34
- y = librosa.util.normalize(y)
35
- if trim_silence:
36
- y, _ = librosa.effects.trim(y, top_db=30)
37
- return y
38
 
39
- def get_deep_embedding(self, audio, sr=16000):
40
- input_values = self.processor(
41
- audio,
42
- sampling_rate=sr,
43
- return_tensors="pt"
44
- ).input_values.to(self.device)
45
- with torch.no_grad():
46
- outputs = self.model(input_values, output_hidden_states=True)
47
- hidden_states = outputs.hidden_states[-1]
48
- embedding_seq = hidden_states.squeeze(0).cpu().numpy()
49
- return embedding_seq
50
-
51
- def compute_dtw_distance(self, features1, features2):
52
- D, wp = dtw(X=features1, Y=features2, metric='euclidean')
53
- distance = D[-1, -1]
54
- normalized_distance = distance / len(wp)
55
- return normalized_distance
56
-
57
- def interpret_similarity(self, norm_distance):
58
- if norm_distance == 0:
59
- result = "The recitations are identical based on the deep embeddings."
60
- score = 100
61
- elif norm_distance < 1:
62
- result = "The recitations are extremely similar."
63
- score = 95
64
- elif norm_distance < 5:
65
- result = "The recitations are very similar with minor differences."
66
- score = 80
67
- elif norm_distance < 10:
68
- result = "The recitations show moderate similarity."
69
- score = 60
70
- elif norm_distance < 20:
71
- result = "The recitations show some noticeable differences."
72
- score = 40
73
- else:
74
- result = "The recitations are quite different."
75
- score = max(0, 100 - norm_distance)
76
- return result, score
77
-
78
- def get_embedding_for_file(self, file_path):
79
- if file_path in self.embedding_cache:
80
- return self.embedding_cache[file_path]
81
- audio = self.load_audio(file_path)
82
- embedding = self.get_deep_embedding(audio)
83
- self.embedding_cache[file_path] = embedding
84
- return embedding
85
-
86
- def predict(self, file_path1, file_path2):
87
- embedding1 = self.get_embedding_for_file(file_path1)
88
- embedding2 = self.get_embedding_for_file(file_path2)
89
- norm_distance = self.compute_dtw_distance(embedding1.T, embedding2.T)
90
- interpretation, similarity_score = self.interpret_similarity(norm_distance)
91
- print(f"Similarity Score: {similarity_score:.1f}/100")
92
- print(f"Interpretation: {interpretation}")
93
- return similarity_score, interpretation
94
-
95
- def clear_cache(self):
96
- self.embedding_cache = {}
97
-
98
- # --- Lifespan Event Handler ---
99
- @asynccontextmanager
100
- async def lifespan(app: FastAPI):
101
- global comparer
102
- auth_token = os.environ.get("HF_TOKEN")
103
- comparer = QuranRecitationComparer(
104
- model_name="jonatasgrosman/wav2vec2-large-xlsr-53-arabic",
105
- auth_token=auth_token
106
- )
107
- print("Model initialized and ready for predictions!")
108
- yield
109
- print("Application shutdown: Cleanup if necessary.")
110
-
111
- app = FastAPI(
112
- title="Quran Recitation Comparer API",
113
- description="Compares two Quran recitations using a deep wav2vec2 model.",
114
- version="1.0",
115
- lifespan=lifespan
116
- )
117
-
118
- # --- API Endpoints ---
119
- @app.get("/", summary="Health Check")
120
  async def root():
121
- return {"message": "Quran Recitation Comparer API is up and running."}
122
-
123
- @app.post("/predict", summary="Compare Two Audio Files", response_model=dict)
124
- async def predict(file1: UploadFile = File(...), file2: UploadFile = File(...)):
125
- tmp1_path = None
126
- tmp2_path = None
127
- try:
128
- suffix1 = os.path.splitext(file1.filename)[1] or ".wav"
129
- with tempfile.NamedTemporaryFile(delete=False, suffix=suffix1) as tmp1:
130
- content1 = await file1.read()
131
- tmp1.write(content1)
132
- tmp1_path = tmp1.name
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
133
 
134
- suffix2 = os.path.splitext(file2.filename)[1] or ".wav"
135
- with tempfile.NamedTemporaryFile(delete=False, suffix=suffix2) as tmp2:
136
- content2 = await file2.read()
137
- tmp2.write(content2)
138
- tmp2_path = tmp2.name
 
 
 
 
 
 
 
 
 
 
139
 
140
- similarity_score, interpretation = comparer.predict(tmp1_path, tmp2_path)
141
- return {"similarity_score": similarity_score, "interpretation": interpretation}
142
- except Exception as e:
143
- raise HTTPException(status_code=500, detail=str(e))
144
- finally:
145
- if tmp1_path and os.path.exists(tmp1_path):
146
- os.remove(tmp1_path)
147
- if tmp2_path and os.path.exists(tmp2_path):
148
- os.remove(tmp2_path)
149
 
150
- @app.post("/clear_cache", summary="Clear Embedding Cache", response_model=dict)
151
- async def clear_cache():
152
- comparer.clear_cache()
153
- return {"message": "Cache cleared."}
 
1
  import os
2
+ from fastapi import FastAPI, UploadFile, File
3
+ from google import genai
4
+ from google.genai import types
5
+ import uvicorn
6
 
7
+ app = FastAPI()
 
 
 
 
 
 
 
8
 
9
+ # Retrieve the GenAI API key from the environment variable.
10
+ api_key = os.getenv("GENAI_API_KEY")
11
+ if not api_key:
12
+ raise EnvironmentError("GENAI_API_KEY environment variable not set")
 
 
 
 
 
 
 
 
 
13
 
14
+ # Initialize the GenAI client.
15
+ client = genai.Client(api_key=api_key)
 
 
 
 
 
 
 
16
 
17
+ @app.get("/")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  async def root():
19
+ return {
20
+ "message": "Welcome to the Audio Similarity API!",
21
+ "usage": {
22
+ "endpoint": "/compare-audio",
23
+ "description": "POST two audio files (user recitation and professional qarri) for similarity analysis.",
24
+ "instructions": "Send audio files as form-data with keys 'audio1' and 'audio2'."
25
+ }
26
+ }
27
+
28
+ @app.post("/compare-audio")
29
+ async def compare_audio(
30
+ audio1: UploadFile = File(...),
31
+ audio2: UploadFile = File(...)
32
+ ):
33
+ # Read the uploaded audio files.
34
+ audio1_bytes = await audio1.read()
35
+ audio2_bytes = await audio2.read()
36
+
37
+ # Create a refined prompt that clearly identifies the audio sources.
38
+ prompt = (
39
+ """Please analyze and compare the two provided audio clips.
40
+ The first audio is the user's recitation, and the second audio is the professional qarri recitation.
41
+ Evaluate their similarity on a scale from 0 to 1, where:
42
+ - 1 indicates the user's recitation contains no mistakes compared to the professional version,
43
+ - 0 indicates there are significant mistakes.
44
+ Provide your response with:
45
+ 1. A numerical similarity score on the first line.
46
+ 2. A single sentence that indicates whether the user's recitation is similar, moderately similar, or dissimilar to the professional qarri."""
47
+ )
48
 
49
+ # Generate the content using the Gemini model with the two audio inputs.
50
+ response = client.models.generate_content(
51
+ model='gemini-2.0-flash',
52
+ contents=[
53
+ prompt,
54
+ types.Part.from_bytes(
55
+ data=audio1_bytes,
56
+ mime_type=audio1.content_type,
57
+ ),
58
+ types.Part.from_bytes(
59
+ data=audio2_bytes,
60
+ mime_type=audio2.content_type,
61
+ )
62
+ ]
63
+ )
64
 
65
+ # Return the model's response.
66
+ return {"result": response.text}
 
 
 
 
 
 
 
67
 
68
+ if __name__ == "__main__":
69
+ uvicorn.run(app, host="0.0.0.0", port=8000)