Spaces:
Paused
Paused
Update backend/main.py
Browse files- backend/main.py +50 -52
backend/main.py
CHANGED
@@ -24,6 +24,8 @@ from mongodb.operations.calls import *
|
|
24 |
from mongodb.models.calls import UserCall, UpdateCall
|
25 |
# from mongodb.endpoints.calls import *
|
26 |
|
|
|
|
|
27 |
from transformers import AutoProcessor, SeamlessM4Tv2Model
|
28 |
|
29 |
# from seamless_communication.inference import Translator
|
@@ -129,7 +131,6 @@ static_files = {
|
|
129 |
},
|
130 |
}
|
131 |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
132 |
-
# processor = AutoProcessor.from_pretrained("facebook/seamless-m4t-v2-large", force_download=True)
|
133 |
processor = AutoProcessor.from_pretrained("facebook/seamless-m4t-v2-large")
|
134 |
|
135 |
# PM - hardcoding temporarily as my GPU doesnt have enough vram
|
@@ -152,34 +153,11 @@ def get_collection_calls():
|
|
152 |
return app.database["call_test"]
|
153 |
|
154 |
|
155 |
-
@app.get("/
|
156 |
def test():
|
157 |
-
|
158 |
return {"message": "Welcome to InterpreTalk!"}
|
159 |
|
160 |
|
161 |
-
@app.post("/test_post/", response_description="List more test call records")
|
162 |
-
def test_post():
|
163 |
-
request_data = {
|
164 |
-
"call_id": "TESTID000001"
|
165 |
-
}
|
166 |
-
|
167 |
-
result = create_calls(get_collection_calls(), request_data)
|
168 |
-
|
169 |
-
# return {"message": "Welcome to InterpreTalk!"}
|
170 |
-
return result
|
171 |
-
|
172 |
-
@app.put("/test_put/", response_description="List test call records")
|
173 |
-
def test_put():
|
174 |
-
|
175 |
-
# result = list_calls(get_collection_calls(), 100)
|
176 |
-
# result = send_captions("TEST", "TEST", "TEST", "oUjUxTYTQFVVjEarIcZ0")
|
177 |
-
result = send_captions("TEST", "TEST", "TEST", "TESTID000001")
|
178 |
-
|
179 |
-
print(result)
|
180 |
-
return result
|
181 |
-
|
182 |
-
|
183 |
async def send_translated_text(client_id, original_text, translated_text, room_id):
|
184 |
print('SEND_TRANSLATED_TEXT IS WOKRING IN FASTAPI BACKEND...')
|
185 |
print(rooms) # Debugging
|
@@ -207,10 +185,33 @@ async def connect(sid, environ):
|
|
207 |
gunicorn_logger.warning(clients)
|
208 |
|
209 |
@sio.on("disconnect")
|
210 |
-
async def disconnect(sid):
|
211 |
gunicorn_logger.debug(f"📤 [event: disconnected] sid={sid}")
|
212 |
clients.pop(sid, None)
|
213 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
214 |
|
215 |
@sio.on("target_language")
|
216 |
async def target_language(sid, target_lang):
|
@@ -232,16 +233,16 @@ async def call_user(sid, call_id):
|
|
232 |
# BO - Get call id from dictionary created during socketio connection
|
233 |
client_id = clients[sid].client_id
|
234 |
|
235 |
-
|
236 |
-
#
|
237 |
-
|
238 |
-
|
239 |
-
|
240 |
-
|
241 |
-
|
242 |
|
243 |
-
|
244 |
-
|
245 |
|
246 |
@sio.on("audio_config")
|
247 |
async def audio_config(sid, sample_rate):
|
@@ -265,27 +266,27 @@ async def answer_call(sid, call_id):
|
|
265 |
# BO - Get call id from dictionary created during socketio connection
|
266 |
client_id = clients[sid].client_id
|
267 |
|
268 |
-
#
|
269 |
-
|
270 |
-
#
|
271 |
-
|
272 |
-
|
273 |
-
|
274 |
-
|
275 |
-
|
276 |
-
|
277 |
|
278 |
|
279 |
@sio.on("incoming_audio")
|
280 |
async def incoming_audio(sid, data, call_id):
|
281 |
-
gunicorn_logger.info("RUNNNING INCOMING AUDIO FUNCTION")
|
282 |
try:
|
283 |
clients[sid].add_bytes(data)
|
284 |
|
285 |
if clients[sid].get_length() >= MAX_BYTES_BUFFER:
|
286 |
gunicorn_logger.info('Buffer full, now outputting...')
|
287 |
output_path = clients[sid].output_path
|
288 |
-
|
|
|
289 |
# source lang is speakers tgt language 😃
|
290 |
src_lang = clients[sid].target_language
|
291 |
if vad_result:
|
@@ -304,21 +305,18 @@ async def incoming_audio(sid, data, call_id):
|
|
304 |
translated_data = model.generate(**t2t_tokens, tgt_lang=tgt_lang, generate_speech=False)[0].tolist()[0]
|
305 |
translated_text = processor.decode(translated_data, skip_special_tokens=True)
|
306 |
print(f"TRANSLATED TEXT = {translated_text}")
|
307 |
-
|
308 |
-
# BO -> send translated_text to mongodb as caption record update based on call_id
|
309 |
-
# send_captions(clients[sid].client_id, asr_text, translated_text, call_id)
|
310 |
|
311 |
# TRANSLATED TEXT
|
312 |
# PM - text_output is a list with 1 string
|
313 |
await send_translated_text(clients[sid].client_id, asr_text, translated_text, call_id)
|
314 |
|
315 |
# BO -> send translated_text to mongodb as caption record update based on call_id
|
316 |
-
|
317 |
|
318 |
except Exception as e:
|
319 |
gunicorn_logger.error(f"Error in incoming_audio: {e.with_traceback()}")
|
320 |
|
321 |
-
def send_captions(client_id, original_text, translated_text, call_id):
|
322 |
# BO -> Update Call Record with Callee field based on call_id
|
323 |
print(f"Now updating Caption field in call record for Caller with ID: {client_id} for call: {call_id}")
|
324 |
|
|
|
24 |
from mongodb.models.calls import UserCall, UpdateCall
|
25 |
# from mongodb.endpoints.calls import *
|
26 |
|
27 |
+
from utils.text_rank import extract_terms
|
28 |
+
|
29 |
from transformers import AutoProcessor, SeamlessM4Tv2Model
|
30 |
|
31 |
# from seamless_communication.inference import Translator
|
|
|
131 |
},
|
132 |
}
|
133 |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
|
|
134 |
processor = AutoProcessor.from_pretrained("facebook/seamless-m4t-v2-large")
|
135 |
|
136 |
# PM - hardcoding temporarily as my GPU doesnt have enough vram
|
|
|
153 |
return app.database["call_test"]
|
154 |
|
155 |
|
156 |
+
@app.get("/home/", response_description="Welcome User")
|
157 |
def test():
|
|
|
158 |
return {"message": "Welcome to InterpreTalk!"}
|
159 |
|
160 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
161 |
async def send_translated_text(client_id, original_text, translated_text, room_id):
|
162 |
print('SEND_TRANSLATED_TEXT IS WOKRING IN FASTAPI BACKEND...')
|
163 |
print(rooms) # Debugging
|
|
|
185 |
gunicorn_logger.warning(clients)
|
186 |
|
187 |
@sio.on("disconnect")
|
188 |
+
async def disconnect(sid):
|
189 |
gunicorn_logger.debug(f"📤 [event: disconnected] sid={sid}")
|
190 |
clients.pop(sid, None)
|
191 |
+
|
192 |
+
@sio.on("term_extraction")
|
193 |
+
async def term_extraction(sid, call_id):
|
194 |
+
gunicorn_logger.debug(f"📤 [event: term_extraction] sid={sid}, call={call_id}")
|
195 |
+
|
196 |
+
# call_id = "0FIdAosKy9ysQDkp14T2"
|
197 |
+
|
198 |
+
# Get combined caption field for call record based on call_id
|
199 |
+
combined_text = get_caption_text(get_collection_calls(), call_id)
|
200 |
+
|
201 |
+
if combined_text: # > min_caption_length: -> poor term extraction on short
|
202 |
+
print("THE COMBINED TEXT IS:", combined_text)
|
203 |
+
|
204 |
+
# Extract Key Terms from Concatenated Caption Field
|
205 |
+
key_terms = extract_terms(combined_text, len(combined_text))
|
206 |
+
|
207 |
+
# BO -> Update Call record with call duration, key terms
|
208 |
+
print("THE KEY TERMS ARE:", key_terms)
|
209 |
+
|
210 |
+
request_data = {
|
211 |
+
"key_terms": key_terms
|
212 |
+
}
|
213 |
+
|
214 |
+
update_calls(get_collection_calls(), call_id, request_data)
|
215 |
|
216 |
@sio.on("target_language")
|
217 |
async def target_language(sid, target_lang):
|
|
|
233 |
# BO - Get call id from dictionary created during socketio connection
|
234 |
client_id = clients[sid].client_id
|
235 |
|
236 |
+
gunicorn_logger.warning(f"NOW TRYING TO CREATE DB RECORD FOR Caller with ID: {client_id} for call: {call_id}")
|
237 |
+
# BO -> Create Call Record with Caller and call_id field (None for callee, duration, terms..)
|
238 |
+
request_data = {
|
239 |
+
"call_id": str(call_id),
|
240 |
+
"caller_id": str(client_id),
|
241 |
+
"creation_date": str(datetime.now())
|
242 |
+
}
|
243 |
|
244 |
+
response = create_calls(get_collection_calls(), request_data)
|
245 |
+
print(response) # BO - print created db call record
|
246 |
|
247 |
@sio.on("audio_config")
|
248 |
async def audio_config(sid, sample_rate):
|
|
|
266 |
# BO - Get call id from dictionary created during socketio connection
|
267 |
client_id = clients[sid].client_id
|
268 |
|
269 |
+
# BO -> Update Call Record with Callee field based on call_id
|
270 |
+
gunicorn_logger.warning(f"NOW UPDATING MongoDB RECORD FOR Caller with ID: {client_id} for call: {call_id}")
|
271 |
+
# BO -> Create Call Record with callee_id field (None for callee, duration, terms..)
|
272 |
+
request_data = {
|
273 |
+
"callee_id": client_id
|
274 |
+
}
|
275 |
+
|
276 |
+
response = update_calls(get_collection_calls(), call_id, request_data)
|
277 |
+
print(response) # BO - print created db call record
|
278 |
|
279 |
|
280 |
@sio.on("incoming_audio")
|
281 |
async def incoming_audio(sid, data, call_id):
|
|
|
282 |
try:
|
283 |
clients[sid].add_bytes(data)
|
284 |
|
285 |
if clients[sid].get_length() >= MAX_BYTES_BUFFER:
|
286 |
gunicorn_logger.info('Buffer full, now outputting...')
|
287 |
output_path = clients[sid].output_path
|
288 |
+
resampled_audio = clients[sid].resample_and_clear()
|
289 |
+
vad_result = clients[sid].vad_analyse(resampled_audio)
|
290 |
# source lang is speakers tgt language 😃
|
291 |
src_lang = clients[sid].target_language
|
292 |
if vad_result:
|
|
|
305 |
translated_data = model.generate(**t2t_tokens, tgt_lang=tgt_lang, generate_speech=False)[0].tolist()[0]
|
306 |
translated_text = processor.decode(translated_data, skip_special_tokens=True)
|
307 |
print(f"TRANSLATED TEXT = {translated_text}")
|
|
|
|
|
|
|
308 |
|
309 |
# TRANSLATED TEXT
|
310 |
# PM - text_output is a list with 1 string
|
311 |
await send_translated_text(clients[sid].client_id, asr_text, translated_text, call_id)
|
312 |
|
313 |
# BO -> send translated_text to mongodb as caption record update based on call_id
|
314 |
+
await send_captions(clients[sid].client_id, asr_text, translated_text, call_id)
|
315 |
|
316 |
except Exception as e:
|
317 |
gunicorn_logger.error(f"Error in incoming_audio: {e.with_traceback()}")
|
318 |
|
319 |
+
async def send_captions(client_id, original_text, translated_text, call_id):
|
320 |
# BO -> Update Call Record with Callee field based on call_id
|
321 |
print(f"Now updating Caption field in call record for Caller with ID: {client_id} for call: {call_id}")
|
322 |
|