Hammad712 commited on
Commit
612c535
·
verified ·
1 Parent(s): 51b44fe

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +28 -21
main.py CHANGED
@@ -4,6 +4,7 @@ from fastapi import FastAPI, UploadFile, File
4
  import uvicorn
5
  import torch
6
  import librosa
 
7
  from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
8
  from librosa.sequence import dtw
9
  from google import genai
@@ -11,17 +12,9 @@ from google.genai import types
11
 
12
  app = FastAPI()
13
 
14
- # ---------------------------
15
- # Gemini-based Comparison API
16
- # ---------------------------
17
-
18
- # Retrieve the GenAI API key from the environment variable.
19
- genai_api_key = os.getenv("GENAI_API_KEY")
20
- if not genai_api_key:
21
- raise EnvironmentError("GENAI_API_KEY environment variable not set")
22
-
23
- # Initialize the GenAI client.
24
- client = genai.Client(api_key=genai_api_key)
25
 
26
  # ---------------------------
27
  # DTW-based Comparison Class
@@ -30,7 +23,7 @@ class QuranRecitationComparer:
30
  def __init__(self, model_name="jonatasgrosman/wav2vec2-large-xlsr-53-arabic", auth_token=None):
31
  """Initialize the Quran recitation comparer with a specific Wav2Vec2 model."""
32
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
33
- # Load model and processor once during initialization
34
  if auth_token:
35
  self.processor = Wav2Vec2Processor.from_pretrained(model_name, token=auth_token)
36
  self.model = Wav2Vec2ForCTC.from_pretrained(model_name, token=auth_token)
@@ -39,14 +32,19 @@ class QuranRecitationComparer:
39
  self.model = Wav2Vec2ForCTC.from_pretrained(model_name)
40
  self.model = self.model.to(self.device)
41
  self.model.eval()
42
- # Cache for embeddings to avoid recomputation
43
  self.embedding_cache = {}
44
 
45
  def load_audio(self, file_path, target_sr=16000, trim_silence=True, normalize=True):
46
  """Load and preprocess an audio file."""
47
  if not os.path.exists(file_path):
48
  raise FileNotFoundError(f"Audio file not found: {file_path}")
49
- y, sr = librosa.load(file_path, sr=target_sr)
 
 
 
 
 
50
  if normalize:
51
  y = librosa.util.normalize(y)
52
  if trim_silence:
@@ -121,15 +119,26 @@ class QuranRecitationComparer:
121
  """Clear the embedding cache to free memory."""
122
  self.embedding_cache = {}
123
 
124
- # Retrieve HuggingFace auth token from environment variable (if needed).
125
- hf_auth_token = os.getenv("HF_AUTH_TOKEN")
126
- # Initialize the comparer instance once at startup.
127
- comparer = QuranRecitationComparer(auth_token=hf_auth_token)
 
 
 
 
 
 
 
 
 
 
 
 
128
 
129
  # ---------------------------
130
  # API Endpoints
131
  # ---------------------------
132
-
133
  @app.get("/")
134
  async def root():
135
  return {
@@ -188,8 +197,6 @@ Provide your response with:
188
  )
189
  ]
190
  )
191
-
192
- # Return the model's response.
193
  return {"result": response.text}
194
 
195
  @app.post("/compare-dtw")
 
4
  import uvicorn
5
  import torch
6
  import librosa
7
+ from audioread.exceptions import NoBackendError
8
  from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
9
  from librosa.sequence import dtw
10
  from google import genai
 
12
 
13
  app = FastAPI()
14
 
15
+ # Global variables to hold our loaded models/clients.
16
+ client = None
17
+ comparer = None
 
 
 
 
 
 
 
 
18
 
19
  # ---------------------------
20
  # DTW-based Comparison Class
 
23
  def __init__(self, model_name="jonatasgrosman/wav2vec2-large-xlsr-53-arabic", auth_token=None):
24
  """Initialize the Quran recitation comparer with a specific Wav2Vec2 model."""
25
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
26
+ # Load model and processor once during initialization.
27
  if auth_token:
28
  self.processor = Wav2Vec2Processor.from_pretrained(model_name, token=auth_token)
29
  self.model = Wav2Vec2ForCTC.from_pretrained(model_name, token=auth_token)
 
32
  self.model = Wav2Vec2ForCTC.from_pretrained(model_name)
33
  self.model = self.model.to(self.device)
34
  self.model.eval()
35
+ # Cache for embeddings to avoid recomputation.
36
  self.embedding_cache = {}
37
 
38
  def load_audio(self, file_path, target_sr=16000, trim_silence=True, normalize=True):
39
  """Load and preprocess an audio file."""
40
  if not os.path.exists(file_path):
41
  raise FileNotFoundError(f"Audio file not found: {file_path}")
42
+ try:
43
+ y, sr = librosa.load(file_path, sr=target_sr)
44
+ except NoBackendError as e:
45
+ raise RuntimeError(
46
+ "Failed to load audio using librosa. Please ensure you have a valid audio backend installed (e.g., ffmpeg)."
47
+ ) from e
48
  if normalize:
49
  y = librosa.util.normalize(y)
50
  if trim_silence:
 
119
  """Clear the embedding cache to free memory."""
120
  self.embedding_cache = {}
121
 
122
+ # ---------------------------
123
+ # Application Startup
124
+ # ---------------------------
125
+ @app.on_event("startup")
126
+ async def startup_event():
127
+ global client, comparer
128
+ # Load the GenAI API key from environment variable.
129
+ genai_api_key = os.getenv("GENAI_API_KEY")
130
+ if not genai_api_key:
131
+ raise EnvironmentError("GENAI_API_KEY environment variable not set")
132
+ client = genai.Client(api_key=genai_api_key)
133
+
134
+ # Retrieve HuggingFace auth token from environment variable (if needed).
135
+ hf_auth_token = os.getenv("HF_AUTH_TOKEN")
136
+ # Initialize the comparer instance once at startup.
137
+ comparer = QuranRecitationComparer(auth_token=hf_auth_token)
138
 
139
  # ---------------------------
140
  # API Endpoints
141
  # ---------------------------
 
142
  @app.get("/")
143
  async def root():
144
  return {
 
197
  )
198
  ]
199
  )
 
 
200
  return {"result": response.text}
201
 
202
  @app.post("/compare-dtw")