File size: 6,389 Bytes
72277b5
13db51f
72277b5
 
 
 
 
 
 
 
 
 
13db51f
0b11366
72277b5
 
 
 
 
 
 
 
 
eca4b03
72277b5
 
 
eca4b03
2cd6c79
eca4b03
72277b5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13db51f
 
 
0b11366
 
 
db0f21a
0b11366
db0f21a
eca4b03
 
7db1cf9
 
db0f21a
 
 
13db51f
 
 
f6d1a77
 
13db51f
 
 
eca4b03
13db51f
 
 
 
 
 
db0f21a
 
 
13db51f
 
 
 
 
72277b5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
eca4b03
 
 
72277b5
 
 
 
 
 
 
 
 
db0f21a
72277b5
 
db0f21a
72277b5
 
 
 
 
 
 
 
 
 
 
 
 
 
db0f21a
 
 
 
 
 
 
72277b5
 
763a8af
 
db0f21a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
from contextlib import asynccontextmanager
from fastapi import FastAPI, WebSocket, WebSocketDisconnect, UploadFile, File
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from fastapi.staticfiles import StaticFiles
from fastapi.responses import FileResponse
import asyncio
import logging
import os
import traceback
import argparse
import uvicorn
import numpy as np
import tempfile

from core import WhisperLiveKit
from audio_processor import AudioProcessor

logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
logging.getLogger().setLevel(logging.WARNING)
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)

audio_processor = None

@asynccontextmanager
async def lifespan(app: FastAPI):
    global audio_processor
    kit = WhisperLiveKit(args)
    audio_processor = AudioProcessor()
    yield

app = FastAPI(lifespan=lifespan)

app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],  # Allows all origins
    allow_credentials=True,
    allow_methods=["*"],  # Allows all methods
    allow_headers=["*"],  # Allows all headers
)

# Mount static files
app.mount("/static", StaticFiles(directory="static"), name="static")

@app.get("/")
async def read_root():
    return FileResponse("static/index.html")

@app.get("/health")
async def health_check():
    return JSONResponse({"status": "healthy"})

@app.post("/detect-language")
async def detect_language(file: UploadFile = File(...)):
    try:
        # Use a temporary directory for saving the uploaded file
        with tempfile.NamedTemporaryFile(delete=False) as temp_file:
            file_path = temp_file.name
            contents = await file.read()
            temp_file.write(contents)
        
        # Use the audio processor for language detection
        if audio_processor:
            # Detect language using the audio processor
            detected_lang, confidence, probs = await audio_processor.detect_language(file_path)
            
            # Clean up - remove the temporary file
            os.remove(file_path)
            
            return JSONResponse({
                "language": detected_lang,
                "confidence": float(confidence),
                "probabilities": {lang: float(prob) for lang, prob in probs.items()}
            })
        else:
            return JSONResponse(
                {"error": "Audio processor not initialized"},
                status_code=500
            )
            
    except Exception as e:
        logger.error(f"Error in language detection: {e}")
        logger.error(f"Traceback: {traceback.format_exc()}")
        # Clean up in case of error
        if 'file_path' in locals() and os.path.exists(file_path):
            os.remove(file_path)
        return JSONResponse(
            {"error": str(e)},
            status_code=500
        )

async def handle_websocket_results(websocket, results_generator):
    """Consumes results from the audio processor and sends them via WebSocket."""
    try:
        async for response in results_generator:
            try:
                logger.debug(f"Sending response: {response}")
                if isinstance(response, dict):
                    # Ensure the response has a consistent format
                    if 'buffer_transcription' in response:
                        await websocket.send_json({
                            'buffer_transcription': response['buffer_transcription']
                        })
                    elif 'full_transcription' in response:
                        await websocket.send_json({
                            'full_transcription': response['full_transcription']
                        })
                    else:
                        await websocket.send_json(response)
                else:
                    # If response is not a dict, wrap it in a text field
                    await websocket.send_json({"text": str(response)})
            except Exception as e:
                logger.error(f"Error sending message: {e}")
                logger.error(f"Traceback: {traceback.format_exc()}")
                raise
    except Exception as e:
        logger.warning(f"Error in WebSocket results handler: {e}")
        logger.warning(f"Traceback: {traceback.format_exc()}")

@app.websocket("/asr")
async def websocket_endpoint(websocket: WebSocket):
    logger.info("New WebSocket connection request")
    websocket_task = None

    try:
        await websocket.accept()
        logger.info("WebSocket connection accepted")
        
        if not audio_processor:
            raise RuntimeError("Audio processor not initialized")
            
        results_generator = await audio_processor.create_tasks()
        websocket_task = asyncio.create_task(handle_websocket_results(websocket, results_generator))

        while True:
            try:
                message = await websocket.receive_bytes()
                logger.debug(f"Received audio chunk of size: {len(message)}")
                await audio_processor.process_audio(message)
            except WebSocketDisconnect:
                logger.info("WebSocket connection closed")
                break
            except Exception as e:
                logger.error(f"Error processing WebSocket message: {e}")
                logger.error(f"Traceback: {traceback.format_exc()}")
                break

    except Exception as e:
        logger.error(f"Error in WebSocket endpoint: {e}")
        logger.error(f"Traceback: {traceback.format_exc()}")
    finally:
        if websocket_task:
            websocket_task.cancel()
            try:
                await websocket_task
            except asyncio.CancelledError:
                pass

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--host", type=str, default="0.0.0.0", help="Host to run the server on")
    parser.add_argument("--port", type=int, default=8000, help="Port to run the server on")
    parser.add_argument("--model", type=str, default="base", help="Whisper model to use")
    parser.add_argument("--backend", type=str, default="faster-whisper", help="Backend to use")
    parser.add_argument("--task", type=str, default="transcribe", help="Task to perform")
    args = parser.parse_args()

    print(args)

    uvicorn.run(app, host=args.host, port=args.port)