benjolo commited on
Commit
000aa13
·
verified ·
1 Parent(s): fb8d5d3

Update backend/main.py

Browse files
Files changed (1) hide show
  1. backend/main.py +50 -52
backend/main.py CHANGED
@@ -24,6 +24,8 @@ from mongodb.operations.calls import *
24
  from mongodb.models.calls import UserCall, UpdateCall
25
  # from mongodb.endpoints.calls import *
26
 
 
 
27
  from transformers import AutoProcessor, SeamlessM4Tv2Model
28
 
29
  # from seamless_communication.inference import Translator
@@ -129,7 +131,6 @@ static_files = {
129
  },
130
  }
131
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
132
- # processor = AutoProcessor.from_pretrained("facebook/seamless-m4t-v2-large", force_download=True)
133
  processor = AutoProcessor.from_pretrained("facebook/seamless-m4t-v2-large")
134
 
135
  # PM - hardcoding temporarily as my GPU doesnt have enough vram
@@ -152,34 +153,11 @@ def get_collection_calls():
152
  return app.database["call_test"]
153
 
154
 
155
- @app.get("/test/", response_description="Welcome User")
156
  def test():
157
-
158
  return {"message": "Welcome to InterpreTalk!"}
159
 
160
 
161
- @app.post("/test_post/", response_description="List more test call records")
162
- def test_post():
163
- request_data = {
164
- "call_id": "TESTID000001"
165
- }
166
-
167
- result = create_calls(get_collection_calls(), request_data)
168
-
169
- # return {"message": "Welcome to InterpreTalk!"}
170
- return result
171
-
172
- @app.put("/test_put/", response_description="List test call records")
173
- def test_put():
174
-
175
- # result = list_calls(get_collection_calls(), 100)
176
- # result = send_captions("TEST", "TEST", "TEST", "oUjUxTYTQFVVjEarIcZ0")
177
- result = send_captions("TEST", "TEST", "TEST", "TESTID000001")
178
-
179
- print(result)
180
- return result
181
-
182
-
183
  async def send_translated_text(client_id, original_text, translated_text, room_id):
184
  print('SEND_TRANSLATED_TEXT IS WOKRING IN FASTAPI BACKEND...')
185
  print(rooms) # Debugging
@@ -207,10 +185,33 @@ async def connect(sid, environ):
207
  gunicorn_logger.warning(clients)
208
 
209
  @sio.on("disconnect")
210
- async def disconnect(sid): # BO - also pass call id as parameter for updating MongoDB
211
  gunicorn_logger.debug(f"📤 [event: disconnected] sid={sid}")
212
  clients.pop(sid, None)
213
- # BO -> Update Call record with call duration, key terms
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
214
 
215
  @sio.on("target_language")
216
  async def target_language(sid, target_lang):
@@ -232,16 +233,16 @@ async def call_user(sid, call_id):
232
  # BO - Get call id from dictionary created during socketio connection
233
  client_id = clients[sid].client_id
234
 
235
- # gunicorn_logger.warning(f"NOW TRYING TO CREATE DB RECORD FOR Caller with ID: {client_id} for call: {call_id}")
236
- # # BO -> Create Call Record with Caller and call_id field (None for callee, duration, terms..)
237
- # request_data = {
238
- # "call_id": str(call_id),
239
- # "caller_id": str(client_id),
240
- # "creation_date": str(datetime.now())
241
- # }
242
 
243
- # response = create_calls(get_collection_calls(), request_data)
244
- # print(response) # BO - print created db call record
245
 
246
  @sio.on("audio_config")
247
  async def audio_config(sid, sample_rate):
@@ -265,27 +266,27 @@ async def answer_call(sid, call_id):
265
  # BO - Get call id from dictionary created during socketio connection
266
  client_id = clients[sid].client_id
267
 
