AshDavid12 commited on
Commit
e37aac1
·
1 Parent(s): 544d523

changing chunks-try get logs in hf

Browse files
Files changed (2) hide show
  1. client.py +13 -14
  2. infer.py +2 -64
client.py CHANGED
@@ -16,7 +16,7 @@ async def send_audio(websocket):
16
  if response.status_code == 200:
17
  print("Starting to stream audio file...")
18
 
19
- for chunk in response.iter_content(chunk_size=512): # Stream in chunks
20
  if chunk:
21
  audio_buffer.extend(chunk)
22
  #print(f"Received audio chunk of size {len(chunk)} bytes.")
@@ -26,7 +26,7 @@ async def send_audio(websocket):
26
  await websocket.send(audio_buffer)
27
  #print(f"Sent {len(audio_buffer)} bytes of audio data.")
28
  audio_buffer.clear()
29
- await asyncio.sleep(0.001)
30
 
31
  print("Finished sending audio.")
32
  else:
@@ -37,12 +37,12 @@ async def receive_transcription(websocket):
37
  while True:
38
  try:
39
 
40
- transcription = await asyncio.wait_for(websocket.recv(), timeout=300)
41
  # Receive transcription from the server
42
  print(f"Transcription: {transcription}")
43
  except Exception as e:
44
  print(f"Error receiving transcription: {e}")
45
- await asyncio.sleep(30)
46
  break
47
 
48
 
@@ -54,7 +54,7 @@ async def send_heartbeat(websocket):
54
  except websockets.ConnectionClosed:
55
  print("Connection closed, stopping heartbeat")
56
  break
57
- await asyncio.sleep(3) # Send ping every 30 seconds (adjust as needed)
58
 
59
 
60
  async def run_client():
@@ -64,19 +64,18 @@ async def run_client():
64
  ssl_context.verify_mode = ssl.CERT_NONE
65
  while True:
66
  try:
67
- async with websockets.connect(uri, ssl=ssl_context, ping_timeout=500, ping_interval=20) as websocket:
68
  await asyncio.gather(
69
  send_audio(websocket),
70
  receive_transcription(websocket),
71
- #send_heartbeat(websocket)
72
  )
73
  except websockets.ConnectionClosedError as e:
74
  print(f"WebSocket closed with error: {e}")
75
- except Exception as e:
76
- print(f"Unexpected error: {e}")
77
-
78
- print("Reconnecting in 5 seconds...")
79
- await asyncio.sleep(5) # Wait 5 seconds before reconnecting
80
 
81
- if __name__ == "__main__":
82
- asyncio.run(run_client())
 
16
  if response.status_code == 200:
17
  print("Starting to stream audio file...")
18
 
19
+ for chunk in response.iter_content(chunk_size=1024): # Stream in chunks
20
  if chunk:
21
  audio_buffer.extend(chunk)
22
  #print(f"Received audio chunk of size {len(chunk)} bytes.")
 
26
  await websocket.send(audio_buffer)
27
  #print(f"Sent {len(audio_buffer)} bytes of audio data.")
28
  audio_buffer.clear()
29
+ await asyncio.sleep(0.01)
30
 
31
  print("Finished sending audio.")
32
  else:
 
37
  while True:
38
  try:
39
 
40
+ transcription = await websocket.recv()
41
  # Receive transcription from the server
42
  print(f"Transcription: {transcription}")
43
  except Exception as e:
44
  print(f"Error receiving transcription: {e}")
45
+ #await asyncio.sleep(30)
46
  break
47
 
48
 
 
54
  except websockets.ConnectionClosed:
55
  print("Connection closed, stopping heartbeat")
56
  break
57
+ await asyncio.sleep(30) # Send ping every 30 seconds (adjust as needed)
58
 
59
 
60
  async def run_client():
 
64
  ssl_context.verify_mode = ssl.CERT_NONE
65
  while True:
66
  try:
67
+ async with websockets.connect(uri, ssl=ssl_context, ping_timeout=1000, ping_interval=50) as websocket:
68
  await asyncio.gather(
69
  send_audio(websocket),
70
  receive_transcription(websocket),
71
+ send_heartbeat(websocket)
72
  )
73
  except websockets.ConnectionClosedError as e:
74
  print(f"WebSocket closed with error: {e}")
75
+ # except Exception as e:
76
+ # print(f"Unexpected error: {e}")
77
+ #
78
+ # print("Reconnecting in 5 seconds...")
79
+ # await asyncio.sleep(5) # Wait 5 seconds before reconnecting
80
 
81
+ asyncio.run(run_client())
 
infer.py CHANGED
@@ -13,8 +13,8 @@ import sys
13
  import asyncio
14
 
15
  # Configure logging
16
- #logging.basicConfig(level=logging.DEBUG, format='%(asctime)s %(levelname)s: %(message)s',handlers=[logging.StreamHandler(sys.stdout)])
17
- logging.getLogger("asyncio").setLevel(logging.DEBUG)
18
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
19
  logging.info(f'Device selected: {device}')
20
 
@@ -79,61 +79,6 @@ async def read_root():
79
  return {"message": "This is the Ivrit AI Streaming service."}
