AshDavid12 commited on
Commit
4c42c49
·
1 Parent(s): 35b4964
Files changed (2) hide show
  1. client.py +9 -29
  2. infer.py +23 -7
client.py CHANGED
@@ -2,38 +2,14 @@ import asyncio
2
  import websockets
3
  import requests
4
  import ssl
 
5
 
6
  # Parameters for reading and sending the audio
7
  AUDIO_FILE_URL = "https://raw.githubusercontent.com/AshDavid12/runpod-serverless-forked/main/test_hebrew.wav" # Use WAV file
8
 
9
- processed_segments = set()
10
-
11
- def process_transcription_results(transcription_result):
12
- global processed_segments # Ensure we use the same set across multiple calls
13
- new_segments = []
14
-
15
- # Iterate over all segments in the transcription result
16
- for segment in transcription_result.get("segments", []):
17
- # You can use a unique identifier like 'id' or a combination of 'start' and 'end' times
18
- segment_id = segment.get("id")
19
-
20
- # Check if the segment is already processed
21
- if segment_id not in processed_segments:
22
- # Process the new segment (do your actual processing here)
23
- new_segments.append(segment)
24
-
25
- # Mark the segment as processed by adding its 'id' to the set
26
- processed_segments.add(segment_id)
27
- print(f"Processed segment ID: {segment_id}")
28
- else:
29
- print(f"Skipping already processed segment ID: {segment_id}")
30
-
31
- # Return only new segments that have not been processed before
32
- return new_segments
33
-
34
 
35
  async def send_audio(websocket):
36
- buffer_size = 512 * 1024 # Buffer audio chunks up to 512KB before sending
37
  audio_buffer = bytearray()
38
 
39
  with requests.get(AUDIO_FILE_URL, stream=True, allow_redirects=False) as response:
@@ -48,7 +24,7 @@ async def send_audio(websocket):
48
  # Send buffered audio data once it's large enough
49
  if len(audio_buffer) >= buffer_size:
50
  await websocket.send(audio_buffer)
51
- print(f"Sent {len(audio_buffer)} bytes of audio data.")
52
  audio_buffer.clear()
53
  await asyncio.sleep(0.01)
54
 
@@ -56,6 +32,7 @@ async def send_audio(websocket):
56
  else:
57
  print(f"Failed to download audio file. Status code: {response.status_code}")
58
 
 
59
  async def receive_transcription(websocket):
60
  while True:
61
  try:
@@ -70,6 +47,7 @@ async def receive_transcription(websocket):
70
  print(f"Error receiving transcription: {e}")
71
  break
72
 
 
73
  async def send_heartbeat(websocket):
74
  while True:
75
  try:
@@ -78,7 +56,7 @@ async def send_heartbeat(websocket):
78
  except websockets.ConnectionClosed:
79
  print("Connection closed, stopping heartbeat")
80
  break
81
- await asyncio.sleep(120) # Send ping every 30 seconds (adjust as needed)
82
 
83
 
84
  async def run_client():
@@ -87,11 +65,13 @@ async def run_client():
87
  ssl_context.check_hostname = False
88
  ssl_context.verify_mode = ssl.CERT_NONE
89
 
90
- async with websockets.connect(uri, ssl=ssl_context, timeout=600) as websocket:
 
91
  await asyncio.gather(
92
  send_audio(websocket),
93
  receive_transcription(websocket),
94
  send_heartbeat(websocket)
95
  )
96
 
 
97
  asyncio.run(run_client())
 
2
  import websockets
3
  import requests
4
  import ssl
5
+ import logging
6
 
7
  # Parameters for reading and sending the audio
8
  AUDIO_FILE_URL = "https://raw.githubusercontent.com/AshDavid12/runpod-serverless-forked/main/test_hebrew.wav" # Use WAV file
9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
  async def send_audio(websocket):
12
+ buffer_size = 1024 # Buffer audio chunks up to 512KB before sending
13
  audio_buffer = bytearray()
14
 
