|
from operator import itemgetter |
|
import os |
|
from datetime import datetime |
|
import uvicorn |
|
from typing import Any, Optional, Tuple, Dict, TypedDict |
|
from urllib import parse |
|
from uuid import uuid4 |
|
import logging |
|
|
|
|
|
|
|
logging.basicConfig(filename="backend.log", |
|
filemode='w', |
|
format='%(asctime)s,%(msecs)d %(name)s %(levelname)s %(message)s', |
|
datefmt='%H:%M:%S', |
|
level=logging.DEBUG) |
|
|
|
logger = logging.getLogger("socketio_server_pubsub") |
|
logger.propagate = True |
|
|
|
import sys |
|
|
|
|
|
from fastapi import FastAPI |
|
from fastapi.middleware.cors import CORSMiddleware |
|
from pymongo import MongoClient |
|
from dotenv import dotenv_values |
|
from routes import router as api_router |
|
from contextlib import asynccontextmanager |
|
import requests |
|
|
|
from typing import List |
|
from datetime import date |
|
from mongodb.operations.calls import * |
|
from mongodb.models.calls import UserCall, UpdateCall |
|
|
|
|
|
from transformers import AutoProcessor, SeamlessM4Tv2Model |
|
|
|
|
|
from Client import Client |
|
|
|
|
|
|
|
import numpy as np |
|
import torch |
|
|
|
import socketio |
|
|
|
DEBUG = True |
|
|
|
ESCAPE_HATCH_SERVER_LOCK_RELEASE_NAME = "remove_server_lock" |
|
|
|
TARGET_SAMPLING_RATE = 16000 |
|
MAX_BYTES_BUFFER = 480_000 |
|
|
|
print("") |
|
print("") |
|
print("=" * 20 + " ⭐️ Starting Server... ⭐️ " + "=" * 20) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
CLIENT_BUILD_PATH = "../streaming-react-app/dist/" |
|
static_files = { |
|
"/": CLIENT_BUILD_PATH, |
|
"/assets/seamless-db6a2555.svg": { |
|
"filename": CLIENT_BUILD_PATH + "assets/seamless-db6a2555.svg", |
|
"content_type": "image/svg+xml", |
|
}, |
|
} |
|
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") |
|
processor = AutoProcessor.from_pretrained("facebook/seamless-m4t-v2-large") |
|
|
|
model = SeamlessM4Tv2Model.from_pretrained("facebook/seamless-m4t-v2-large").to("cpu") |
|
|
|
config = dotenv_values(".env") |
|
|
|
|
|
uri = os.environ['MONGODB_URI'] |
|
|
|
|
|
|
|
|
|
|
|
@asynccontextmanager |
|
async def lifespan(app: FastAPI): |
|
|
|
app.mongodb_client = MongoClient(uri) |
|
app.database = app.mongodb_client['IT-Cluster1'] |
|
try: |
|
app.mongodb_client.admin.command('ping') |
|
print("MongoDB Connection Established...") |
|
except Exception as e: |
|
print(e) |
|
|
|
yield |
|
|
|
|
|
print("Closing MongoDB Connection...") |
|
app.mongodb_client.close() |
|
|
|
app = FastAPI(lifespan=lifespan, logger=logger) |
|
|
|
|
|
app.add_middleware( |
|
CORSMiddleware, |
|
allow_origins=["*"], |
|
allow_credentials=True, |
|
allow_methods=["*"], |
|
allow_headers=["*"], |
|
) |
|
|
|
app.include_router(api_router) |
|
|
|
|
|
|
|
sio = socketio.AsyncServer( |
|
async_mode="asgi", |
|
cors_allowed_origins="*", |
|
logger=logger, |
|
engineio_logger=logger, |
|
) |
|
|
|
socketio_app = socketio.ASGIApp(sio) |
|
|
|
|
|
from fastapi import APIRouter, Body, Request, status |
|
|
|
bytes_data = bytearray() |
|
model_name = "seamlessM4T_v2_large" |
|
vocoder_name = "vocoder_v2" if model_name == "seamlessM4T_v2_large" else "vocoder_36langs" |
|
|
|
clients = {} |
|
rooms = {} |
|
|
|
|
|
def get_collection_users(): |
|
return app.database["user_records"] |
|
|
|
def get_collection_calls(): |
|
|
|
return app.database["call_test"] |
|
|
|
|
|
@app.get("/test/", response_description="List all existing call records", response_model=List[UserCall]) |
|
def test(): |
|
|
|
result = list_calls(get_collection_calls(), 100) |
|
|
|
|
|
|
|
print(result) |
|
return result |
|
|
|
|
|
@app.put("/test_put/", response_description="List all existing call records", response_model=UserCall) |
|
def test_put(): |
|
|
|
|
|
|
|
result = send_captions("TEST", "TEST", "TEST", "TESTID000001") |
|
|
|
print(result) |
|
return result |
|
|
|
|
|
@app.post("/test_post/", response_description="List all existing call records", response_model=UserCall) |
|
def test_post(): |
|
request_data = { |
|
"call_id": "TESTID000001" |
|
} |
|
|
|
result = create_calls(get_collection_calls(), request_data) |
|
|
|
|
|
return result |
|
|
|
|
|
async def send_translated_text(client_id, original_text, translated_text, room_id): |
|
print('SEND_TRANSLATED_TEXT IS WOKRING IN FASTAPI BACKEND...') |
|
print(rooms) |
|
print(clients) |
|
|
|
data = { |
|
"author": str(client_id), |
|
"original_text": str(original_text), |
|
"translated_text": str(translated_text), |
|
"timestamp": str(datetime.now()) |
|
} |
|
logger.warning("SENDING TRANSLATED TEXT TO CLIENT") |
|
await sio.emit("translated_text", data, room=room_id) |
|
logger.warning("SUCCESSFULLY SEND AUDIO TO FRONTEND") |
|
|
|
@sio.on("connect") |
|
async def connect(sid, environ): |
|
print(f"📥 [event: connected] sid={sid}") |
|
query_params = dict(parse.parse_qsl(environ["QUERY_STRING"])) |
|
client_id = query_params.get("client_id") |
|
logger.info(f"📥 [event: connected] sid={sid}, client_id={client_id}") |
|
|
|
clients[sid] = Client(sid, client_id) |
|
logger.warning(f"Client connected: {sid}") |
|
logger.warning(clients) |
|
|
|
@sio.on("disconnect") |
|
async def disconnect(sid): |
|
logger.debug(f"📤 [event: disconnected] sid={sid}") |
|
clients.pop(sid, None) |
|
|
|
|
|
@sio.on("target_language") |
|
async def target_language(sid, target_lang): |
|
logger.info(f"📥 [event: target_language] sid={sid}, target_lang={target_lang}") |
|
clients[sid].target_language = target_lang |
|
|
|
@sio.on("call_user") |
|
async def call_user(sid, call_id): |
|
clients[sid].call_id = call_id |
|
logger.warning(f"CALL {sid}: entering room {call_id}") |
|
rooms[call_id] = rooms.get(call_id, []) |
|
if sid not in rooms[call_id] and len(rooms[call_id]) < 2: |
|
rooms[call_id].append(sid) |
|
sio.enter_room(sid, call_id) |
|
else: |
|
logger.warning(f"CALL {sid}: room {call_id} is full") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@sio.on("audio_config") |
|
async def audio_config(sid, sample_rate): |
|
clients[sid].original_sr = sample_rate |
|
|
|
|
|
@sio.on("answer_call") |
|
async def answer_call(sid, call_id): |
|
|
|
clients[sid].call_id = call_id |
|
logger.warning(f"ANSWER {sid}: entering room {call_id}") |
|
rooms[call_id] = rooms.get(call_id, []) |
|
if sid not in rooms[call_id] and len(rooms[call_id]) < 2: |
|
rooms[call_id].append(sid) |
|
sio.enter_room(sid, call_id) |
|
else: |
|
logger.warning(f"ANSWER {sid}: room {call_id} is full") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@sio.on("incoming_audio") |
|
async def incoming_audio(sid, data, call_id): |
|
try: |
|
clients[sid].add_bytes(data) |
|
|
|
if clients[sid].get_length() >= MAX_BYTES_BUFFER: |
|
logger.warning('Buffer full, now outputting...') |
|
output_path = clients[sid].output_path |
|
vad_result, resampled_audio = clients[sid].resample_and_write_to_file() |
|
|
|
src_lang = clients[sid].target_language |
|
if vad_result: |
|
logger.warning('Speech detected, now processing audio.....') |
|
tgt_sid = next(id for id in rooms[call_id] if id != sid) |
|
tgt_lang = clients[tgt_sid].target_language |
|
|
|
output_tokens = processor(audios=resampled_audio, src_lang=src_lang, return_tensors="pt") |
|
model_output = model.generate(**output_tokens, tgt_lang=src_lang, generate_speech=False)[0].tolist()[0] |
|
asr_text = processor.decode(model_output, skip_special_tokens=True) |
|
print(f"ASR TEXT = {asr_text}") |
|
|
|
|
|
t2t_tokens = processor(text=asr_text, src_lang=src_lang, tgt_lang=tgt_lang, return_tensors="pt") |
|
print(f"FIRST TYPE = {type(output_tokens)}, SECOND TYPE = {type(t2t_tokens)}") |
|
translated_data = model.generate(**t2t_tokens, tgt_lang=tgt_lang, generate_speech=False)[0].tolist()[0] |
|
translated_text = processor.decode(translated_data, skip_special_tokens=True) |
|
print(f"TRANSLATED TEXT = {translated_text}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
await send_translated_text(clients[sid].client_id, asr_text, translated_text, call_id) |
|
|
|
|
|
|
|
|
|
except Exception as e: |
|
logger.error(f"Error in incoming_audio: {e.with_traceback()}") |
|
|
|
def send_captions(client_id, original_text, translated_text, call_id): |
|
|
|
print(f"Now updating Caption field in call record for Caller with ID: {client_id} for call: {call_id}") |
|
|
|
data = { |
|
"author": str(client_id), |
|
"original_text": str(original_text), |
|
"translated_text": str(translated_text), |
|
"timestamp": str(datetime.now()) |
|
} |
|
|
|
response = update_captions(get_collection_calls(), call_id, data) |
|
return response |
|
|
|
app.mount("/", socketio_app) |
|
|
|
if __name__ == '__main__': |
|
uvicorn.run("main:app", host='127.0.0.1', port=8080, log_level="info") |
|
|
|
|