Spaces:
Sleeping
Sleeping
Update main.py
Browse files
main.py
CHANGED
@@ -4,6 +4,7 @@ from fastapi import FastAPI, UploadFile, File
|
|
4 |
import uvicorn
|
5 |
import torch
|
6 |
import librosa
|
|
|
7 |
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
|
8 |
from librosa.sequence import dtw
|
9 |
from google import genai
|
@@ -11,17 +12,9 @@ from google.genai import types
|
|
11 |
|
12 |
app = FastAPI()
|
13 |
|
14 |
-
#
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
# Retrieve the GenAI API key from the environment variable.
|
19 |
-
genai_api_key = os.getenv("GENAI_API_KEY")
|
20 |
-
if not genai_api_key:
|
21 |
-
raise EnvironmentError("GENAI_API_KEY environment variable not set")
|
22 |
-
|
23 |
-
# Initialize the GenAI client.
|
24 |
-
client = genai.Client(api_key=genai_api_key)
|
25 |
|
26 |
# ---------------------------
|
27 |
# DTW-based Comparison Class
|
@@ -30,7 +23,7 @@ class QuranRecitationComparer:
|
|
30 |
def __init__(self, model_name="jonatasgrosman/wav2vec2-large-xlsr-53-arabic", auth_token=None):
|
31 |
"""Initialize the Quran recitation comparer with a specific Wav2Vec2 model."""
|
32 |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
33 |
-
# Load model and processor once during initialization
|
34 |
if auth_token:
|
35 |
self.processor = Wav2Vec2Processor.from_pretrained(model_name, token=auth_token)
|
36 |
self.model = Wav2Vec2ForCTC.from_pretrained(model_name, token=auth_token)
|
@@ -39,14 +32,19 @@ class QuranRecitationComparer:
|
|
39 |
self.model = Wav2Vec2ForCTC.from_pretrained(model_name)
|
40 |
self.model = self.model.to(self.device)
|
41 |
self.model.eval()
|
42 |
-
# Cache for embeddings to avoid recomputation
|
43 |
self.embedding_cache = {}
|
44 |
|
45 |
def load_audio(self, file_path, target_sr=16000, trim_silence=True, normalize=True):
|
46 |
"""Load and preprocess an audio file."""
|
47 |
if not os.path.exists(file_path):
|
48 |
raise FileNotFoundError(f"Audio file not found: {file_path}")
|
49 |
-
|
|
|
|
|
|
|
|
|
|
|
50 |
if normalize:
|
51 |
y = librosa.util.normalize(y)
|
52 |
if trim_silence:
|
@@ -121,15 +119,26 @@ class QuranRecitationComparer:
|
|
121 |
"""Clear the embedding cache to free memory."""
|
122 |
self.embedding_cache = {}
|
123 |
|
124 |
-
#
|
125 |
-
|
126 |
-
#
|
127 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
128 |
|
129 |
# ---------------------------
|
130 |
# API Endpoints
|
131 |
# ---------------------------
|
132 |
-
|
133 |
@app.get("/")
|
134 |
async def root():
|
135 |
return {
|
@@ -188,8 +197,6 @@ Provide your response with:
|
|
188 |
)
|
189 |
]
|
190 |
)
|
191 |
-
|
192 |
-
# Return the model's response.
|
193 |
return {"result": response.text}
|
194 |
|
195 |
@app.post("/compare-dtw")
|
|
|
4 |
import uvicorn
|
5 |
import torch
|
6 |
import librosa
|
7 |
+
from audioread.exceptions import NoBackendError
|
8 |
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
|
9 |
from librosa.sequence import dtw
|
10 |
from google import genai
|
|
|
12 |
|
13 |
app = FastAPI()
|
14 |
|
15 |
+
# Global variables to hold our loaded models/clients.
|
16 |
+
client = None
|
17 |
+
comparer = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
|
19 |
# ---------------------------
|
20 |
# DTW-based Comparison Class
|
|
|
23 |
def __init__(self, model_name="jonatasgrosman/wav2vec2-large-xlsr-53-arabic", auth_token=None):
|
24 |
"""Initialize the Quran recitation comparer with a specific Wav2Vec2 model."""
|
25 |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
26 |
+
# Load model and processor once during initialization.
|
27 |
if auth_token:
|
28 |
self.processor = Wav2Vec2Processor.from_pretrained(model_name, token=auth_token)
|
29 |
self.model = Wav2Vec2ForCTC.from_pretrained(model_name, token=auth_token)
|
|
|
32 |
self.model = Wav2Vec2ForCTC.from_pretrained(model_name)
|
33 |
self.model = self.model.to(self.device)
|
34 |
self.model.eval()
|
35 |
+
# Cache for embeddings to avoid recomputation.
|
36 |
self.embedding_cache = {}
|
37 |
|
38 |
def load_audio(self, file_path, target_sr=16000, trim_silence=True, normalize=True):
|
39 |
"""Load and preprocess an audio file."""
|
40 |
if not os.path.exists(file_path):
|
41 |
raise FileNotFoundError(f"Audio file not found: {file_path}")
|
42 |
+
try:
|
43 |
+
y, sr = librosa.load(file_path, sr=target_sr)
|
44 |
+
except NoBackendError as e:
|
45 |
+
raise RuntimeError(
|
46 |
+
"Failed to load audio using librosa. Please ensure you have a valid audio backend installed (e.g., ffmpeg)."
|
47 |
+
) from e
|
48 |
if normalize:
|
49 |
y = librosa.util.normalize(y)
|
50 |
if trim_silence:
|
|
|
119 |
"""Clear the embedding cache to free memory."""
|
120 |
self.embedding_cache = {}
|
121 |
|
122 |
+
# ---------------------------
|
123 |
+
# Application Startup
|
124 |
+
# ---------------------------
|
125 |
+
@app.on_event("startup")
|
126 |
+
async def startup_event():
|
127 |
+
global client, comparer
|
128 |
+
# Load the GenAI API key from environment variable.
|
129 |
+
genai_api_key = os.getenv("GENAI_API_KEY")
|
130 |
+
if not genai_api_key:
|
131 |
+
raise EnvironmentError("GENAI_API_KEY environment variable not set")
|
132 |
+
client = genai.Client(api_key=genai_api_key)
|
133 |
+
|
134 |
+
# Retrieve HuggingFace auth token from environment variable (if needed).
|
135 |
+
hf_auth_token = os.getenv("HF_AUTH_TOKEN")
|
136 |
+
# Initialize the comparer instance once at startup.
|
137 |
+
comparer = QuranRecitationComparer(auth_token=hf_auth_token)
|
138 |
|
139 |
# ---------------------------
|
140 |
# API Endpoints
|
141 |
# ---------------------------
|
|
|
142 |
@app.get("/")
|
143 |
async def root():
|
144 |
return {
|
|
|
197 |
)
|
198 |
]
|
199 |
)
|
|
|
|
|
200 |
return {"result": response.text}
|
201 |
|
202 |
@app.post("/compare-dtw")
|