15
  with requests.get(AUDIO_FILE_URL, stream=True, allow_redirects=False) as response:
 
24
  # Send buffered audio data once it's large enough
25
  if len(audio_buffer) >= buffer_size:
26
  await websocket.send(audio_buffer)
27
+ #print(f"Sent {len(audio_buffer)} bytes of audio data.")
28
  audio_buffer.clear()
29
  await asyncio.sleep(0.01)
30
 
 
32
  else:
33
  print(f"Failed to download audio file. Status code: {response.status_code}")
34
 
35
+
36
  async def receive_transcription(websocket):
37
  while True:
38
  try:
 
47
  print(f"Error receiving transcription: {e}")
48
  break
49
 
50
+
51
  async def send_heartbeat(websocket):
52
  while True:
53
  try:
 
56
  except websockets.ConnectionClosed:
57
  print("Connection closed, stopping heartbeat")
58
  break
59
+ await asyncio.sleep(30) # Send ping every 30 seconds (adjust as needed)
60
 
61
 
62
  async def run_client():
 
65
  ssl_context.check_hostname = False
66
  ssl_context.verify_mode = ssl.CERT_NONE
67
 
68
+ async with websockets.connect(uri, ssl=ssl_context, timeout=120) as websocket:
69
+ print(f"here")
70
  await asyncio.gather(
71
  send_audio(websocket),
72
  receive_transcription(websocket),
73
  send_heartbeat(websocket)
74
  )
75
 
76
+
77
  asyncio.run(run_client())
infer.py CHANGED
@@ -160,14 +160,13 @@ def transcribe_core_ws(audio_file, last_transcribed_time):
160
 
161
  # Track the new segments and update the last transcribed time
162
  for s in segs:
163
- words= []
164
  logging.info(f"Processing segment with start time: {s.start} and end time: {s.end}")
165
 
166
  # Only process segments that start after the last transcribed time
167
  if s.start >= last_transcribed_time:
168
  logging.info(f"New segment found starting at {s.start} seconds.")
