Spaces:
Paused
Paused
from contextlib import asynccontextmanager | |
from fastapi import FastAPI, WebSocket, WebSocketDisconnect, UploadFile, File | |
from fastapi.middleware.cors import CORSMiddleware | |
from fastapi.responses import JSONResponse | |
from fastapi.staticfiles import StaticFiles | |
from fastapi.responses import FileResponse | |
import asyncio | |
import logging | |
import os | |
import traceback | |
import argparse | |
import uvicorn | |
import numpy as np | |
import tempfile | |
from core import WhisperLiveKit | |
from audio_processor import AudioProcessor | |
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") | |
logging.getLogger().setLevel(logging.WARNING) | |
logger = logging.getLogger(__name__) | |
logger.setLevel(logging.DEBUG) | |
audio_processor = None | |
async def lifespan(app: FastAPI): | |
global audio_processor | |
kit = WhisperLiveKit() | |
audio_processor = AudioProcessor() | |
yield | |
app = FastAPI(lifespan=lifespan) | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], # Allows all origins | |
allow_credentials=True, | |
allow_methods=["*"], # Allows all methods | |
allow_headers=["*"], # Allows all headers | |
) | |
# Mount static files | |
app.mount("/static", StaticFiles(directory="static"), name="static") | |
async def read_root(): | |
return FileResponse("static/index.html") | |
async def health_check(): | |
return JSONResponse({"status": "healthy"}) | |
async def detect_language(file: UploadFile = File(...)): | |
try: | |
# Use a temporary directory for saving the uploaded file | |
with tempfile.NamedTemporaryFile(delete=False) as temp_file: | |
file_path = temp_file.name | |
contents = await file.read() | |
temp_file.write(contents) | |
# Use the audio processor for language detection | |
if audio_processor: | |
# Detect language using the audio processor | |
detected_lang, confidence, probs = await audio_processor.detect_language(file_path) | |
# Clean up - remove the temporary file | |
os.remove(file_path) | |
return JSONResponse({ | |
"language": detected_lang, | |
"confidence": float(confidence), | |
"probabilities": {lang: float(prob) for lang, prob in probs.items()} | |
}) | |
else: | |
return JSONResponse( | |
{"error": "Audio processor not initialized"}, | |
status_code=500 | |
) | |
except Exception as e: | |
logger.error(f"Error in language detection: {e}") | |
logger.error(f"Traceback: {traceback.format_exc()}") | |
# Clean up in case of error | |
if 'file_path' in locals() and os.path.exists(file_path): | |
os.remove(file_path) | |
return JSONResponse( | |
{"error": str(e)}, | |
status_code=500 | |
) | |
async def handle_websocket_results(websocket, results_generator): | |
"""Consumes results from the audio processor and sends them via WebSocket.""" | |
try: | |
async for response in results_generator: | |
try: | |
logger.debug(f"Sending response: {response}") | |
if isinstance(response, dict): | |
# Ensure the response has a consistent format | |
if 'buffer_transcription' in response: | |
await websocket.send_json({ | |
'buffer_transcription': response['buffer_transcription'] | |
}) | |
elif 'full_transcription' in response: | |
await websocket.send_json({ | |
'full_transcription': response['full_transcription'] | |
}) | |
else: | |
await websocket.send_json(response) | |
else: | |
# If response is not a dict, wrap it in a text field | |
await websocket.send_json({"text": str(response)}) | |
except Exception as e: | |
logger.error(f"Error sending message: {e}") | |
logger.error(f"Traceback: {traceback.format_exc()}") | |
raise | |
except Exception as e: | |
logger.warning(f"Error in WebSocket results handler: {e}") | |
logger.warning(f"Traceback: {traceback.format_exc()}") | |
async def websocket_endpoint(websocket: WebSocket): | |
logger.info("New WebSocket connection request") | |
websocket_task = None | |
try: | |
await websocket.accept() | |
logger.info("WebSocket connection accepted") | |
if not audio_processor: | |
raise RuntimeError("Audio processor not initialized") | |
results_generator = await audio_processor.create_tasks() | |
websocket_task = asyncio.create_task(handle_websocket_results(websocket, results_generator)) | |
while True: | |
try: | |
message = await websocket.receive_bytes() | |
logger.debug(f"Received audio chunk of size: {len(message)}") | |
await audio_processor.process_audio(message) | |
except WebSocketDisconnect: | |
logger.info("WebSocket connection closed") | |
break | |
except Exception as e: | |
logger.error(f"Error processing WebSocket message: {e}") | |
logger.error(f"Traceback: {traceback.format_exc()}") | |
break | |
except Exception as e: | |
logger.error(f"Error in WebSocket endpoint: {e}") | |
logger.error(f"Traceback: {traceback.format_exc()}") | |
finally: | |
if websocket_task: | |
websocket_task.cancel() | |
try: | |
await websocket_task | |
except asyncio.CancelledError: | |
pass | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--host", type=str, default="0.0.0.0", help="Host to run the server on") | |
parser.add_argument("--port", type=int, default=8000, help="Port to run the server on") | |
parser.add_argument("--model", type=str, default="base", help="Whisper model to use") | |
parser.add_argument("--backend", type=str, default="faster-whisper", help="Backend to use") | |
parser.add_argument("--task", type=str, default="transcribe", help="Task to perform") | |
args = parser.parse_args() | |
print(args) | |
uvicorn.run(app, host=args.host, port=args.port) |