|
from fastapi import FastAPI, HTTPException |
|
from fastapi.responses import JSONResponse, FileResponse |
|
from pydantic import BaseModel |
|
import numpy as np |
|
import io |
|
import soundfile as sf |
|
import base64 |
|
import logging |
|
import torch |
|
import librosa |
|
from pathlib import Path |
|
import magic |
|
from pydub import AudioSegment |
|
import traceback |
|
from logging.handlers import RotatingFileHandler |
|
|
|
|
|
from asr import transcribe, ASR_LANGUAGES |
|
from tts import synthesize, TTS_LANGUAGES |
|
from lid import identify |
|
from asr import ASR_SAMPLING_RATE |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
file_handler = RotatingFileHandler('app.log', maxBytes=10000000, backupCount=5) |
|
file_handler.setLevel(logging.INFO) |
|
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') |
|
file_handler.setFormatter(formatter) |
|
logger.addHandler(file_handler) |
|
|
|
app = FastAPI(title="MMS: Scaling Speech Technology to 1000+ languages") |
|
|
|
|
|
class AudioRequest(BaseModel): |
|
audio: str |
|
language: str |
|
|
|
class TTSRequest(BaseModel): |
|
text: str |
|
language: str |
|
speed: float |
|
|
|
def detect_mime_type(input_bytes): |
|
mime = magic.Magic(mime=True) |
|
return mime.from_buffer(input_bytes) |
|
|
|
def extract_audio(input_bytes): |
|
mime_type = detect_mime_type(input_bytes) |
|
|
|
if mime_type.startswith('audio/'): |
|
return sf.read(io.BytesIO(input_bytes)) |
|
elif mime_type.startswith('video/webm'): |
|
audio = AudioSegment.from_file(io.BytesIO(input_bytes), format="webm") |
|
audio_array = np.array(audio.get_array_of_samples()) |
|
sample_rate = audio.frame_rate |
|
return audio_array, sample_rate |
|
else: |
|
raise ValueError(f"Unsupported MIME type: {mime_type}") |
|
|
|
@app.post("/transcribe") |
|
async def transcribe_audio(request: AudioRequest): |
|
try: |
|
input_bytes = base64.b64decode(request.audio) |
|
audio_array, sample_rate = extract_audio(input_bytes) |
|
|
|
|
|
if len(audio_array.shape) > 1: |
|
audio_array = audio_array.mean(axis=1) |
|
|
|
|
|
audio_array = audio_array.astype(np.float32) |
|
|
|
|
|
if sample_rate != ASR_SAMPLING_RATE: |
|
audio_array = librosa.resample(audio_array, orig_sr=sample_rate, target_sr=ASR_SAMPLING_RATE) |
|
|
|
result = transcribe(audio_array, request.language) |
|
return JSONResponse(content={"transcription": result}) |
|
except Exception as e: |
|
logger.error(f"Error in transcribe_audio: {str(e)}", exc_info=True) |
|
error_details = { |
|
"error": str(e), |
|
"traceback": traceback.format_exc() |
|
} |
|
return JSONResponse( |
|
status_code=500, |
|
content={"message": "An error occurred during transcription", "details": error_details} |
|
) |
|
|
|
@app.post("/synthesize") |
|
async def synthesize_speech(request: TTSRequest): |
|
logger.info(f"Synthesize request received: text='{request.text}', language='{request.language}', speed={request.speed}") |
|
try: |
|
|
|
logger.info("Validating input parameters") |
|
if not request.text: |
|
raise ValueError("Text cannot be empty") |
|
if request.language not in TTS_LANGUAGES: |
|
raise ValueError(f"Unsupported language: {request.language}") |
|
if not 0.5 <= request.speed <= 2.0: |
|
raise ValueError(f"Speed must be between 0.5 and 2.0, got {request.speed}") |
|
|
|
logger.info("Calling synthesize function") |
|
result, filtered_text = synthesize(request.text, request.language, request.speed) |
|
logger.info(f"Synthesize function completed. Filtered text: '{filtered_text}'") |
|
|
|
if result is None: |
|
logger.error("Synthesize function returned None") |
|
raise ValueError("Synthesis failed to produce audio") |
|
|
|
sample_rate, audio = result |
|
logger.info(f"Synthesis result: sample_rate={sample_rate}, audio_shape={audio.shape if isinstance(audio, np.ndarray) else 'not numpy array'}, audio_dtype={audio.dtype if isinstance(audio, np.ndarray) else type(audio)}") |
|
|
|
logger.info("Converting audio to numpy array") |
|
audio = np.array(audio, dtype=np.float32) |
|
logger.info(f"Converted audio shape: {audio.shape}, dtype: {audio.dtype}") |
|
|
|
logger.info("Normalizing audio") |
|
max_value = np.max(np.abs(audio)) |
|
if max_value == 0: |
|
logger.warning("Audio array is all zeros") |
|
raise ValueError("Generated audio is silent (all zeros)") |
|
audio = audio / max_value |
|
logger.info(f"Normalized audio range: [{audio.min()}, {audio.max()}]") |
|
|
|
logger.info("Converting to int16") |
|
audio = (audio * 32767).astype(np.int16) |
|
logger.info(f"Int16 audio shape: {audio.shape}, dtype: {audio.dtype}") |
|
|
|
logger.info("Writing audio to buffer") |
|
buffer = io.BytesIO() |
|
sf.write(buffer, audio, sample_rate, format='wav') |
|
buffer.seek(0) |
|
logger.info(f"Buffer size: {buffer.getbuffer().nbytes} bytes") |
|
|
|
logger.info("Preparing FileResponse") |
|
response = FileResponse( |
|
buffer, |
|
media_type="audio/wav", |
|
headers={"Content-Disposition": "attachment; filename=synthesized_audio.wav"} |
|
) |
|
logger.info("FileResponse prepared successfully") |
|
|
|
return response |
|
|
|
except ValueError as ve: |
|
logger.error(f"ValueError in synthesize_speech: {str(ve)}", exc_info=True) |
|
return JSONResponse( |
|
status_code=400, |
|
content={"message": "Invalid input", "details": str(ve)} |
|
) |
|
except Exception as e: |
|
logger.error(f"Unexpected error in synthesize_speech: {str(e)}", exc_info=True) |
|
error_details = { |
|
"error": str(e), |
|
"type": type(e).__name__, |
|
"traceback": traceback.format_exc() |
|
} |
|
return JSONResponse( |
|
status_code=500, |
|
content={"message": "An unexpected error occurred during speech synthesis", "details": error_details} |
|
) |
|
finally: |
|
logger.info("Synthesize request completed") |
|
|
|
@app.post("/identify") |
|
async def identify_language(request: AudioRequest): |
|
try: |
|
input_bytes = base64.b64decode(request.audio) |
|
audio_array, sample_rate = extract_audio(input_bytes) |
|
result = identify(audio_array) |
|
return JSONResponse(content={"language_identification": result}) |
|
except Exception as e: |
|
logger.error(f"Error in identify_language: {str(e)}", exc_info=True) |
|
error_details = { |
|
"error": str(e), |
|
"traceback": traceback.format_exc() |
|
} |
|
return JSONResponse( |
|
status_code=500, |
|
content={"message": "An error occurred during language identification", "details": error_details} |
|
) |
|
|
|
@app.get("/asr_languages") |
|
async def get_asr_languages(): |
|
try: |
|
return JSONResponse(content=ASR_LANGUAGES) |
|
except Exception as e: |
|
logger.error(f"Error in get_asr_languages: {str(e)}", exc_info=True) |
|
error_details = { |
|
"error": str(e), |
|
"traceback": traceback.format_exc() |
|
} |
|
return JSONResponse( |
|
status_code=500, |
|
content={"message": "An error occurred while fetching ASR languages", "details": error_details} |
|
) |
|
|
|
@app.get("/tts_languages") |
|
async def get_tts_languages(): |
|
try: |
|
return JSONResponse(content=TTS_LANGUAGES) |
|
except Exception as e: |
|
logger.error(f"Error in get_tts_languages: {str(e)}", exc_info=True) |
|
error_details = { |
|
"error": str(e), |
|
"traceback": traceback.format_exc() |
|
} |
|
return JSONResponse( |
|
status_code=500, |
|
content={"message": "An error occurred while fetching TTS languages", "details": error_details} |
|
) |