Spaces:
Edmond98
/
Running on A100

Edmond7 commited on
Commit
4b305c9
1 Parent(s): 9104ce6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +42 -38
app.py CHANGED
@@ -3,6 +3,7 @@ from fastapi import FastAPI, HTTPException, File, UploadFile, Depends, Security
3
  from fastapi.security.api_key import APIKeyHeader, APIKey
4
  from fastapi.responses import JSONResponse
5
  from pydantic import BaseModel
 
6
  import numpy as np
7
  import io
8
  import soundfile as sf
@@ -21,10 +22,9 @@ import time
21
  import tempfile
22
 
23
  # Import functions from other modules
24
- from asr import transcribe, ASR_LANGUAGES
25
  from tts import synthesize, TTS_LANGUAGES
26
  from lid import identify
27
- from asr import ASR_SAMPLING_RATE
28
 
29
  # Configure logging
30
  logging.basicConfig(level=logging.INFO)
@@ -60,15 +60,18 @@ s3_client = boto3.client(
60
  # Define request models
61
  class AudioRequest(BaseModel):
62
  audio: str # Base64 encoded audio or video data
63
- language: str
64
 
65
  class TTSRequest(BaseModel):
66
  text: str
67
- language: str
68
- speed: float
69
 
70
  class LanguageRequest(BaseModel):
71
- language: str
 
 
 
72
 
73
  async def get_api_key(api_key_header: str = Security(api_key_header)):
74
  if api_key_header == API_KEY:
@@ -140,7 +143,13 @@ async def transcribe_audio(request: AudioRequest, api_key: APIKey = Depends(get_
140
  if sample_rate != ASR_SAMPLING_RATE:
141
  audio_array = librosa.resample(audio_array, orig_sr=sample_rate, target_sr=ASR_SAMPLING_RATE)
142
 
143
- result = transcribe(audio_array, request.language)
 
 
 
 
 
 
144
  processing_time = time.time() - start_time
145
  return JSONResponse(content={"transcription": result, "processing_time_seconds": processing_time})
146
  except Exception as e:
@@ -156,7 +165,7 @@ async def transcribe_audio(request: AudioRequest, api_key: APIKey = Depends(get_
156
  )
157
 
158
  @app.post("/transcribe_file")
159
- async def transcribe_audio_file(file: UploadFile = File(...), language: str = "", api_key: APIKey = Depends(get_api_key)):
160
  start_time = time.time()
161
  try:
162
  contents = await file.read()
@@ -169,7 +178,13 @@ async def transcribe_audio_file(file: UploadFile = File(...), language: str = ""
169
  if sample_rate != ASR_SAMPLING_RATE:
170
  audio_array = librosa.resample(audio_array, orig_sr=sample_rate, target_sr=ASR_SAMPLING_RATE)
171
 
172
- result = transcribe(audio_array, language)
 
 
 
 
 
 
173
  processing_time = time.time() - start_time
174
  return JSONResponse(content={"transcription": result, "processing_time_seconds": processing_time})
175
  except Exception as e:
@@ -189,19 +204,23 @@ async def synthesize_speech(request: TTSRequest, api_key: APIKey = Depends(get_a
189
  start_time = time.time()
190
  logger.info(f"Synthesize request received: text='{request.text}', language='{request.language}', speed={request.speed}")
191
  try:
192
- # Extract the ISO code from the full language name
193
- lang_code = request.language.split()[0].strip()
 
 
 
 
194
 
195
  # Input validation
196
  if not request.text:
197
  raise ValueError("Text cannot be empty")
198
  if lang_code not in TTS_LANGUAGES:
199
- raise ValueError(f"Unsupported language: {request.language}")
200
  if not 0.5 <= request.speed <= 2.0:
201
  raise ValueError(f"Speed must be between 0.5 and 2.0, got {request.speed}")
202
 
203
  logger.info(f"Calling synthesize function with lang_code: {lang_code}")
204
- result, filtered_text = synthesize(request.text, request.language, request.speed)
205
  logger.info(f"Synthesize function completed. Filtered text: '{filtered_text}'")
206
 
207
  if result is None:
@@ -279,8 +298,6 @@ async def synthesize_speech(request: TTSRequest, api_key: APIKey = Depends(get_a
279
  status_code=500,
280
  content={"message": "An unexpected error occurred during speech synthesis", "details": error_details, "processing_time_seconds": processing_time}
281
  )
282
- finally:
283
- logger.info("Synthesize request completed")
284
 
285
  @app.post("/identify")
286
  async def identify_language(request: AudioRequest, api_key: APIKey = Depends(get_api_key)):
@@ -328,22 +345,14 @@ async def identify_language_file(file: UploadFile = File(...), api_key: APIKey =
328
  async def get_asr_languages(request: LanguageRequest, api_key: APIKey = Depends(get_api_key)):
329
  start_time = time.time()
330
  try:
331
- if request.language.lower() not in [lang.lower() for lang in ASR_LANGUAGES]:
332
- raise ValueError(f"Unsupported language: {request.language}")
 
 
 
333
 
334
- matching_languages = [lang for lang in ASR_LANGUAGES if lang.lower().startswith(request.language.lower())]
335
- processing_time = time.time() - start_time
336
- return JSONResponse
337
- matching_languages = [lang for lang in ASR_LANGUAGES if lang.lower().startswith(request.language.lower())]
338
  processing_time = time.time() - start_time
339
  return JSONResponse(content={"languages": matching_languages, "processing_time_seconds": processing_time})
340
- except ValueError as ve:
341
- logger.error(f"ValueError in get_asr_languages: {str(ve)}", exc_info=True)
342
- processing_time = time.time() - start_time
343
- return JSONResponse(
344
- status_code=400,
345
- content={"message": "Invalid input", "details": str(ve), "processing_time_seconds": processing_time}
346
- )
347
  except Exception as e:
348
  logger.error(f"Error in get_asr_languages: {str(e)}", exc_info=True)
349
  error_details = {
@@ -360,19 +369,14 @@ async def get_asr_languages(request: LanguageRequest, api_key: APIKey = Depends(
360
  async def get_tts_languages(request: LanguageRequest, api_key: APIKey = Depends(get_api_key)):
361
  start_time = time.time()
362
  try:
363
- if request.language.lower() not in [lang.lower() for lang in TTS_LANGUAGES]:
364
- raise ValueError(f"Unsupported language: {request.language}")
 
 
 
365
 
366
- matching_languages = [lang for lang in TTS_LANGUAGES if lang.lower().startswith(request.language.lower())]
367
  processing_time = time.time() - start_time
368
  return JSONResponse(content={"languages": matching_languages, "processing_time_seconds": processing_time})
369
- except ValueError as ve:
370
- logger.error(f"ValueError in get_tts_languages: {str(ve)}", exc_info=True)
371
- processing_time = time.time() - start_time
372
- return JSONResponse(
373
- status_code=400,
374
- content={"message": "Invalid input", "details": str(ve), "processing_time_seconds": processing_time}
375
- )
376
  except Exception as e:
377
  logger.error(f"Error in get_tts_languages: {str(e)}", exc_info=True)
378
  error_details = {
 
3
  from fastapi.security.api_key import APIKeyHeader, APIKey
4
  from fastapi.responses import JSONResponse
5
  from pydantic import BaseModel
6
+ from typing import Optional
7
  import numpy as np
8
  import io
9
  import soundfile as sf
 
22
  import tempfile
23
 
24
  # Import functions from other modules
25
+ from asr import transcribe, ASR_LANGUAGES, ASR_SAMPLING_RATE
26
  from tts import synthesize, TTS_LANGUAGES
27
  from lid import identify
 
28
 
29
  # Configure logging
30
  logging.basicConfig(level=logging.INFO)
 
60
  # Define request models
61
  class AudioRequest(BaseModel):
62
  audio: str # Base64 encoded audio or video data
63
+ language: Optional[str] = None
64
 
65
  class TTSRequest(BaseModel):
66
  text: str
67
+ language: Optional[str] = None
68
+ speed: float = 1.0
69
 
70
  class LanguageRequest(BaseModel):
71
+ language: Optional[str] = None
72
+
73
+ class TranscribeFileRequest(BaseModel):
74
+ language: Optional[str] = None
75
 
76
  async def get_api_key(api_key_header: str = Security(api_key_header)):
77
  if api_key_header == API_KEY:
 
143
  if sample_rate != ASR_SAMPLING_RATE:
144
  audio_array = librosa.resample(audio_array, orig_sr=sample_rate, target_sr=ASR_SAMPLING_RATE)
145
 
146
+ if request.language is None:
147
+ # If no language is provided, use language identification
148
+ identified_language = identify(audio_array)
149
+ result = transcribe(audio_array, identified_language)
150
+ else:
151
+ result = transcribe(audio_array, request.language)
152
+
153
  processing_time = time.time() - start_time
154
  return JSONResponse(content={"transcription": result, "processing_time_seconds": processing_time})
155
  except Exception as e:
 
165
  )
166
 
167
  @app.post("/transcribe_file")
168
+ async def transcribe_audio_file(file: UploadFile = File(...), request: TranscribeFileRequest = Depends(), api_key: APIKey = Depends(get_api_key)):
169
  start_time = time.time()
170
  try:
171
  contents = await file.read()
 
178
  if sample_rate != ASR_SAMPLING_RATE:
179
  audio_array = librosa.resample(audio_array, orig_sr=sample_rate, target_sr=ASR_SAMPLING_RATE)
180
 
181
+ if request.language is None:
182
+ # If no language is provided, use language identification
183
+ identified_language = identify(audio_array)
184
+ result = transcribe(audio_array, identified_language)
185
+ else:
186
+ result = transcribe(audio_array, request.language)
187
+
188
  processing_time = time.time() - start_time
189
  return JSONResponse(content={"transcription": result, "processing_time_seconds": processing_time})
190
  except Exception as e:
 
204
  start_time = time.time()
205
  logger.info(f"Synthesize request received: text='{request.text}', language='{request.language}', speed={request.speed}")
206
  try:
207
+ if request.language is None:
208
+ # If no language is provided, default to English
209
+ lang_code = "eng"
210
+ else:
211
+ # Extract the ISO code from the full language name
212
+ lang_code = request.language.split()[0].strip()
213
 
214
  # Input validation
215
  if not request.text:
216
  raise ValueError("Text cannot be empty")
217
  if lang_code not in TTS_LANGUAGES:
218
+ raise ValueError(f"Unsupported language: {lang_code}")
219
  if not 0.5 <= request.speed <= 2.0:
220
  raise ValueError(f"Speed must be between 0.5 and 2.0, got {request.speed}")
221
 
222
  logger.info(f"Calling synthesize function with lang_code: {lang_code}")
223
+ result, filtered_text = synthesize(request.text, lang_code, request.speed)
224
  logger.info(f"Synthesize function completed. Filtered text: '{filtered_text}'")
225
 
226
  if result is None:
 
298
  status_code=500,
299
  content={"message": "An unexpected error occurred during speech synthesis", "details": error_details, "processing_time_seconds": processing_time}
300
  )
 
 
301
 
302
  @app.post("/identify")
303
  async def identify_language(request: AudioRequest, api_key: APIKey = Depends(get_api_key)):
 
345
  async def get_asr_languages(request: LanguageRequest, api_key: APIKey = Depends(get_api_key)):
346
  start_time = time.time()
347
  try:
348
+ if request.language is None or request.language == "":
349
+ # If no language is provided, return all languages
350
+ matching_languages = ASR_LANGUAGES
351
+ else:
352
+ matching_languages = [lang for lang in ASR_LANGUAGES if lang.lower().startswith(request.language.lower())]
353
 
 
 
 
 
354
  processing_time = time.time() - start_time
355
  return JSONResponse(content={"languages": matching_languages, "processing_time_seconds": processing_time})
 
 
 
 
 
 
 
356
  except Exception as e:
357
  logger.error(f"Error in get_asr_languages: {str(e)}", exc_info=True)
358
  error_details = {
 
369
  async def get_tts_languages(request: LanguageRequest, api_key: APIKey = Depends(get_api_key)):
370
  start_time = time.time()
371
  try:
372
+ if request.language is None or request.language == "":
373
+ # If no language is provided, return all languages
374
+ matching_languages = TTS_LANGUAGES
375
+ else:
376
+ matching_languages = [lang for lang in TTS_LANGUAGES if lang.lower().startswith(request.language.lower())]
377
 
 
378
  processing_time = time.time() - start_time
379
  return JSONResponse(content={"languages": matching_languages, "processing_time_seconds": processing_time})
 
 
 
 
 
 
 
380
  except Exception as e:
381
  logger.error(f"Error in get_tts_languages: {str(e)}", exc_info=True)
382
  error_details = {