268
- # # BO -> Update Call Record with Callee field based on call_id
269
- # gunicorn_logger.warning(f"NOW UPDATING MongoDB RECORD FOR Caller with ID: {client_id} for call: {call_id}")
270
- # # BO -> Create Call Record with callee_id field (None for callee, duration, terms..)
271
- # request_data = {
272
- # "callee_id": client_id
273
- # }
274
-
275
- # response = update_calls(get_collection_calls(), call_id, request_data)
276
- # print(response) # BO - print created db call record
277
 
278
 
279
  @sio.on("incoming_audio")
280
  async def incoming_audio(sid, data, call_id):
281
- gunicorn_logger.info("RUNNNING INCOMING AUDIO FUNCTION")
282
  try:
283
  clients[sid].add_bytes(data)
284
 
285
  if clients[sid].get_length() >= MAX_BYTES_BUFFER:
286
  gunicorn_logger.info('Buffer full, now outputting...')
287
  output_path = clients[sid].output_path
288
- vad_result, resampled_audio = clients[sid].resample_and_write_to_file()
 
289
  # source lang is speakers tgt language 😃
290
  src_lang = clients[sid].target_language
291
  if vad_result:
@@ -304,21 +305,18 @@ async def incoming_audio(sid, data, call_id):
304
  translated_data = model.generate(**t2t_tokens, tgt_lang=tgt_lang, generate_speech=False)[0].tolist()[0]
305
  translated_text = processor.decode(translated_data, skip_special_tokens=True)
306
  print(f"TRANSLATED TEXT = {translated_text}")
307
-
308
- # BO -> send translated_text to mongodb as caption record update based on call_id
309
- # send_captions(clients[sid].client_id, asr_text, translated_text, call_id)
310
 
311
  # TRANSLATED TEXT
312
  # PM - text_output is a list with 1 string
313
  await send_translated_text(clients[sid].client_id, asr_text, translated_text, call_id)
314
 
315
  # BO -> send translated_text to mongodb as caption record update based on call_id
316
- # send_captions(clients[sid].client_id, asr_text, translated_text, call_id)
317
 
318
  except Exception as e:
319
  gunicorn_logger.error(f"Error in incoming_audio: {e.with_traceback()}")
320
 
321
- def send_captions(client_id, original_text, translated_text, call_id):
322
  # BO -> Update Call Record with Callee field based on call_id
323
  print(f"Now updating Caption field in call record for Caller with ID: {client_id} for call: {call_id}")
324
 
 
24
  from mongodb.models.calls import UserCall, UpdateCall
25
  # from mongodb.endpoints.calls import *
26
 
27
+ from utils.text_rank import extract_terms
28
+
29
  from transformers import AutoProcessor, SeamlessM4Tv2Model
30
 
31
  # from seamless_communication.inference import Translator
 
131
  },
132
  }
133
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
 
134
  processor = AutoProcessor.from_pretrained("facebook/seamless-m4t-v2-large")
135
 
136
  # PM - hardcoding temporarily as my GPU doesnt have enough vram
 
153
  return app.database["call_test"]
154
 
155
 
156
+ @app.get("/home/", response_description="Welcome User")
157
  def test():
 
158
  return {"message": "Welcome to InterpreTalk!"}
159
 
160
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
161
  async def send_translated_text(client_id, original_text, translated_text, room_id):
162
  print('SEND_TRANSLATED_TEXT IS WOKRING IN FASTAPI BACKEND...')
163
  print(rooms) # Debugging
 
185
  gunicorn_logger.warning(clients)
186
 
187
  @sio.on("disconnect")
188
+ async def disconnect(sid):
189
  gunicorn_logger.debug(f"📤 [event: disconnected] sid={sid}")
190
  clients.pop(sid, None)
