Update app.py
Browse files
app.py
CHANGED
@@ -3,6 +3,7 @@ from fastapi import FastAPI, HTTPException, File, UploadFile, Depends, Security
|
|
3 |
from fastapi.security.api_key import APIKeyHeader, APIKey
|
4 |
from fastapi.responses import JSONResponse
|
5 |
from pydantic import BaseModel
|
|
|
6 |
import numpy as np
|
7 |
import io
|
8 |
import soundfile as sf
|
@@ -21,10 +22,9 @@ import time
|
|
21 |
import tempfile
|
22 |
|
23 |
# Import functions from other modules
|
24 |
-
from asr import transcribe, ASR_LANGUAGES
|
25 |
from tts import synthesize, TTS_LANGUAGES
|
26 |
from lid import identify
|
27 |
-
from asr import ASR_SAMPLING_RATE
|
28 |
|
29 |
# Configure logging
|
30 |
logging.basicConfig(level=logging.INFO)
|
@@ -60,15 +60,18 @@ s3_client = boto3.client(
|
|
60 |
# Define request models
|
61 |
class AudioRequest(BaseModel):
|
62 |
audio: str # Base64 encoded audio or video data
|
63 |
-
language: str
|
64 |
|
65 |
class TTSRequest(BaseModel):
|
66 |
text: str
|
67 |
-
language: str
|
68 |
-
speed: float
|
69 |
|
70 |
class LanguageRequest(BaseModel):
|
71 |
-
language: str
|
|
|
|
|
|
|
72 |
|
73 |
async def get_api_key(api_key_header: str = Security(api_key_header)):
|
74 |
if api_key_header == API_KEY:
|
@@ -140,7 +143,13 @@ async def transcribe_audio(request: AudioRequest, api_key: APIKey = Depends(get_
|
|
140 |
if sample_rate != ASR_SAMPLING_RATE:
|
141 |
audio_array = librosa.resample(audio_array, orig_sr=sample_rate, target_sr=ASR_SAMPLING_RATE)
|
142 |
|
143 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
144 |
processing_time = time.time() - start_time
|
145 |
return JSONResponse(content={"transcription": result, "processing_time_seconds": processing_time})
|
146 |
except Exception as e:
|
@@ -156,7 +165,7 @@ async def transcribe_audio(request: AudioRequest, api_key: APIKey = Depends(get_
|
|
156 |
)
|
157 |
|
158 |
@app.post("/transcribe_file")
|
159 |
-
async def transcribe_audio_file(file: UploadFile = File(...),
|
160 |
start_time = time.time()
|
161 |
try:
|
162 |
contents = await file.read()
|
@@ -169,7 +178,13 @@ async def transcribe_audio_file(file: UploadFile = File(...), language: str = ""
|
|
169 |
if sample_rate != ASR_SAMPLING_RATE:
|
170 |
audio_array = librosa.resample(audio_array, orig_sr=sample_rate, target_sr=ASR_SAMPLING_RATE)
|
171 |
|
172 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
173 |
processing_time = time.time() - start_time
|
174 |
return JSONResponse(content={"transcription": result, "processing_time_seconds": processing_time})
|
175 |
except Exception as e:
|
@@ -189,19 +204,23 @@ async def synthesize_speech(request: TTSRequest, api_key: APIKey = Depends(get_a
|
|
189 |
start_time = time.time()
|
190 |
logger.info(f"Synthesize request received: text='{request.text}', language='{request.language}', speed={request.speed}")
|
191 |
try:
|
192 |
-
|
193 |
-
|
|
|
|
|
|
|
|
|
194 |
|
195 |
# Input validation
|
196 |
if not request.text:
|
197 |
raise ValueError("Text cannot be empty")
|
198 |
if lang_code not in TTS_LANGUAGES:
|
199 |
-
raise ValueError(f"Unsupported language: {
|
200 |
if not 0.5 <= request.speed <= 2.0:
|
201 |
raise ValueError(f"Speed must be between 0.5 and 2.0, got {request.speed}")
|
202 |
|
203 |
logger.info(f"Calling synthesize function with lang_code: {lang_code}")
|
204 |
-
result, filtered_text = synthesize(request.text,
|
205 |
logger.info(f"Synthesize function completed. Filtered text: '{filtered_text}'")
|
206 |
|
207 |
if result is None:
|
@@ -279,8 +298,6 @@ async def synthesize_speech(request: TTSRequest, api_key: APIKey = Depends(get_a
|
|
279 |
status_code=500,
|
280 |
content={"message": "An unexpected error occurred during speech synthesis", "details": error_details, "processing_time_seconds": processing_time}
|
281 |
)
|
282 |
-
finally:
|
283 |
-
logger.info("Synthesize request completed")
|
284 |
|
285 |
@app.post("/identify")
|
286 |
async def identify_language(request: AudioRequest, api_key: APIKey = Depends(get_api_key)):
|
@@ -328,22 +345,14 @@ async def identify_language_file(file: UploadFile = File(...), api_key: APIKey =
|
|
328 |
async def get_asr_languages(request: LanguageRequest, api_key: APIKey = Depends(get_api_key)):
|
329 |
start_time = time.time()
|
330 |
try:
|
331 |
-
if request.language
|
332 |
-
|
|
|
|
|
|
|
333 |
|
334 |
-
matching_languages = [lang for lang in ASR_LANGUAGES if lang.lower().startswith(request.language.lower())]
|
335 |
-
processing_time = time.time() - start_time
|
336 |
-
return JSONResponse
|
337 |
-
matching_languages = [lang for lang in ASR_LANGUAGES if lang.lower().startswith(request.language.lower())]
|
338 |
processing_time = time.time() - start_time
|
339 |
return JSONResponse(content={"languages": matching_languages, "processing_time_seconds": processing_time})
|
340 |
-
except ValueError as ve:
|
341 |
-
logger.error(f"ValueError in get_asr_languages: {str(ve)}", exc_info=True)
|
342 |
-
processing_time = time.time() - start_time
|
343 |
-
return JSONResponse(
|
344 |
-
status_code=400,
|
345 |
-
content={"message": "Invalid input", "details": str(ve), "processing_time_seconds": processing_time}
|
346 |
-
)
|
347 |
except Exception as e:
|
348 |
logger.error(f"Error in get_asr_languages: {str(e)}", exc_info=True)
|
349 |
error_details = {
|
@@ -360,19 +369,14 @@ async def get_asr_languages(request: LanguageRequest, api_key: APIKey = Depends(
|
|
360 |
async def get_tts_languages(request: LanguageRequest, api_key: APIKey = Depends(get_api_key)):
|
361 |
start_time = time.time()
|
362 |
try:
|
363 |
-
if request.language
|
364 |
-
|
|
|
|
|
|
|
365 |
|
366 |
-
matching_languages = [lang for lang in TTS_LANGUAGES if lang.lower().startswith(request.language.lower())]
|
367 |
processing_time = time.time() - start_time
|
368 |
return JSONResponse(content={"languages": matching_languages, "processing_time_seconds": processing_time})
|
369 |
-
except ValueError as ve:
|
370 |
-
logger.error(f"ValueError in get_tts_languages: {str(ve)}", exc_info=True)
|
371 |
-
processing_time = time.time() - start_time
|
372 |
-
return JSONResponse(
|
373 |
-
status_code=400,
|
374 |
-
content={"message": "Invalid input", "details": str(ve), "processing_time_seconds": processing_time}
|
375 |
-
)
|
376 |
except Exception as e:
|
377 |
logger.error(f"Error in get_tts_languages: {str(e)}", exc_info=True)
|
378 |
error_details = {
|
|
|
3 |
from fastapi.security.api_key import APIKeyHeader, APIKey
|
4 |
from fastapi.responses import JSONResponse
|
5 |
from pydantic import BaseModel
|
6 |
+
from typing import Optional
|
7 |
import numpy as np
|
8 |
import io
|
9 |
import soundfile as sf
|
|
|
22 |
import tempfile
|
23 |
|
24 |
# Import functions from other modules
|
25 |
+
from asr import transcribe, ASR_LANGUAGES, ASR_SAMPLING_RATE
|
26 |
from tts import synthesize, TTS_LANGUAGES
|
27 |
from lid import identify
|
|
|
28 |
|
29 |
# Configure logging
|
30 |
logging.basicConfig(level=logging.INFO)
|
|
|
60 |
# Define request models
|
61 |
class AudioRequest(BaseModel):
|
62 |
audio: str # Base64 encoded audio or video data
|
63 |
+
language: Optional[str] = None
|
64 |
|
65 |
class TTSRequest(BaseModel):
|
66 |
text: str
|
67 |
+
language: Optional[str] = None
|
68 |
+
speed: float = 1.0
|
69 |
|
70 |
class LanguageRequest(BaseModel):
|
71 |
+
language: Optional[str] = None
|
72 |
+
|
73 |
+
class TranscribeFileRequest(BaseModel):
|
74 |
+
language: Optional[str] = None
|
75 |
|
76 |
async def get_api_key(api_key_header: str = Security(api_key_header)):
|
77 |
if api_key_header == API_KEY:
|
|
|
143 |
if sample_rate != ASR_SAMPLING_RATE:
|
144 |
audio_array = librosa.resample(audio_array, orig_sr=sample_rate, target_sr=ASR_SAMPLING_RATE)
|
145 |
|
146 |
+
if request.language is None:
|
147 |
+
# If no language is provided, use language identification
|
148 |
+
identified_language = identify(audio_array)
|
149 |
+
result = transcribe(audio_array, identified_language)
|
150 |
+
else:
|
151 |
+
result = transcribe(audio_array, request.language)
|
152 |
+
|
153 |
processing_time = time.time() - start_time
|
154 |
return JSONResponse(content={"transcription": result, "processing_time_seconds": processing_time})
|
155 |
except Exception as e:
|
|
|
165 |
)
|
166 |
|
167 |
@app.post("/transcribe_file")
|
168 |
+
async def transcribe_audio_file(file: UploadFile = File(...), request: TranscribeFileRequest = Depends(), api_key: APIKey = Depends(get_api_key)):
|
169 |
start_time = time.time()
|
170 |
try:
|
171 |
contents = await file.read()
|
|
|
178 |
if sample_rate != ASR_SAMPLING_RATE:
|
179 |
audio_array = librosa.resample(audio_array, orig_sr=sample_rate, target_sr=ASR_SAMPLING_RATE)
|
180 |
|
181 |
+
if request.language is None:
|
182 |
+
# If no language is provided, use language identification
|
183 |
+
identified_language = identify(audio_array)
|
184 |
+
result = transcribe(audio_array, identified_language)
|
185 |
+
else:
|
186 |
+
result = transcribe(audio_array, request.language)
|
187 |
+
|
188 |
processing_time = time.time() - start_time
|
189 |
return JSONResponse(content={"transcription": result, "processing_time_seconds": processing_time})
|
190 |
except Exception as e:
|
|
|
204 |
start_time = time.time()
|
205 |
logger.info(f"Synthesize request received: text='{request.text}', language='{request.language}', speed={request.speed}")
|
206 |
try:
|
207 |
+
if request.language is None:
|
208 |
+
# If no language is provided, default to English
|
209 |
+
lang_code = "eng"
|
210 |
+
else:
|
211 |
+
# Extract the ISO code from the full language name
|
212 |
+
lang_code = request.language.split()[0].strip()
|
213 |
|
214 |
# Input validation
|
215 |
if not request.text:
|
216 |
raise ValueError("Text cannot be empty")
|
217 |
if lang_code not in TTS_LANGUAGES:
|
218 |
+
raise ValueError(f"Unsupported language: {lang_code}")
|
219 |
if not 0.5 <= request.speed <= 2.0:
|
220 |
raise ValueError(f"Speed must be between 0.5 and 2.0, got {request.speed}")
|
221 |
|
222 |
logger.info(f"Calling synthesize function with lang_code: {lang_code}")
|
223 |
+
result, filtered_text = synthesize(request.text, lang_code, request.speed)
|
224 |
logger.info(f"Synthesize function completed. Filtered text: '{filtered_text}'")
|
225 |
|
226 |
if result is None:
|
|
|
298 |
status_code=500,
|
299 |
content={"message": "An unexpected error occurred during speech synthesis", "details": error_details, "processing_time_seconds": processing_time}
|
300 |
)
|
|
|
|
|
301 |
|
302 |
@app.post("/identify")
|
303 |
async def identify_language(request: AudioRequest, api_key: APIKey = Depends(get_api_key)):
|
|
|
345 |
async def get_asr_languages(request: LanguageRequest, api_key: APIKey = Depends(get_api_key)):
|
346 |
start_time = time.time()
|
347 |
try:
|
348 |
+
if request.language is None or request.language == "":
|
349 |
+
# If no language is provided, return all languages
|
350 |
+
matching_languages = ASR_LANGUAGES
|
351 |
+
else:
|
352 |
+
matching_languages = [lang for lang in ASR_LANGUAGES if lang.lower().startswith(request.language.lower())]
|
353 |
|
|
|
|
|
|
|
|
|
354 |
processing_time = time.time() - start_time
|
355 |
return JSONResponse(content={"languages": matching_languages, "processing_time_seconds": processing_time})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
356 |
except Exception as e:
|
357 |
logger.error(f"Error in get_asr_languages: {str(e)}", exc_info=True)
|
358 |
error_details = {
|
|
|
369 |
async def get_tts_languages(request: LanguageRequest, api_key: APIKey = Depends(get_api_key)):
|
370 |
start_time = time.time()
|
371 |
try:
|
372 |
+
if request.language is None or request.language == "":
|
373 |
+
# If no language is provided, return all languages
|
374 |
+
matching_languages = TTS_LANGUAGES
|
375 |
+
else:
|
376 |
+
matching_languages = [lang for lang in TTS_LANGUAGES if lang.lower().startswith(request.language.lower())]
|
377 |
|
|
|
378 |
processing_time = time.time() - start_time
|
379 |
return JSONResponse(content={"languages": matching_languages, "processing_time_seconds": processing_time})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
380 |
except Exception as e:
|
381 |
logger.error(f"Error in get_tts_languages: {str(e)}", exc_info=True)
|
382 |
error_details = {
|