Spaces:
Runtime error
Runtime error
from fastapi import FastAPI, File, UploadFile, HTTPException | |
from fastapi.middleware.cors import CORSMiddleware | |
import nemo.collections.asr as nemo_asr | |
import shutil | |
import os | |
from tempfile import NamedTemporaryFile | |
from typing import Dict | |
from pydantic import BaseModel | |
import uvicorn | |
# Dictionary mapping language codes to model names | |
LANGUAGE_MODELS = { | |
"hi": "ai4bharat/indicconformer_stt_hi_hybrid_ctc_rnnt_large", | |
"bn": "ai4bharat/indicconformer_stt_bn_hybrid_ctc_rnnt_large", | |
"ta": "ai4bharat/indicconformer_stt_ta_hybrid_ctc_rnnt_large", | |
# Add more languages and their corresponding models as needed | |
} | |
class TranscriptionResponse(BaseModel): | |
text: str | |
language: str | |
app = FastAPI( | |
title="Indian Languages ASR API", | |
description="API for automatic speech recognition in Indian languages", | |
version="1.0.0", | |
) | |
# Add CORS middleware | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
# Cache for loaded models | |
model_cache = {} | |
def get_model(language: str): | |
""" | |
Get or load the ASR model for the specified language | |
""" | |
if language not in LANGUAGE_MODELS: | |
raise HTTPException( | |
status_code=400, | |
detail=f"Unsupported language: {language}. Supported languages are: {list(LANGUAGE_MODELS.keys())}", | |
) | |
if language not in model_cache: | |
try: | |
model = nemo_asr.models.ASRModel.from_pretrained(LANGUAGE_MODELS[language]) | |
model_cache[language] = model | |
except Exception as e: | |
raise HTTPException( | |
status_code=500, | |
detail=f"Error loading model for language {language}: {str(e)}", | |
) | |
return model_cache[language] | |
async def transcribe_audio( | |
language: str, | |
file: UploadFile = File(...), | |
): | |
""" | |
Transcribe audio file in the specified Indian language | |
Parameters: | |
- language: Language code (e.g., 'hi' for Hindi, 'bn' for Bengali) | |
- file: Audio file in WAV format | |
Returns: | |
- Transcription text and language | |
""" | |
# Validate file format | |
if not file.filename.endswith(".wav"): | |
raise HTTPException(status_code=400, detail="Only WAV files are supported") | |
# Get the appropriate model | |
model = get_model(language) | |
# Save uploaded file temporarily | |
with NamedTemporaryFile(delete=False, suffix=".wav") as temp_file: | |
try: | |
# Copy uploaded file to temporary file | |
shutil.copyfileobj(file.file, temp_file) | |
temp_file.flush() | |
# Perform transcription | |
transcriptions = model.transcribe([temp_file.name]) | |
if not transcriptions or len(transcriptions) == 0: | |
raise HTTPException(status_code=500, detail="Transcription failed") | |
return TranscriptionResponse(text=transcriptions[0], language=language) | |
except Exception as e: | |
raise HTTPException( | |
status_code=500, detail=f"Error during transcription: {str(e)}" | |
) | |
finally: | |
# Clean up temporary file | |
os.unlink(temp_file.name) | |
async def get_supported_languages() -> Dict[str, str]: | |
""" | |
Get list of supported languages and their model names | |
""" | |
return LANGUAGE_MODELS | |
async def health_check(): | |
""" | |
Health check endpoint | |
""" | |
return {"status": "healthy"} | |
if __name__ == "__main__": | |
uvicorn.run(app, host="0.0.0.0", port=8000) | |