|
from fastapi import FastAPI, HTTPException |
|
from fastapi.responses import JSONResponse, FileResponse |
|
from pydantic import BaseModel |
|
import numpy as np |
|
import io |
|
import soundfile as sf |
|
import base64 |
|
import logging |
|
import torch |
|
import librosa |
|
from transformers import Wav2Vec2ForCTC, AutoProcessor |
|
from pathlib import Path |
|
|
|
|
|
from asr import transcribe, ASR_LANGUAGES |
|
from tts import synthesize, TTS_LANGUAGES |
|
from lid import identify |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
|
|
app = FastAPI(title="MMS: Scaling Speech Technology to 1000+ languages") |
|
|
|
|
|
class AudioRequest(BaseModel): |
|
audio: str |
|
language: str |
|
|
|
class TTSRequest(BaseModel): |
|
text: str |
|
language: str |
|
speed: float |
|
|
|
@app.post("/transcribe") |
|
async def transcribe_audio(request: AudioRequest): |
|
try: |
|
audio_bytes = base64.b64decode(request.audio) |
|
audio_array, sample_rate = sf.read(io.BytesIO(audio_bytes)) |
|
|
|
|
|
if len(audio_array.shape) > 1: |
|
audio_array = audio_array.mean(axis=1) |
|
|
|
result = transcribe(audio_array, request.language) |
|
return JSONResponse(content={"transcription": result}) |
|
except Exception as e: |
|
logger.error(f"Error in transcribe_audio: {str(e)}") |
|
raise HTTPException(status_code=500, detail=f"An error occurred: {str(e)}") |
|
|
|
@app.post("/synthesize") |
|
async def synthesize_speech(request: TTSRequest): |
|
try: |
|
audio, filtered_text = synthesize(request.text, request.language, request.speed) |
|
|
|
|
|
buffer = io.BytesIO() |
|
sf.write(buffer, audio, 22050, format='wav') |
|
buffer.seek(0) |
|
|
|
return FileResponse( |
|
buffer, |
|
media_type="audio/wav", |
|
headers={"Content-Disposition": "attachment; filename=synthesized_audio.wav"} |
|
) |
|
except Exception as e: |
|
logger.error(f"Error in synthesize_speech: {str(e)}") |
|
raise HTTPException(status_code=500, detail=f"An error occurred: {str(e)}") |
|
|
|
@app.post("/identify") |
|
async def identify_language(request: AudioRequest): |
|
try: |
|
audio_bytes = base64.b64decode(request.audio) |
|
audio_array, sample_rate = sf.read(io.BytesIO(audio_bytes)) |
|
|
|
result = identify(audio_array) |
|
return JSONResponse(content={"language_identification": result}) |
|
except Exception as e: |
|
logger.error(f"Error in identify_language: {str(e)}") |
|
raise HTTPException(status_code=500, detail=f"An error occurred: {str(e)}") |
|
|
|
@app.get("/asr_languages") |
|
async def get_asr_languages(): |
|
try: |
|
return JSONResponse(content=ASR_LANGUAGES) |
|
except Exception as e: |
|
logger.error(f"Error in get_asr_languages: {str(e)}") |
|
raise HTTPException(status_code=500, detail=f"An error occurred: {str(e)}") |
|
|
|
@app.get("/tts_languages") |
|
async def get_tts_languages(): |
|
try: |
|
return JSONResponse(content=TTS_LANGUAGES) |
|
except Exception as e: |
|
logger.error(f"Error in get_tts_languages: {str(e)}") |
|
raise HTTPException(status_code=500, detail=f"An error occurred: {str(e)}") |
|
|