AshDavid12 commited on
Commit
92ce07c
·
1 Parent(s): b0d532b

added chunk tracking

Browse files
Files changed (2) hide show
  1. client.py +14 -5
  2. infer.py +16 -11
client.py CHANGED
@@ -7,15 +7,24 @@ import ssl
7
  AUDIO_FILE_URL = "https://raw.githubusercontent.com/AshDavid12/runpod_serverless_whisper/main/me-hebrew.wav" # Use WAV file
8
 
9
  async def send_audio(websocket):
10
- # Stream the audio file in real-time
 
 
11
  with requests.get(AUDIO_FILE_URL, stream=True, allow_redirects=False) as response:
12
  if response.status_code == 200:
13
  print("Starting to stream audio file...")
14
 
15
- for chunk in response.iter_content(chunk_size=8192): # Stream in chunks of 8192 bytes
16
  if chunk:
17
- await websocket.send(chunk) # Send each chunk over WebSocket
18
- print(f"Sent audio chunk of size {len(chunk)} bytes")
 
 
 
 
 
 
 
19
 
20
  print("Finished sending audio.")
21
  else:
@@ -42,7 +51,7 @@ async def send_heartbeat(websocket):
42
 
43
 
44
  async def run_client():
45
- uri = ("wss://gigaverse-ivrit-ai-streaming.hf.space/ws/transcribe") # WebSocket URL
46
  ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
47
  ssl_context.check_hostname = False
48
  ssl_context.verify_mode = ssl.CERT_NONE
 
7
  AUDIO_FILE_URL = "https://raw.githubusercontent.com/AshDavid12/runpod_serverless_whisper/main/me-hebrew.wav" # Use WAV file
8
 
9
  async def send_audio(websocket):
10
+ buffer_size = 512 * 1024 # Buffer audio chunks up to 512KB before sending
11
+ audio_buffer = bytearray()
12
+
13
  with requests.get(AUDIO_FILE_URL, stream=True, allow_redirects=False) as response:
14
  if response.status_code == 200:
15
  print("Starting to stream audio file...")
16
 
17
+ for chunk in response.iter_content(chunk_size=8192): # Stream in chunks
18
  if chunk:
19
+ audio_buffer.extend(chunk)
20
+ print(f"Received audio chunk of size {len(chunk)} bytes.")
21
+
22
+ # Send buffered audio data once it's large enough
23
+ if len(audio_buffer) >= buffer_size:
24
+ await websocket.send(audio_buffer)
25
+ print(f"Sent {len(audio_buffer)} bytes of audio data.")
26
+ audio_buffer.clear()
27
+ await asyncio.sleep(0.01)
28
 
29
  print("Finished sending audio.")
30
  else:
 
51
 
52
 
53
  async def run_client():
54
+ uri = ("wss://gigaverse-ivrit-ai-streaming.hf.space/wtranscribe") # WebSocket URL
55
  ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
56
  ssl_context.check_hostname = False
57
  ssl_context.verify_mode = ssl.CERT_NONE
infer.py CHANGED
@@ -9,7 +9,7 @@ from fastapi import FastAPI, HTTPException, WebSocket,WebSocketDisconnect
9
  import websockets
10
  from pydantic import BaseModel
11
  from typing import Optional
12
- import sys
13
  import asyncio
14
 
15
  # Configure logging
@@ -186,7 +186,7 @@ def transcribe_core_ws(audio_file, last_transcribed_time):
186
  import tempfile
187
 
188
 
189
- @app.websocket("/ws/transcribe")
190
  async def websocket_transcribe(websocket: WebSocket):
191
  logging.info("New WebSocket connection request received.")
192
  await websocket.accept()
@@ -195,6 +195,8 @@ async def websocket_transcribe(websocket: WebSocket):
195
  try:
196
  processed_segments = [] # Keeps track of the segments already transcribed
197
  accumulated_audio_size = 0 # Track how much audio data has been buffered
 
 
198
 
199
  # A temporary file to store the growing audio data
200
  with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_audio_file:
@@ -208,20 +210,23 @@ async def websocket_transcribe(websocket: WebSocket):
208
  logging.warning("Received empty audio chunk, skipping processing.")
209
  continue
210
 
211
- # Write audio chunk to file and accumulate size
212
  temp_audio_file.write(audio_chunk)
213
  temp_audio_file.flush()
214
  accumulated_audio_size += len(audio_chunk)
215
- logging.info(
216
- f"Received and buffered {len(audio_chunk)} bytes, total buffered: {accumulated_audio_size} bytes")
217
 
218
- # Buffer at least 512KB before transcription
219
- if accumulated_audio_size >= (512 * 1024): # Adjust this size as needed
220
- logging.info("Buffered enough data, starting transcription.")
 
221
 
222
- partial_result, processed_segments = transcribe_core_ws(temp_audio_file.name,
223
- processed_segments)
224
- accumulated_audio_size = 0 # Reset the accumulated audio size
 
 
 
 
225
 
226
  # Send the transcription result back to the client
227
  logging.info(f"Sending {len(partial_result['new_segments'])} new segments to the client.")
 
9
  import websockets
10
  from pydantic import BaseModel
11
  from typing import Optional
12
+ import sys
13
  import asyncio
14
 
15
  # Configure logging
 
186
  import tempfile
187
 
188
 
189
+ @app.websocket("/wtranscribe")
190
  async def websocket_transcribe(websocket: WebSocket):
191
  logging.info("New WebSocket connection request received.")
192
  await websocket.accept()
 
195
  try:
196
  processed_segments = [] # Keeps track of the segments already transcribed
197
  accumulated_audio_size = 0 # Track how much audio data has been buffered
198
+ accumulated_audio_time = 0 # Track the total audio duration accumulated
199
+ min_transcription_time = 5.0 # Minimum duration of audio in seconds before transcription starts
200
 
201
  # A temporary file to store the growing audio data
202
  with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_audio_file:
 
210
  logging.warning("Received empty audio chunk, skipping processing.")
211
  continue
212
 
213
+ # Write audio chunk to file and accumulate size and time
214
  temp_audio_file.write(audio_chunk)
215
  temp_audio_file.flush()
216
  accumulated_audio_size += len(audio_chunk)
 
 
217
 
218
+ # Estimate the duration of the chunk based on its size (e.g., 16kHz audio)
219
+ chunk_duration = len(audio_chunk) / (16000 * 2) # Assuming 16kHz mono WAV (2 bytes per sample)
220
+ accumulated_audio_time += chunk_duration
221
+ logging.info(f"Received and buffered {len(audio_chunk)} bytes, total buffered: {accumulated_audio_size} bytes, total time: {accumulated_audio_time:.2f} seconds")
222
 
223
+ # Transcribe when enough time (audio) is accumulated (e.g., at least 5 seconds of audio)
224
+ if accumulated_audio_time >= min_transcription_time:
225
+ logging.info("Buffered enough audio time, starting transcription.")
226
+
227
+ # Call the transcription function with the last processed time
228
+ partial_result, processed_segments = transcribe_core_ws(temp_audio_file.name, processed_segments)
229
+ accumulated_audio_time = 0 # Reset the accumulated audio time
230
 
231
  # Send the transcription result back to the client
232
  logging.info(f"Sending {len(partial_result['new_segments'])} new segments to the client.")