169
- for w in words:
170
- words.append({'start': w.start, 'end': w.end, 'word': w.word, 'probability': w.probability})
171
  seg = {
172
  'id': s.id, 'seek': s.seek, 'start': s.start, 'end': s.end, 'text': s.text,
173
  'avg_logprob': s.avg_logprob, 'compression_ratio': s.compression_ratio,
@@ -177,10 +176,10 @@ def transcribe_core_ws(audio_file, last_transcribed_time):
177
  ret['new_segments'].append(seg)
178
 
179
  # Update the last transcribed time to the end of the current segment
180
- new_last_transcribed_time = s.end
181
  logging.debug(f"Updated last transcribed time to: {new_last_transcribed_time} seconds")
182
 
183
- logging.info(f"Returning {len(ret['new_segments'])} new segments and updated last transcribed time.")
184
  return ret, new_last_transcribed_time
185
 
186
 
@@ -195,7 +194,10 @@ async def websocket_transcribe(websocket: WebSocket):
195
 
196
  try:
197
  processed_segments = [] # Keeps track of the segments already transcribed
 
 
198
  last_transcribed_time = 0.0
 
199
 
200
  # A temporary file to store the growing audio data
201
  with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_audio_file:
@@ -212,18 +214,33 @@ async def websocket_transcribe(websocket: WebSocket):
212
  # Write audio chunk to file and accumulate size and time
213
  temp_audio_file.write(audio_chunk)
214
  temp_audio_file.flush()
 
 
 
 
 
 
 
 
 
 
 
215
 
216
  # Call the transcription function with the last processed time
217
  partial_result, last_transcribed_time = transcribe_core_ws(temp_audio_file.name, last_transcribed_time)
218
  accumulated_audio_time = 0 # Reset the accumulated audio time
 
 
 
 
219
 
 
220
  response = {
221
  "new_segments": partial_result['new_segments'],
222
  "processed_segments": processed_segments
223
  }
224
  logging.info(f"Sending {len(partial_result['new_segments'])} new segments to the client.")
225
  await websocket.send_json(response)
226
- processed_segments.extend(partial_result['new_segments'])
227
 
228
  except WebSocketDisconnect:
229
  logging.info("WebSocket connection closed by the client.")
@@ -237,4 +254,3 @@ async def websocket_transcribe(websocket: WebSocket):
237
  logging.info("Cleaning up and closing WebSocket connection.")
238
 
239
 
240
-
 
160
 
161
  # Track the new segments and update the last transcribed time
162
  for s in segs:
 
163
  logging.info(f"Processing segment with start time: {s.start} and end time: {s.end}")
164
 
165
  # Only process segments that start after the last transcribed time
166
  if s.start >= last_transcribed_time:
167
  logging.info(f"New segment found starting at {s.start} seconds.")
168
+ words = [{'start': w.start, 'end': w.end, 'word': w.word, 'probability': w.probability} for w in s.words]
169
+
170
  seg = {
171
  'id': s.id, 'seek': s.seek, 'start': s.start, 'end': s.end, 'text': s.text,
172
  'avg_logprob': s.avg_logprob, 'compression_ratio': s.compression_ratio,
 
176
  ret['new_segments'].append(seg)
177
 
178
  # Update the last transcribed time to the end of the current segment
179
+ new_last_transcribed_time = max(new_last_transcribed_time, s.end)
180
  logging.debug(f"Updated last transcribed time to: {new_last_transcribed_time} seconds")
181
 
182
+ #logging.info(f"Returning {len(ret['new_segments'])} new segments and updated last transcribed time.")
183
  return ret, new_last_transcribed_time
184
 
185
 
 
194
 
195
  try:
196
  processed_segments = [] # Keeps track of the segments already transcribed
197
+ accumulated_audio_size = 0 # Track how much audio data has been buffered
198
+ accumulated_audio_time = 0 # Track the total audio duration accumulated
199
  last_transcribed_time = 0.0
200
+ #min_transcription_time = 5.0 # Minimum duration of audio in seconds before transcription starts
201
 
202
  # A temporary file to store the growing audio data
203
  with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_audio_file:
 
214
  # Write audio chunk to file and accumulate size and time
215
  temp_audio_file.write(audio_chunk)
216
  temp_audio_file.flush()
217
+ accumulated_audio_size += len(audio_chunk)
218
+
219
+ # Estimate the duration of the chunk based on its size (e.g., 16kHz audio)
220
+ chunk_duration = len(audio_chunk) / (16000 * 2) # Assuming 16kHz mono WAV (2 bytes per sample)
221
+ accumulated_audio_time += chunk_duration
222
+ #logging.info(f"Received and buffered {len(audio_chunk)} bytes, total buffered: {accumulated_audio_size} bytes, total time: {accumulated_audio_time:.2f} seconds")
223
+
224
+ # Transcribe when enough time (audio) is accumulated (e.g., at least 5 seconds of audio)
225
+ #if accumulated_audio_time >= min_transcription_time:
226
+ #logging.info("Buffered enough audio time, starting transcription.")
227
+
228
 
229
  # Call the transcription function with the last processed time
230
  partial_result, last_transcribed_time = transcribe_core_ws(temp_audio_file.name, last_transcribed_time)
231
  accumulated_audio_time = 0 # Reset the accumulated audio time
232
+ processed_segments.extend(partial_result['new_segments'])
233
+
234
+ # Reset the accumulated audio size after transcription
235
+ accumulated_audio_size = 0
236
 
237
+ # Send the transcription result back to the client with both new and all processed segments
238
  response = {
239
  "new_segments": partial_result['new_segments'],
240
  "processed_segments": processed_segments
241
  }
242
  logging.info(f"Sending {len(partial_result['new_segments'])} new segments to the client.")
243
  await websocket.send_json(response)
 
244
 
245
  except WebSocketDisconnect:
246
  logging.info("WebSocket connection closed by the client.")
 
254
  logging.info("Cleaning up and closing WebSocket connection.")
255
 
256