from fastapi import FastAPI, HTTPException, UploadFile, File |
from fastapi.responses import JSONResponse, FileResponse |
from pydantic import BaseModel |
import numpy as np |
import io |
import soundfile as sf |
import base64 |
import logging |
import torch |
import librosa |
from transformers import Wav2Vec2ForCTC, AutoProcessor |
from pathlib import Path |
from moviepy.editor import VideoFileClip |
import magic |
from asr import transcribe, ASR_LANGUAGES |
from tts import synthesize, TTS_LANGUAGES |
from lid import identify |
from asr import ASR_SAMPLING_RATE, transcribe |
logging.basicConfig(level=logging.INFO) |
logger = logging.getLogger(__name__) |
app = FastAPI(title="MMS: Scaling Speech Technology to 1000+ languages") |
class AudioRequest(BaseModel): |
audio: str |
language: str |
class TTSRequest(BaseModel): |
text: str |
language: str |
speed: float |
def detect_mime_type(input_bytes): |
mime = magic.Magic(mime=True) |
return mime.from_buffer(input_bytes) |
def extract_audio(input_bytes): |
mime_type = detect_mime_type(input_bytes) |
if mime_type.startswith('audio/'): |
return sf.read(io.BytesIO(input_bytes)) |
elif mime_type.startswith('video/'): |
with io.BytesIO(input_bytes) as f: |
video = VideoFileClip(f.name) |
audio = video.audio |
audio_array = audio.to_soundarray() |
sample_rate = audio.fps |
return audio_array, sample_rate |
else: |
raise ValueError(f"Unsupported MIME type: {mime_type}") |
@app.post("/transcribe") |
async def transcribe_audio(request: AudioRequest): |
try: |
input_bytes = base64.b64decode(request.audio) |
audio_array, sample_rate = extract_audio(input_bytes) |
if len(audio_array.shape) > 1: |
audio_array = audio_array.mean(axis=1) |
audio_array = audio_array.astype(np.float32) |
if sample_rate != ASR_SAMPLING_RATE: |
audio_array = librosa.resample(audio_array, orig_sr=sample_rate, target_sr=ASR_SAMPLING_RATE) |
result = transcribe(audio_array, request.language) |
return JSONResponse(content={"transcription": result}) |
except Exception as e: |
logger.error(f"Error in transcribe_audio: {str(e)}") |
raise HTTPException(status_code=500, detail=f"An error occurred: {str(e)}") |
@app.post("/synthesize") |
async def synthesize_speech(request: TTSRequest): |
try: |
audio, filtered_text = synthesize(request.text, request.language, request.speed) |
buffer = io.BytesIO() |
sf.write(buffer, audio, 22050, format='wav') |
buffer.seek(0) |
return FileResponse( |
buffer, |
media_type="audio/wav", |
headers={"Content-Disposition": "attachment; filename=synthesized_audio.wav"} |
) |
except Exception as e: |
logger.error(f"Error in synthesize_speech: {str(e)}") |
raise HTTPException(status_code=500, detail=f"An error occurred: {str(e)}") |
@app.post("/identify") |
async def identify_language(request: AudioRequest): |
try: |
input_bytes = base64.b64decode(request.audio) |
audio_array, sample_rate = extract_audio(input_bytes) |
result = identify(audio_array) |
return JSONResponse(content={"language_identification": result}) |
except Exception as e: |
logger.error(f"Error in identify_language: {str(e)}") |
raise HTTPException(status_code=500, detail=f"An error occurred: {str(e)}") |
@app.get("/asr_languages") |
async def get_asr_languages(): |
try: |
return JSONResponse(content=ASR_LANGUAGES) |
except Exception as e: |
logger.error(f"Error in get_asr_languages: {str(e)}") |
raise HTTPException(status_code=500, detail=f"An error occurred: {str(e)}") |
@app.get("/tts_languages") |
async def get_tts_languages(): |
try: |
return JSONResponse(content=TTS_LANGUAGES) |
except Exception as e: |
logger.error(f"Error in get_tts_languages: {str(e)}") |
raise HTTPException(status_code=500, detail=f"An error occurred: {str(e)}") |