Spaces:
Paused
Paused
# from operator import itemgetter | |
# import os | |
# from datetime import datetime | |
# import uvicorn | |
# from typing import Any, Optional, Tuple, Dict, TypedDict | |
# from urllib import parse | |
# from uuid import uuid4 | |
# import logging | |
# from fastapi.logger import logger as fastapi_logger | |
# import sys | |
# # sys.path.append('/Users/benolojo/DCU/CA4/ca400_FinalYearProject/2024-ca400-olojob2-majdap2/src/backend/') | |
# from fastapi import FastAPI | |
# from fastapi.middleware.cors import CORSMiddleware | |
# from fastapi import APIRouter, Body, Request, status | |
# from pymongo import MongoClient | |
# from dotenv import dotenv_values | |
# from routes import router as api_router | |
# from contextlib import asynccontextmanager | |
# import requests | |
# from typing import List | |
# from datetime import date | |
# from mongodb.operations.calls import * | |
# from mongodb.models.calls import UserCall, UpdateCall | |
# # from mongodb.endpoints.calls import * | |
# # from transformers import AutoProcessor, SeamlessM4Tv2Model | |
# # from seamless_communication.inference import Translator | |
# from Client import Client | |
# #---------------------------------- | |
# # base seamless imports | |
# # --------------------------------- | |
# import numpy as np | |
# import torch | |
# # --------------------------------- | |
# import socketio | |
# ############################################### | |
# # Configure logger | |
# gunicorn_error_logger = logging.getLogger("gunicorn.error") | |
# gunicorn_logger = logging.getLogger("gunicorn") | |
# uvicorn_access_logger = logging.getLogger("uvicorn.access") | |
# gunicorn_error_logger.propagate = True | |
# gunicorn_logger.propagate = True | |
# uvicorn_access_logger.propagate = True | |
# uvicorn_access_logger.handlers = gunicorn_error_logger.handlers | |
# fastapi_logger.handlers = gunicorn_error_logger.handlers | |
# ############################################### | |
# # sio is the main socket.io entrypoint | |
# sio = socketio.AsyncServer( | |
# async_mode="asgi", | |
# cors_allowed_origins="*", | |
# logger=gunicorn_logger, | |
# engineio_logger=gunicorn_logger, | |
# ) | |
# # sio.logger.setLevel(logging.DEBUG) | |
# socketio_app = socketio.ASGIApp(sio) | |
# # app.mount("/", socketio_app) | |
# config = dotenv_values(".env") | |
# # Read connection string from environment vars | |
# # uri = os.environ['MONGODB_URI'] | |
# # Read connection string from .env file | |
# uri = config['MONGODB_URI'] | |
# # Set transformers cache | |
# # os.environ['HF_HOME'] = './.cache/' | |
# # os.environ['SENTENCE_TRANSFORMERS_HOME'] = './.cache' | |
# # MongoDB Connection Lifespan Events | |
# @asynccontextmanager | |
# async def lifespan(app: FastAPI): | |
# # startup logic | |
# app.mongodb_client = MongoClient(uri) | |
# app.database = app.mongodb_client['IT-Cluster1'] #connect to interpretalk primary db | |
# try: | |
# app.mongodb_client.admin.command('ping') | |
# print("MongoDB Connection Established...") | |
# except Exception as e: | |
# print(e) | |
# yield | |
# # shutdown logic | |
# print("Closing MongoDB Connection...") | |
# app.mongodb_client.close() | |
# app = FastAPI(lifespan=lifespan, logger=gunicorn_logger) | |
# # New CORS funcitonality | |
# app.add_middleware( | |
# CORSMiddleware, | |
# allow_origins=["*"], # configured node app port | |
# allow_credentials=True, | |
# allow_methods=["*"], | |
# allow_headers=["*"], | |
# ) | |
# app.include_router(api_router) # include routers for user, calls and transcripts operations | |
# DEBUG = True | |
# ESCAPE_HATCH_SERVER_LOCK_RELEASE_NAME = "remove_server_lock" | |
# TARGET_SAMPLING_RATE = 16000 | |
# MAX_BYTES_BUFFER = 480_000 | |
# print("") | |
# print("") | |
# print("=" * 20 + " ⭐️ Starting Server... ⭐️ " + "=" * 20) | |
# ############################################### | |
# # Configure socketio server | |
# ############################################### | |
# # TODO PM - change this to the actual path | |
# # seamless remnant code | |
# # CLIENT_BUILD_PATH = "../streaming-react-app/dist/" | |
# # static_files = { | |
# # "/": CLIENT_BUILD_PATH, | |
# # "/assets/seamless-db6a2555.svg": { | |
# # "filename": CLIENT_BUILD_PATH + "assets/seamless-db6a2555.svg", | |
# # "content_type": "image/svg+xml", | |
# # }, | |
# # } | |
# device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
# # processor = AutoProcessor.from_pretrained("facebook/seamless-m4t-v2-large", force_download=True) | |
# #cache_dir="/.cache" | |
# # PM - hardcoding temporarily as my GPU doesnt have enough vram | |
# # model = SeamlessM4Tv2Model.from_pretrained("facebook/seamless-m4t-v2-large").to("cpu") | |
# # model = SeamlessM4Tv2Model.from_pretrained("facebook/seamless-m4t-v2-large", force_download=True).to(device) | |
# bytes_data = bytearray() | |
# model_name = "seamlessM4T_v2_large" | |
# # vocoder_name = "vocoder_v2" if model_name == "seamlessM4T_v2_large" else "vocoder_36langs" | |
# clients = {} | |
# rooms = {} | |
# import torch | |
# from transformers import pipeline | |
# translator = pipeline("automatic-speech-recognition", | |
# "facebook/seamless-m4t-v2-large", | |
# torch_dtype=torch.float32, | |
# device="cpu") | |
# converter = pipeline("translation", | |
# "facebook/seamless-m4t-v2-large", | |
# torch_dtype=torch.float32, | |
# device="cpu") | |
# def get_collection_users(): | |
# return app.database["user_records"] | |
# def get_collection_calls(): | |
# # return app.database["call_records"] | |
# return app.database["call_test"] | |
# @app.get("/test/", response_description="Welcome User") | |
# def test(): | |
# return {"message": "Welcome to InterpreTalk!"} | |
# @app.post("/test_post/", response_description="List more test call records") | |
# def test_post(): | |
# request_data = { | |
# "call_id": "TESTID000001" | |
# } | |
# result = create_calls(get_collection_calls(), request_data) | |
# # return {"message": "Welcome to InterpreTalk!"} | |
# return result | |
# @app.put("/test_put/", response_description="List test call records") | |
# def test_put(): | |
# # result = list_calls(get_collection_calls(), 100) | |
# # result = send_captions("TEST", "TEST", "TEST", "oUjUxTYTQFVVjEarIcZ0") | |
# result = send_captions("TEST", "TEST", "TEST", "TESTID000001") | |
# print(result) | |
# return result | |
# async def send_translated_text(client_id, original_text, translated_text, room_id): | |
# print('SEND_TRANSLATED_TEXT IS WOKRING IN FASTAPI BACKEND...') | |
# print(rooms) | |
# print(clients) | |
# data = { | |
# "author": str(client_id), | |
# "original_text": str(original_text), | |
# "translated_text": str(translated_text), | |
# "timestamp": str(datetime.now()) | |
# } | |
# gunicorn_logger.info("SENDING TRANSLATED TEXT TO CLIENT") | |
# await sio.emit("translated_text", data, room=room_id) | |
# gunicorn_logger.info("SUCCESSFULLY SEND AUDIO TO FRONTEND") | |
# @sio.on("connect") | |
# async def connect(sid, environ): | |
# print(f"📥 [event: connected] sid={sid}") | |
# query_params = dict(parse.parse_qsl(environ["QUERY_STRING"])) | |
# client_id = query_params.get("client_id") | |
# gunicorn_logger.info(f"📥 [event: connected] sid={sid}, client_id={client_id}") | |
# # sid = socketid, client_id = client specific ID ,always the same for same user | |
# clients[sid] = Client(sid, client_id) | |
# gunicorn_logger.warning(f"Client connected: {sid}") | |
# gunicorn_logger.warning(clients) | |
# @sio.on("disconnect") | |
# async def disconnect(sid): # BO - also pass call id as parameter for updating MongoDB | |
# gunicorn_logger.debug(f"📤 [event: disconnected] sid={sid}") | |
# clients.pop(sid, None) | |
# # BO -> Update Call record with call duration, key terms | |
# @sio.on("target_language") | |
# async def target_language(sid, target_lang): | |
# gunicorn_logger.info(f"📥 [event: target_language] sid={sid}, target_lang={target_lang}") | |
# clients[sid].target_language = target_lang | |
# @sio.on("call_user") | |
# async def call_user(sid, call_id): | |
# clients[sid].call_id = call_id | |
# gunicorn_logger.info(f"CALL {sid}: entering room {call_id}") | |
# rooms[call_id] = rooms.get(call_id, []) | |
# if sid not in rooms[call_id] and len(rooms[call_id]) < 2: | |
# rooms[call_id].append(sid) | |
# sio.enter_room(sid, call_id) | |
# else: | |
# gunicorn_logger.info(f"CALL {sid}: room {call_id} is full") | |
# # await sio.emit("room_full", room=call_id, to=sid) | |
# # # BO - Get call id from dictionary created during socketio connection | |
# # client_id = clients[sid].client_id | |
# # gunicorn_logger.warning(f"NOW TRYING TO CREATE DB RECORD FOR Caller with ID: {client_id} for call: {call_id}") | |
# # # # BO -> Create Call Record with Caller and call_id field (None for callee, duration, terms..) | |
# # request_data = { | |
# # "call_id": str(call_id), | |
# # "caller_id": str(client_id), | |
# # "creation_date": str(datetime.now()) | |
# # } | |
# # response = create_calls(get_collection_calls(), request_data) | |
# # print(response) # BO - print created db call record | |
# @sio.on("audio_config") | |
# async def audio_config(sid, sample_rate): | |
# clients[sid].original_sr = sample_rate | |
# @sio.on("answer_call") | |
# async def answer_call(sid, call_id): | |
# clients[sid].call_id = call_id | |
# gunicorn_logger.info(f"ANSWER {sid}: entering room {call_id}") | |
# rooms[call_id] = rooms.get(call_id, []) | |
# if sid not in rooms[call_id] and len(rooms[call_id]) < 2: | |
# rooms[call_id].append(sid) | |
# sio.enter_room(sid, call_id) | |
# else: | |
# gunicorn_logger.info(f"ANSWER {sid}: room {call_id} is full") | |
# # await sio.emit("room_full", room=call_id, to=sid) | |
# # # BO - Get call id from dictionary created during socketio connection | |
# # client_id = clients[sid].client_id | |
# # # BO -> Update Call Record with Callee field based on call_id | |
# # gunicorn_logger.warning(f"NOW UPDATING MongoDB RECORD FOR Caller with ID: {client_id} for call: {call_id}") | |
# # # # BO -> Create Call Record with callee_id field (None for callee, duration, terms..) | |
# # request_data = { | |
# # "callee_id": client_id | |
# # } | |
# # response = update_calls(get_collection_calls(), call_id, request_data) | |
# # print(response) # BO - print created db call record | |
# @sio.on("incoming_audio") | |
# async def incoming_audio(sid, data, call_id): | |
# try: | |
# clients[sid].add_bytes(data) | |
# if clients[sid].get_length() >= MAX_BYTES_BUFFER: | |
# gunicorn_logger.info('Buffer full, now outputting...') | |
# output_path = clients[sid].output_path | |
# vad_result, resampled_audio = clients[sid].resample_and_write_to_file() | |
# # source lang is speakers tgt language 😃 | |
# src_lang = clients[sid].target_language | |
# if vad_result: | |
# gunicorn_logger.info('Speech detected, now processing audio.....') | |
# tgt_sid = next(id for id in rooms[call_id] if id != sid) | |
# tgt_lang = clients[tgt_sid].target_language | |
# # following example from https://github.com/facebookresearch/seamless_communication/blob/main/docs/m4t/README.md#transformers-usage | |
# # output_tokens = processor(audios=resampled_audio, src_lang=src_lang, return_tensors="pt") | |
# # model_output = model.generate(**output_tokens, tgt_lang=src_lang, generate_speech=False)[0].tolist()[0] | |
# # asr_text = processor.decode(model_output, skip_special_tokens=True) | |
# asr_text = translator(resampled_audio, generate_kwargs={"tgt_lang": src_lang})['text'] | |
# print(f"ASR TEXT = {asr_text}") | |
# # ASR TEXT => ORIGINAL TEXT | |
# # t2t_tokens = processor(text=asr_text, src_lang=src_lang, tgt_lang=tgt_lang, return_tensors="pt") | |
# # print(f"FIRST TYPE = {type(output_tokens)}, SECOND TYPE = {type(t2t_tokens)}") | |
# # translated_data = model.generate(**t2t_tokens, tgt_lang=tgt_lang, generate_speech=False)[0].tolist()[0] | |
# # translated_text = processor.decode(translated_data, skip_special_tokens=True) | |
# translated_text = converter(asr_text, src_lang=src_lang, tgt_lang=tgt_lang) | |
# print(f"TRANSLATED TEXT = {translated_text}") | |
# # BO -> send translated_text to mongodb as caption record update based on call_id | |
# # send_captions(clients[sid].client_id, asr_text, translated_text, call_id) | |
# # TRANSLATED TEXT | |
# # PM - text_output is a list with 1 string | |
# await send_translated_text(clients[sid].client_id, asr_text, translated_text, call_id) | |
# # # BO -> send translated_text to mongodb as caption record update based on call_id | |
# # send_captions(clients[sid].client_id, asr_text, translated_text, call_id) | |
# except Exception as e: | |
# gunicorn_logger.error(f"Error in incoming_audio: {e.with_traceback()}") | |
# def send_captions(client_id, original_text, translated_text, call_id): | |
# # BO -> Update Call Record with Callee field based on call_id | |
# print(f"Now updating Caption field in call record for Caller with ID: {client_id} for call: {call_id}") | |
# data = { | |
# "author": str(client_id), | |
# "original_text": str(original_text), | |
# "translated_text": str(translated_text), | |
# "timestamp": str(datetime.now()) | |
# } | |
# response = update_captions(get_collection_calls(), call_id, data) | |
# return response | |
# app.mount("/", socketio_app) | |
# if __name__ == '__main__': | |
# uvicorn.run("main:app", host='0.0.0.0', port=7860, log_level="debug") | |
# # Running in Docker Container | |
# if __name__ != "__main__": | |
# fastapi_logger.setLevel(gunicorn_logger.level) | |
# else: | |
# fastapi_logger.setLevel(logging.DEBUG) | |
from huggingface_hub import scan_cache_dir | |
hf_cache_info = scan_cache_dir() |