benjolo's picture
Update backend/main.py
9de3c31 verified
raw
history blame
13.8 kB
# from operator import itemgetter
# import os
# from datetime import datetime
# import uvicorn
# from typing import Any, Optional, Tuple, Dict, TypedDict
# from urllib import parse
# from uuid import uuid4
# import logging
# from fastapi.logger import logger as fastapi_logger
# import sys
# # sys.path.append('/Users/benolojo/DCU/CA4/ca400_FinalYearProject/2024-ca400-olojob2-majdap2/src/backend/')
# from fastapi import FastAPI
# from fastapi.middleware.cors import CORSMiddleware
# from fastapi import APIRouter, Body, Request, status
# from pymongo import MongoClient
# from dotenv import dotenv_values
# from routes import router as api_router
# from contextlib import asynccontextmanager
# import requests
# from typing import List
# from datetime import date
# from mongodb.operations.calls import *
# from mongodb.models.calls import UserCall, UpdateCall
# # from mongodb.endpoints.calls import *
# # from transformers import AutoProcessor, SeamlessM4Tv2Model
# # from seamless_communication.inference import Translator
# from Client import Client
# #----------------------------------
# # base seamless imports
# # ---------------------------------
# import numpy as np
# import torch
# # ---------------------------------
# import socketio
# ###############################################
# # Configure logger
# gunicorn_error_logger = logging.getLogger("gunicorn.error")
# gunicorn_logger = logging.getLogger("gunicorn")
# uvicorn_access_logger = logging.getLogger("uvicorn.access")
# gunicorn_error_logger.propagate = True
# gunicorn_logger.propagate = True
# uvicorn_access_logger.propagate = True
# uvicorn_access_logger.handlers = gunicorn_error_logger.handlers
# fastapi_logger.handlers = gunicorn_error_logger.handlers
# ###############################################
# # sio is the main socket.io entrypoint
# sio = socketio.AsyncServer(
# async_mode="asgi",
# cors_allowed_origins="*",
# logger=gunicorn_logger,
# engineio_logger=gunicorn_logger,
# )
# # sio.logger.setLevel(logging.DEBUG)
# socketio_app = socketio.ASGIApp(sio)
# # app.mount("/", socketio_app)
# config = dotenv_values(".env")
# # Read connection string from environment vars
# # uri = os.environ['MONGODB_URI']
# # Read connection string from .env file
# uri = config['MONGODB_URI']
# # Set transformers cache
# # os.environ['HF_HOME'] = './.cache/'
# # os.environ['SENTENCE_TRANSFORMERS_HOME'] = './.cache'
# # MongoDB Connection Lifespan Events
# @asynccontextmanager
# async def lifespan(app: FastAPI):
# # startup logic
# app.mongodb_client = MongoClient(uri)
# app.database = app.mongodb_client['IT-Cluster1'] #connect to interpretalk primary db
# try:
# app.mongodb_client.admin.command('ping')
# print("MongoDB Connection Established...")
# except Exception as e:
# print(e)
# yield
# # shutdown logic
# print("Closing MongoDB Connection...")
# app.mongodb_client.close()
# app = FastAPI(lifespan=lifespan, logger=gunicorn_logger)
# # New CORS funcitonality
# app.add_middleware(
# CORSMiddleware,
# allow_origins=["*"], # configured node app port
# allow_credentials=True,
# allow_methods=["*"],
# allow_headers=["*"],
# )
# app.include_router(api_router) # include routers for user, calls and transcripts operations
# DEBUG = True
# ESCAPE_HATCH_SERVER_LOCK_RELEASE_NAME = "remove_server_lock"
# TARGET_SAMPLING_RATE = 16000
# MAX_BYTES_BUFFER = 480_000
# print("")
# print("")
# print("=" * 20 + " ⭐️ Starting Server... ⭐️ " + "=" * 20)
# ###############################################
# # Configure socketio server
# ###############################################
# # TODO PM - change this to the actual path
# # seamless remnant code
# # CLIENT_BUILD_PATH = "../streaming-react-app/dist/"
# # static_files = {
# # "/": CLIENT_BUILD_PATH,
# # "/assets/seamless-db6a2555.svg": {
# # "filename": CLIENT_BUILD_PATH + "assets/seamless-db6a2555.svg",
# # "content_type": "image/svg+xml",
# # },
# # }
# device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# # processor = AutoProcessor.from_pretrained("facebook/seamless-m4t-v2-large", force_download=True)
# #cache_dir="/.cache"
# # PM - hardcoding temporarily as my GPU doesnt have enough vram
# # model = SeamlessM4Tv2Model.from_pretrained("facebook/seamless-m4t-v2-large").to("cpu")
# # model = SeamlessM4Tv2Model.from_pretrained("facebook/seamless-m4t-v2-large", force_download=True).to(device)
# bytes_data = bytearray()
# model_name = "seamlessM4T_v2_large"
# # vocoder_name = "vocoder_v2" if model_name == "seamlessM4T_v2_large" else "vocoder_36langs"
# clients = {}
# rooms = {}
# import torch
# from transformers import pipeline
# translator = pipeline("automatic-speech-recognition",
# "facebook/seamless-m4t-v2-large",
# torch_dtype=torch.float32,
# device="cpu")
# converter = pipeline("translation",
# "facebook/seamless-m4t-v2-large",
# torch_dtype=torch.float32,
# device="cpu")
# def get_collection_users():
# return app.database["user_records"]
# def get_collection_calls():
# # return app.database["call_records"]
# return app.database["call_test"]
# @app.get("/test/", response_description="Welcome User")
# def test():
# return {"message": "Welcome to InterpreTalk!"}
# @app.post("/test_post/", response_description="List more test call records")
# def test_post():
# request_data = {
# "call_id": "TESTID000001"
# }
# result = create_calls(get_collection_calls(), request_data)
# # return {"message": "Welcome to InterpreTalk!"}
# return result
# @app.put("/test_put/", response_description="List test call records")
# def test_put():
# # result = list_calls(get_collection_calls(), 100)
# # result = send_captions("TEST", "TEST", "TEST", "oUjUxTYTQFVVjEarIcZ0")
# result = send_captions("TEST", "TEST", "TEST", "TESTID000001")
# print(result)
# return result
# async def send_translated_text(client_id, original_text, translated_text, room_id):
# print('SEND_TRANSLATED_TEXT IS WOKRING IN FASTAPI BACKEND...')
# print(rooms)
# print(clients)
# data = {
# "author": str(client_id),
# "original_text": str(original_text),
# "translated_text": str(translated_text),
# "timestamp": str(datetime.now())
# }
# gunicorn_logger.info("SENDING TRANSLATED TEXT TO CLIENT")
# await sio.emit("translated_text", data, room=room_id)
# gunicorn_logger.info("SUCCESSFULLY SEND AUDIO TO FRONTEND")
# @sio.on("connect")
# async def connect(sid, environ):
# print(f"📥 [event: connected] sid={sid}")
# query_params = dict(parse.parse_qsl(environ["QUERY_STRING"]))
# client_id = query_params.get("client_id")
# gunicorn_logger.info(f"📥 [event: connected] sid={sid}, client_id={client_id}")
# # sid = socketid, client_id = client specific ID ,always the same for same user
# clients[sid] = Client(sid, client_id)
# gunicorn_logger.warning(f"Client connected: {sid}")
# gunicorn_logger.warning(clients)
# @sio.on("disconnect")
# async def disconnect(sid): # BO - also pass call id as parameter for updating MongoDB
# gunicorn_logger.debug(f"📤 [event: disconnected] sid={sid}")
# clients.pop(sid, None)
# # BO -> Update Call record with call duration, key terms
# @sio.on("target_language")
# async def target_language(sid, target_lang):
# gunicorn_logger.info(f"📥 [event: target_language] sid={sid}, target_lang={target_lang}")
# clients[sid].target_language = target_lang
# @sio.on("call_user")
# async def call_user(sid, call_id):
# clients[sid].call_id = call_id
# gunicorn_logger.info(f"CALL {sid}: entering room {call_id}")
# rooms[call_id] = rooms.get(call_id, [])
# if sid not in rooms[call_id] and len(rooms[call_id]) < 2:
# rooms[call_id].append(sid)
# sio.enter_room(sid, call_id)
# else:
# gunicorn_logger.info(f"CALL {sid}: room {call_id} is full")
# # await sio.emit("room_full", room=call_id, to=sid)
# # # BO - Get call id from dictionary created during socketio connection
# # client_id = clients[sid].client_id
# # gunicorn_logger.warning(f"NOW TRYING TO CREATE DB RECORD FOR Caller with ID: {client_id} for call: {call_id}")
# # # # BO -> Create Call Record with Caller and call_id field (None for callee, duration, terms..)
# # request_data = {
# # "call_id": str(call_id),
# # "caller_id": str(client_id),
# # "creation_date": str(datetime.now())
# # }
# # response = create_calls(get_collection_calls(), request_data)
# # print(response) # BO - print created db call record
# @sio.on("audio_config")
# async def audio_config(sid, sample_rate):
# clients[sid].original_sr = sample_rate
# @sio.on("answer_call")
# async def answer_call(sid, call_id):
# clients[sid].call_id = call_id
# gunicorn_logger.info(f"ANSWER {sid}: entering room {call_id}")
# rooms[call_id] = rooms.get(call_id, [])
# if sid not in rooms[call_id] and len(rooms[call_id]) < 2:
# rooms[call_id].append(sid)
# sio.enter_room(sid, call_id)
# else:
# gunicorn_logger.info(f"ANSWER {sid}: room {call_id} is full")
# # await sio.emit("room_full", room=call_id, to=sid)
# # # BO - Get call id from dictionary created during socketio connection
# # client_id = clients[sid].client_id
# # # BO -> Update Call Record with Callee field based on call_id
# # gunicorn_logger.warning(f"NOW UPDATING MongoDB RECORD FOR Caller with ID: {client_id} for call: {call_id}")
# # # # BO -> Create Call Record with callee_id field (None for callee, duration, terms..)
# # request_data = {
# # "callee_id": client_id
# # }
# # response = update_calls(get_collection_calls(), call_id, request_data)
# # print(response) # BO - print created db call record
# @sio.on("incoming_audio")
# async def incoming_audio(sid, data, call_id):
# try:
# clients[sid].add_bytes(data)
# if clients[sid].get_length() >= MAX_BYTES_BUFFER:
# gunicorn_logger.info('Buffer full, now outputting...')
# output_path = clients[sid].output_path
# vad_result, resampled_audio = clients[sid].resample_and_write_to_file()
# # source lang is speakers tgt language 😃
# src_lang = clients[sid].target_language
# if vad_result:
# gunicorn_logger.info('Speech detected, now processing audio.....')
# tgt_sid = next(id for id in rooms[call_id] if id != sid)
# tgt_lang = clients[tgt_sid].target_language
# # following example from https://github.com/facebookresearch/seamless_communication/blob/main/docs/m4t/README.md#transformers-usage
# # output_tokens = processor(audios=resampled_audio, src_lang=src_lang, return_tensors="pt")
# # model_output = model.generate(**output_tokens, tgt_lang=src_lang, generate_speech=False)[0].tolist()[0]
# # asr_text = processor.decode(model_output, skip_special_tokens=True)
# asr_text = translator(resampled_audio, generate_kwargs={"tgt_lang": src_lang})['text']
# print(f"ASR TEXT = {asr_text}")
# # ASR TEXT => ORIGINAL TEXT
# # t2t_tokens = processor(text=asr_text, src_lang=src_lang, tgt_lang=tgt_lang, return_tensors="pt")
# # print(f"FIRST TYPE = {type(output_tokens)}, SECOND TYPE = {type(t2t_tokens)}")
# # translated_data = model.generate(**t2t_tokens, tgt_lang=tgt_lang, generate_speech=False)[0].tolist()[0]
# # translated_text = processor.decode(translated_data, skip_special_tokens=True)
# translated_text = converter(asr_text, src_lang=src_lang, tgt_lang=tgt_lang)
# print(f"TRANSLATED TEXT = {translated_text}")
# # BO -> send translated_text to mongodb as caption record update based on call_id
# # send_captions(clients[sid].client_id, asr_text, translated_text, call_id)
# # TRANSLATED TEXT
# # PM - text_output is a list with 1 string
# await send_translated_text(clients[sid].client_id, asr_text, translated_text, call_id)
# # # BO -> send translated_text to mongodb as caption record update based on call_id
# # send_captions(clients[sid].client_id, asr_text, translated_text, call_id)
# except Exception as e:
# gunicorn_logger.error(f"Error in incoming_audio: {e.with_traceback()}")
# def send_captions(client_id, original_text, translated_text, call_id):
# # BO -> Update Call Record with Callee field based on call_id
# print(f"Now updating Caption field in call record for Caller with ID: {client_id} for call: {call_id}")
# data = {
# "author": str(client_id),
# "original_text": str(original_text),
# "translated_text": str(translated_text),
# "timestamp": str(datetime.now())
# }
# response = update_captions(get_collection_calls(), call_id, data)
# return response
# app.mount("/", socketio_app)
# if __name__ == '__main__':
# uvicorn.run("main:app", host='0.0.0.0', port=7860, log_level="debug")
# # Running in Docker Container
# if __name__ != "__main__":
# fastapi_logger.setLevel(gunicorn_logger.level)
# else:
# fastapi_logger.setLevel(logging.DEBUG)
from huggingface_hub import scan_cache_dir
hf_cache_info = scan_cache_dir()