191
+
192
+ @sio.on("term_extraction")
193
+ async def term_extraction(sid, call_id):
194
+ gunicorn_logger.debug(f"📤 [event: term_extraction] sid={sid}, call={call_id}")
195
+
196
+ # call_id = "0FIdAosKy9ysQDkp14T2"
197
+
198
+ # Get combined caption field for call record based on call_id
199
+ combined_text = get_caption_text(get_collection_calls(), call_id)
200
+
201
+ if combined_text: # > min_caption_length: -> poor term extraction on short
202
+ print("THE COMBINED TEXT IS:", combined_text)
203
+
204
+ # Extract Key Terms from Concatenated Caption Field
205
+ key_terms = extract_terms(combined_text, len(combined_text))
206
+
207
+ # BO -> Update Call record with call duration, key terms
208
+ print("THE KEY TERMS ARE:", key_terms)
209
+
210
+ request_data = {
211
+ "key_terms": key_terms
212
+ }
213
+
214
+ update_calls(get_collection_calls(), call_id, request_data)
215
 
216
  @sio.on("target_language")
217
  async def target_language(sid, target_lang):
 
233
  # BO - Get call id from dictionary created during socketio connection
234
  client_id = clients[sid].client_id
235
 
236
+ gunicorn_logger.warning(f"NOW TRYING TO CREATE DB RECORD FOR Caller with ID: {client_id} for call: {call_id}")
237
+ # BO -> Create Call Record with Caller and call_id field (None for callee, duration, terms..)
238
+ request_data = {
239
+ "call_id": str(call_id),
240
+ "caller_id": str(client_id),
241
+ "creation_date": str(datetime.now())
242
+ }
243
 
244
+ response = create_calls(get_collection_calls(), request_data)
245
+ print(response) # BO - print created db call record
246
 
247
  @sio.on("audio_config")
248
  async def audio_config(sid, sample_rate):
 
266
  # BO - Get call id from dictionary created during socketio connection
267
  client_id = clients[sid].client_id
268
 
269
+ # BO -> Update Call Record with Callee field based on call_id
270
+ gunicorn_logger.warning(f"NOW UPDATING MongoDB RECORD FOR Caller with ID: {client_id} for call: {call_id}")
271
+ # BO -> Create Call Record with callee_id field (None for callee, duration, terms..)
272
+ request_data = {
273
+ "callee_id": client_id
274
+ }
275
+
276
+ response = update_calls(get_collection_calls(), call_id, request_data)
277
+ print(response) # BO - print created db call record
278
 
279
 
280
  @sio.on("incoming_audio")
281
  async def incoming_audio(sid, data, call_id):
 
282
  try:
283
  clients[sid].add_bytes(data)
284
 
285
  if clients[sid].get_length() >= MAX_BYTES_BUFFER:
286
  gunicorn_logger.info('Buffer full, now outputting...')
287
  output_path = clients[sid].output_path
288
+ resampled_audio = clients[sid].resample_and_clear()
289
+ vad_result = clients[sid].vad_analyse(resampled_audio)
290
  # source lang is speakers tgt language 😃
291
  src_lang = clients[sid].target_language
292
  if vad_result:
 
305
  translated_data = model.generate(**t2t_tokens, tgt_lang=tgt_lang, generate_speech=False)[0].tolist()[0]
306
  translated_text = processor.decode(translated_data, skip_special_tokens=True)
307
  print(f"TRANSLATED TEXT = {translated_text}")
 
 
 
308
 
309
  # TRANSLATED TEXT
310
  # PM - text_output is a list with 1 string
311
  await send_translated_text(clients[sid].client_id, asr_text, translated_text, call_id)
312
 
313
  # BO -> send translated_text to mongodb as caption record update based on call_id
314
+ await send_captions(clients[sid].client_id, asr_text, translated_text, call_id)
315
 
316
  except Exception as e:
317
  gunicorn_logger.error(f"Error in incoming_audio: {e.with_traceback()}")
318
 
319
+ async def send_captions(client_id, original_text, translated_text, call_id):
320
  # BO -> Update Call Record with Callee field based on call_id
321
  print(f"Now updating Caption field in call record for Caller with ID: {client_id} for call: {call_id}")
322