80
 
81
 
82
- @app.post("/transcribe")
83
- async def transcribe(input_data: InputData):
84
- logging.info(f'Received transcription request with data: {input_data}')
85
- datatype = input_data.type
86
- if not datatype:
87
- logging.error('datatype field not provided')
88
- raise HTTPException(status_code=400, detail="datatype field not provided. Should be 'blob' or 'url'.")
89
-
90
- if datatype not in ['blob', 'url']:
91
- logging.error(f'Invalid datatype: {datatype}')
92
- raise HTTPException(status_code=400, detail=f"datatype should be 'blob' or 'url', but is {datatype} instead.")
93
-
94
- with tempfile.TemporaryDirectory() as d:
95
- audio_file = f'{d}/audio.mp3'
96
- logging.debug(f'Created temporary directory: {d}')
97
-
98
- if datatype == 'blob':
99
- if not input_data.data:
100
- logging.error("Missing 'data' for 'blob' input")
101
- raise HTTPException(status_code=400, detail="Missing 'data' for 'blob' input.")
102
- logging.info('Decoding base64 blob data')
103
- mp3_bytes = base64.b64decode(input_data.data)
104
- open(audio_file, 'wb').write(mp3_bytes)
105
- logging.info(f'Audio file written: {audio_file}')
106
- elif datatype == 'url':
107
- if not input_data.url:
108
- logging.error("Missing 'url' for 'url' input")
109
- raise HTTPException(status_code=400, detail="Missing 'url' for 'url' input.")
110
- logging.info(f'Downloading file from URL: {input_data.url}')
111
- success = download_file(input_data.url, MAX_PAYLOAD_SIZE, audio_file, None)
112
- if not success:
113
- logging.error(f"Error downloading data from {input_data.url}")
114
- raise HTTPException(status_code=400, detail=f"Error downloading data from {input_data.url}")
115
-
116
- result = transcribe_core(audio_file)
117
- return {"result": result}
118
-
119
-
120
- def transcribe_core(audio_file):
121
- logging.info('Starting transcription...')
122
- ret = {'segments': []}
123
-
124
- segs, _ = model.transcribe(audio_file, language='he', word_timestamps=True)
125
- logging.info('Transcription completed')
126
-
127
- for s in segs:
128
- words = [{'start': w.start, 'end': w.end, 'word': w.word, 'probability': w.probability} for w in s.words]
129
- seg = {
130
- 'id': s.id, 'seek': s.seek, 'start': s.start, 'end': s.end, 'text': s.text, 'avg_logprob': s.avg_logprob,
131
- 'compression_ratio': s.compression_ratio, 'no_speech_prob': s.no_speech_prob, 'words': words
132
- }
133
- logging.info(f'Transcription segment: {seg}')
134
- ret['segments'].append(seg)
135
-
136
- return ret
137
 
138
 
139
  def transcribe_core_ws(audio_file, last_transcribed_time):
@@ -219,14 +164,7 @@ async def websocket_transcribe(websocket: WebSocket):
219
  # Estimate the duration of the chunk based on its size (e.g., 16kHz audio)
220
  chunk_duration = len(audio_chunk) / (16000 * 2) # Assuming 16kHz mono WAV (2 bytes per sample)
221
  accumulated_audio_time += chunk_duration
222
- #logging.info(f"Received and buffered {len(audio_chunk)} bytes, total buffered: {accumulated_audio_size} bytes, total time: {accumulated_audio_time:.2f} seconds")
223
-
224
- # Transcribe when enough time (audio) is accumulated (e.g., at least 5 seconds of audio)
225
- #if accumulated_audio_time >= min_transcription_time:
226
- #logging.info("Buffered enough audio time, starting transcription.")
227
-
228
 
229
- # Call the transcription function with the last processed time
230
  partial_result, last_transcribed_time = transcribe_core_ws(temp_audio_file.name, last_transcribed_time)
231
  accumulated_audio_time = 0 # Reset the accumulated audio time
232
  processed_segments.extend(partial_result['new_segments'])
 
13
  import asyncio
14
 
15
  # Configure logging
16
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s %(levelname)s: %(message)s',handlers=[logging.StreamHandler(sys.stdout)], force=True)
17
+ #logging.getLogger("asyncio").setLevel(logging.DEBUG)
18
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
19
  logging.info(f'Device selected: {device}')
20
 
 
79
  return {"message": "This is the Ivrit AI Streaming service."}
80
 
81
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
 
83
 
84
  def transcribe_core_ws(audio_file, last_transcribed_time):
 
164
  # Estimate the duration of the chunk based on its size (e.g., 16kHz audio)
165
  chunk_duration = len(audio_chunk) / (16000 * 2) # Assuming 16kHz mono WAV (2 bytes per sample)
166
  accumulated_audio_time += chunk_duration
 
 
 
 
 
 
167
 
 
168
  partial_result, last_transcribed_time = transcribe_core_ws(temp_audio_file.name, last_transcribed_time)
169
  accumulated_audio_time = 0 # Reset the accumulated audio time
170
  processed_segments.extend(partial_result['new_segments'])