AshDavid12 commited on
Commit
5a62402
·
1 Parent(s): aab7acf

complete change

Browse files
Files changed (2) hide show
  1. client.py +163 -70
  2. infer.py +233 -119
client.py CHANGED
@@ -2,80 +2,173 @@ import asyncio
2
  import websockets
3
  import requests
4
  import ssl
 
 
 
5
 
6
  # Parameters for reading and sending the audio
7
  AUDIO_FILE_URL = "https://raw.githubusercontent.com/AshDavid12/runpod-serverless-forked/main/test_hebrew.wav" # Use WAV file
8
  #AUDIO_FILE_URL = "https://raw.githubusercontent.com/AshDavid12/hugging_face_ivrit_streaming/main/long_hebrew.wav"
9
 
10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  async def send_audio(websocket):
12
- buffer_size = 512 * 1024 #HAVE TO HAVE 512!!
13
- audio_buffer = bytearray()
14
-
15
- with requests.get(AUDIO_FILE_URL, stream=True, allow_redirects=False) as response:
16
- if response.status_code == 200:
17
- print("Starting to stream audio file...")
18
-
19
- for chunk in response.iter_content(chunk_size=1024): # Stream in chunks
20
- if chunk:
21
- audio_buffer.extend(chunk)
22
- #print(f"Received audio chunk of size {len(chunk)} bytes.")
23
-
24
- # Send buffered audio data once it's large enough
25
- #if len(audio_buffer) >= buffer_size:
26
- await websocket.send(audio_buffer)
27
- #print(f"Sent {len(audio_buffer)} bytes of audio data.")
28
- audio_buffer.clear()
29
- await asyncio.sleep(0.01)
30
-
31
- print("Finished sending audio.")
32
- else:
33
- print(f"Failed to download audio file. Status code: {response.status_code}")
34
-
35
-
36
- async def receive_transcription(websocket):
37
- while True:
38
- try:
39
-
40
- transcription = await websocket.recv()
41
- # Receive transcription from the server
42
- print(f"Transcription: {transcription}")
43
- except Exception as e:
44
- print(f"Error receiving transcription: {e}")
45
- #await asyncio.sleep(30)
46
- break
47
-
48
-
49
- async def send_heartbeat(websocket):
50
- while True:
51
- try:
52
- await websocket.ping()
53
- print("Sent keepalive ping")
54
- except websockets.ConnectionClosed:
55
- print("Connection closed, stopping heartbeat")
56
- break
57
- await asyncio.sleep(30) # Send ping every 30 seconds (adjust as needed)
58
-
59
-
60
- async def run_client():
61
- uri = ("wss://gigaverse-ivrit-ai-streaming.hf.space/wtranscribe") # WebSocket URL
62
- ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
63
- ssl_context.check_hostname = False
64
- ssl_context.verify_mode = ssl.CERT_NONE
65
- while True:
66
- try:
67
- async with websockets.connect(uri, ssl=ssl_context, ping_timeout=1000, ping_interval=50) as websocket:
68
- await asyncio.gather(
69
- send_audio(websocket),
70
- receive_transcription(websocket),
71
- send_heartbeat(websocket)
72
- )
73
- except websockets.ConnectionClosedError as e:
74
- print(f"WebSocket closed with error: {e}")
75
- # except Exception as e:
76
- # print(f"Unexpected error: {e}")
77
- #
78
- # print("Reconnecting in 5 seconds...")
79
- # await asyncio.sleep(5) # Wait 5 seconds before reconnecting
80
-
81
- asyncio.run(run_client())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  import websockets
3
  import requests
4
  import ssl
5
+ import wave
6
+ import logging
7
+ import sys
8
 
9
  # Parameters for reading and sending the audio
10
  AUDIO_FILE_URL = "https://raw.githubusercontent.com/AshDavid12/runpod-serverless-forked/main/test_hebrew.wav" # Use WAV file
11
  #AUDIO_FILE_URL = "https://raw.githubusercontent.com/AshDavid12/hugging_face_ivrit_streaming/main/long_hebrew.wav"
12
 
13
 
14
+
15
+ # Set up logging
16
+ logging.basicConfig(level=logging.DEBUG, format='%(asctime)s %(levelname)s: %(message)s',
17
+ handlers=[logging.StreamHandler(sys.stdout)], force=True)
18
+ logger = logging.getLogger(__name__)
19
+
20
+ async def send_receive():
21
+ uri = "wss://gigaverse-ivrit-ai-streaming.hf.space/ws" # Update with your server's address if needed
22
+ logger.info(f"Connecting to server at {uri}")
23
+ try:
24
+ async with websockets.connect(uri) as websocket:
25
+ logger.info("WebSocket connection established")
26
+ # Start tasks for sending and receiving
27
+ send_task = asyncio.create_task(send_audio(websocket))
28
+ receive_task = asyncio.create_task(receive_transcriptions(websocket))
29
+ await asyncio.gather(send_task, receive_task)
30
+ except Exception as e:
31
+ logger.error(f"WebSocket connection error: {e}")
32
+
33
  async def send_audio(websocket):
34
+ wav_file = 'path/to/your/audio.wav' # Replace with the path to your WAV file
35
+ logger.info(f"Opening WAV file: {wav_file}")
36
+
37
+ try:
38
+ # Open the WAV file
39
+ wf = wave.open(wav_file, 'rb')
40
+
41
+ # Log WAV file parameters
42
+ channels = wf.getnchannels()
43
+ sampwidth = wf.getsampwidth()
44
+ framerate = wf.getframerate()
45
+ nframes = wf.getnframes()
46
+ duration = nframes / framerate
47
+ logger.debug(f"WAV file parameters: channels={channels}, sample_width={sampwidth}, framerate={framerate}, frames={nframes}, duration={duration:.2f}s")
48
+
49
+ # Ensure the WAV file has the expected parameters
50
+ if channels != 1 or sampwidth != 2 or framerate != 16000:
51
+ logger.error("WAV file must be mono channel, 16-bit samples, 16kHz sampling rate")
52
+ return
53
+
54
+ chunk_duration = 0.1 # in seconds
55
+ chunk_size = int(framerate * chunk_duration)
56
+ logger.info(f"Starting to send audio data in chunks of {chunk_duration}s ({chunk_size} frames)")
57
+
58
+ total_chunks = 0
59
+ total_bytes_sent = 0
60
+
61
+ while True:
62
+ data = wf.readframes(chunk_size)
63
+ if not data:
64
+ logger.info("End of WAV file reached")
65
+ break
66
+ await websocket.send(data)
67
+ total_chunks += 1
68
+ total_bytes_sent += len(data)
69
+ logger.debug(f"Sent chunk {total_chunks}: {len(data)} bytes")
70
+ await asyncio.sleep(chunk_duration) # Simulate real-time streaming
71
+
72
+ logger.info(f"Finished sending audio data: {total_chunks} chunks sent, total bytes sent: {total_bytes_sent}")
73
+ except Exception as e:
74
+ logger.error(f"Send audio error: {e}")
75
+ finally:
76
+ wf.close()
77
+ logger.info("WAV file closed")
78
+
79
+ async def receive_transcriptions(websocket):
80
+ try:
81
+ logger.info("Starting to receive transcriptions")
82
+ async for message in websocket: # This is the same as websocket.recv()
83
+ logger.info(f"Received transcription: {message}")
84
+ print(f"Transcription: {message}")
85
+ except Exception as e:
86
+ logger.error(f"Receive transcription error: {e}")
87
+
88
+ if __name__ == "__main__":
89
+ asyncio.run(send_receive())
90
+
91
+
92
+
93
+
94
+
95
+
96
+
97
+
98
+
99
+
100
+
101
+
102
+
103
+
104
+ # async def send_audio(websocket):
105
+ # buffer_size = 512 * 1024 #HAVE TO HAVE 512!!
106
+ # audio_buffer = bytearray()
107
+ #
108
+ # with requests.get(AUDIO_FILE_URL, stream=True, allow_redirects=False) as response:
109
+ # if response.status_code == 200:
110
+ # print("Starting to stream audio file...")
111
+ #
112
+ # for chunk in response.iter_content(chunk_size=1024): # Stream in chunks
113
+ # if chunk:
114
+ # audio_buffer.extend(chunk)
115
+ # #print(f"Received audio chunk of size {len(chunk)} bytes.")
116
+ #
117
+ # # Send buffered audio data once it's large enough
118
+ # if len(audio_buffer) >= buffer_size:
119
+ # await websocket.send(audio_buffer)
120
+ # #print(f"Sent {len(audio_buffer)} bytes of audio data.")
121
+ # audio_buffer.clear()
122
+ # await asyncio.sleep(0.01)
123
+ #
124
+ # print("Finished sending audio.")
125
+ # else:
126
+ # print(f"Failed to download audio file. Status code: {response.status_code}")
127
+ #
128
+ #
129
+ # async def receive_transcription(websocket):
130
+ # while True:
131
+ # try:
132
+ #
133
+ # transcription = await websocket.recv()
134
+ # # Receive transcription from the server
135
+ # print(f"Transcription: {transcription}")
136
+ # except Exception as e:
137
+ # print(f"Error receiving transcription: {e}")
138
+ # #await asyncio.sleep(30)
139
+ # break
140
+ #
141
+ #
142
+ # async def send_heartbeat(websocket):
143
+ # while True:
144
+ # try:
145
+ # await websocket.ping()
146
+ # print("Sent keepalive ping")
147
+ # except websockets.ConnectionClosed:
148
+ # print("Connection closed, stopping heartbeat")
149
+ # break
150
+ # await asyncio.sleep(30) # Send ping every 30 seconds (adjust as needed)
151
+ #
152
+ #
153
+ # async def run_client():
154
+ # uri = ("wss://gigaverse-ivrit-ai-streaming.hf.space/wtranscribe") # WebSocket URL
155
+ # ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
156
+ # ssl_context.check_hostname = False
157
+ # ssl_context.verify_mode = ssl.CERT_NONE
158
+ # while True:
159
+ # try:
160
+ # async with websockets.connect(uri, ssl=ssl_context, ping_timeout=1000, ping_interval=50) as websocket:
161
+ # await asyncio.gather(
162
+ # send_audio(websocket),
163
+ # receive_transcription(websocket),
164
+ # send_heartbeat(websocket)
165
+ # )
166
+ # except websockets.ConnectionClosedError as e:
167
+ # print(f"WebSocket closed with error: {e}")
168
+ # # except Exception as e:
169
+ # # print(f"Unexpected error: {e}")
170
+ # #
171
+ # # print("Reconnecting in 5 seconds...")
172
+ # # await asyncio.sleep(5) # Wait 5 seconds before reconnecting
173
+ #
174
+ # asyncio.run(run_client())
infer.py CHANGED
@@ -1,6 +1,8 @@
1
  import base64
2
  import faster_whisper
3
  import tempfile
 
 
4
  import torch
5
  import time
6
  import requests
@@ -15,6 +17,7 @@ import asyncio
15
  # Configure logging
16
  logging.basicConfig(level=logging.INFO, format='%(asctime)s %(levelname)s: %(message)s',
17
  handlers=[logging.StreamHandler(sys.stdout)], force=True)
 
18
  #logging.getLogger("asyncio").setLevel(logging.DEBUG)
19
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
20
  logging.info(f'Device selected: {device}')
@@ -31,49 +34,49 @@ logging.info(f'Max payload size set to: {MAX_PAYLOAD_SIZE} bytes')
31
  app = FastAPI()
32
 
33
 
34
- class InputData(BaseModel):
35
- type: str
36
- data: Optional[str] = None # Used for blob input
37
- url: Optional[str] = None # Used for url input
38
-
39
-
40
- def download_file(url, max_size_bytes, output_filename, api_key=None):
41
- """
42
- Download a file from a given URL with size limit and optional API key.
43
- """
44
- logging.debug(f'Starting file download from URL: {url}')
45
- try:
46
- headers = {}
47
- if api_key:
48
- headers['Authorization'] = f'Bearer {api_key}'
49
- logging.debug('API key provided, added to headers')
50
-
51
- response = requests.get(url, stream=True, headers=headers)
52
- response.raise_for_status()
53
-
54
- file_size = int(response.headers.get('Content-Length', 0))
55
- logging.info(f'File size: {file_size} bytes')
56
-
57
- if file_size > max_size_bytes:
58
- logging.error(f'File size exceeds limit: {file_size} > {max_size_bytes}')
59
- return False
60
-
61
- downloaded_size = 0
62
- with open(output_filename, 'wb') as file:
63
- for chunk in response.iter_content(chunk_size=8192):
64
- downloaded_size += len(chunk)
65
- logging.debug(f'Downloaded {downloaded_size} bytes')
66
- if downloaded_size > max_size_bytes:
67
- logging.error('Downloaded size exceeds maximum allowed payload size')
68
- return False
69
- file.write(chunk)
70
-
71
- logging.info(f'File downloaded successfully: {output_filename}')
72
- return True
73
-
74
- except requests.RequestException as e:
75
- logging.error(f"Error downloading file: {e}")
76
- return False
77
 
78
 
79
  @app.get("/")
@@ -81,97 +84,208 @@ async def read_root():
81
  return {"message": "This is the Ivrit AI Streaming service."}
82
 
83
 
84
- async def transcribe_core_ws(audio_file):
85
- """
86
- Transcribe the audio file and return only the segments that have not been processed yet.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
 
88
- :param audio_file: Path to the growing audio file.
89
- :param last_transcribed_time: The last time (in seconds) that was transcribed.
90
- :return: Newly transcribed segments and the updated last transcribed time.
91
- """
92
- ret = {'segments': []}
93
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
  try:
95
- # Transcribe the entire audio file
96
- logging.debug(f"Initiating model transcription for file: {audio_file}")
97
 
98
- segs, _ = await asyncio.to_thread(model.transcribe, audio_file, language='he', word_timestamps=True)
99
- logging.info('Transcription completed successfully.')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
  except Exception as e:
101
- logging.error(f"Error during transcription: {e}")
102
- raise e
103
 
104
- # Track the new segments and update the last transcribed time
105
- for s in segs:
106
- logging.info(f"Processing segment with start time: {s.start} and end time: {s.end}")
 
 
107
 
108
- # Only process segments that start after the last transcribed time
109
- logging.info(f"New segment found starting at {s.start} seconds.")
110
- words = [{'start': w.start, 'end': w.end, 'word': w.word, 'probability': w.probability} for w in s.words]
111
 
112
- seg = {
113
- 'id': s.id, 'seek': s.seek, 'start': s.start, 'end': s.end, 'text': s.text,
114
- 'avg_logprob': s.avg_logprob, 'compression_ratio': s.compression_ratio,
115
- 'no_speech_prob': s.no_speech_prob, 'words': words
116
- }
117
- logging.info(f'Adding new transcription segment: {seg}')
118
- ret['segements'].append(seg)
119
 
120
- # Update the last transcribed time to the end of the current segment
121
 
122
 
123
- #logging.info(f"Returning {len(ret['new_segments'])} new segments and updated last transcribed time.")
124
- return ret
125
 
126
 
127
- import tempfile
128
 
129
 
130
- @app.websocket("/wtranscribe")
131
- async def websocket_transcribe(websocket: WebSocket):
132
- logging.info("New WebSocket connection request received.")
133
- await websocket.accept()
134
- logging.info("WebSocket connection established successfully.")
135
 
136
- try:
137
- processed_segments = [] # Keeps track of the segments already transcribed
138
- accumulated_audio_size = 0 # Track how much audio data has been buffered
139
- accumulated_audio_time = 0 # Track the total audio duration accumulated
140
- last_transcribed_time = 0.0
141
- #min_transcription_time = 5.0 # Minimum duration of audio in seconds before transcription starts
142
-
143
- # A temporary file to store the growing audio data
144
-
145
- while True:
146
- try:
147
- with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_audio_file: ##new temp file for every chunk
148
- logging.info(f"Temporary audio file created at {temp_audio_file.name}")
149
- # Receive the next chunk of audio data
150
- audio_chunk = await websocket.receive_bytes()
151
- if not audio_chunk:
152
- logging.warning("Received empty audio chunk, skipping processing.")
153
- continue
154
-
155
- # Write audio chunk to file and accumulate size and time
156
- temp_audio_file.write(audio_chunk)
157
- temp_audio_file.flush()
158
- accumulated_audio_size += len(audio_chunk)
159
-
160
- # Estimate the duration of the chunk based on its size (e.g., 16kHz audio)
161
- chunk_duration = len(audio_chunk) / (16000 * 2) # Assuming 16kHz mono WAV (2 bytes per sample)
162
- accumulated_audio_time += chunk_duration
163
-
164
- partial_result = await transcribe_core_ws(temp_audio_file)
165
- accumulated_audio_time = 0 # Reset the accumulated audio time
166
- await websocket.send_json(partial_result)
167
-
168
- except WebSocketDisconnect:
169
- logging.info("WebSocket connection closed by the client.")
170
- break
171
 
172
- except Exception as e:
173
- logging.error(f"Unexpected error during WebSocket transcription: {e}")
174
- await websocket.send_json({"error": str(e)})
175
 
176
- finally:
177
- logging.info("Cleaning up and closing WebSocket connection.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import base64
2
  import faster_whisper
3
  import tempfile
4
+
5
+ import numpy as np
6
  import torch
7
  import time
8
  import requests
 
17
  # Configure logging
18
  logging.basicConfig(level=logging.INFO, format='%(asctime)s %(levelname)s: %(message)s',
19
  handlers=[logging.StreamHandler(sys.stdout)], force=True)
20
+ logger = logging.getLogger(__name__)
21
  #logging.getLogger("asyncio").setLevel(logging.DEBUG)
22
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
23
  logging.info(f'Device selected: {device}')
 
34
  app = FastAPI()
35
 
36
 
37
+ # class InputData(BaseModel):
38
+ # type: str
39
+ # data: Optional[str] = None # Used for blob input
40
+ # url: Optional[str] = None # Used for url input
41
+ #
42
+ #
43
+ # def download_file(url, max_size_bytes, output_filename, api_key=None):
44
+ # """
45
+ # Download a file from a given URL with size limit and optional API key.
46
+ # """
47
+ # logging.debug(f'Starting file download from URL: {url}')
48
+ # try:
49
+ # headers = {}
50
+ # if api_key:
51
+ # headers['Authorization'] = f'Bearer {api_key}'
52
+ # logging.debug('API key provided, added to headers')
53
+ #
54
+ # response = requests.get(url, stream=True, headers=headers)
55
+ # response.raise_for_status()
56
+ #
57
+ # file_size = int(response.headers.get('Content-Length', 0))
58
+ # logging.info(f'File size: {file_size} bytes')
59
+ #
60
+ # if file_size > max_size_bytes:
61
+ # logging.error(f'File size exceeds limit: {file_size} > {max_size_bytes}')
62
+ # return False
63
+ #
64
+ # downloaded_size = 0
65
+ # with open(output_filename, 'wb') as file:
66
+ # for chunk in response.iter_content(chunk_size=8192):
67
+ # downloaded_size += len(chunk)
68
+ # logging.debug(f'Downloaded {downloaded_size} bytes')
69
+ # if downloaded_size > max_size_bytes:
70
+ # logging.error('Downloaded size exceeds maximum allowed payload size')
71
+ # return False
72
+ # file.write(chunk)
73
+ #
74
+ # logging.info(f'File downloaded successfully: {output_filename}')
75
+ # return True
76
+ #
77
+ # except requests.RequestException as e:
78
+ # logging.error(f"Error downloading file: {e}")
79
+ # return False
80
 
81
 
82
  @app.get("/")
 
84
  return {"message": "This is the Ivrit AI Streaming service."}
85
 
86
 
87
+ # async def transcribe_core_ws(audio_file):
88
+ # ret = {'segments': []}
89
+ #
90
+ # try:
91
+ #
92
+ # logging.debug(f"Initiating model transcription for file: {audio_file}")
93
+ #
94
+ # segs, _ = await asyncio.to_thread(model.transcribe, audio_file, language='he', word_timestamps=True)
95
+ # logging.info('Transcription completed successfully.')
96
+ # except Exception as e:
97
+ # logging.error(f"Error during transcription: {e}")
98
+ # raise e
99
+ #
100
+ # # Track the new segments and update the last transcribed time
101
+ # for s in segs:
102
+ # logging.info(f"Processing segment with start time: {s.start} and end time: {s.end}")
103
+ #
104
+ # # Only process segments that start after the last transcribed time
105
+ # logging.info(f"New segment found starting at {s.start} seconds.")
106
+ # words = [{'start': w.start, 'end': w.end, 'word': w.word, 'probability': w.probability} for w in s.words]
107
+ #
108
+ # seg = {
109
+ # 'id': s.id, 'seek': s.seek, 'start': s.start, 'end': s.end, 'text': s.text,
110
+ # 'avg_logprob': s.avg_logprob, 'compression_ratio': s.compression_ratio,
111
+ # 'no_speech_prob': s.no_speech_prob, 'words': words
112
+ # }
113
+ # logging.info(f'Adding new transcription segment: {seg}')
114
+ # ret['segements'].append(seg)
115
+ #
116
+ # # Update the last transcribed time to the end of the current segment
117
+ #
118
+ #
119
+ # #logging.info(f"Returning {len(ret['new_segments'])} new segments and updated last transcribed time.")
120
+ # return ret
121
+
122
+
123
+ import tempfile
124
 
 
 
 
 
 
125
 
126
+ @app.websocket("/ws")
127
+ async def websocket_endpoint(websocket: WebSocket):
128
+ """WebSocket endpoint to handle client connections."""
129
+ await websocket.accept()
130
+ client_ip = websocket.client.host
131
+ logger.info(f"Client connected: {client_ip}")
132
+ try:
133
+ await process_audio_stream(websocket)
134
+ except WebSocketDisconnect:
135
+ logger.info(f"Client disconnected: {client_ip}")
136
+ except Exception as e:
137
+ logger.error(f"Unexpected error: {e}")
138
+ await websocket.close()
139
+
140
+ async def process_audio_stream(websocket: WebSocket):
141
+ """Continuously receive audio chunks and initiate transcription tasks."""
142
+ sampling_rate = 16000
143
+ min_chunk_size = 1 # in seconds
144
+ audio_buffer = np.array([], dtype=np.float32)
145
+
146
+ transcription_task = None
147
+ chunk_counter = 0
148
+ total_bytes_received = 0
149
+
150
+ while True:
151
+ try:
152
+ # Receive audio data from client
153
+ data = await websocket.receive_bytes()
154
+ if not data:
155
+ logger.info("No data received, closing connection")
156
+ break
157
+ chunk_counter += 1
158
+ chunk_size = len(data)
159
+ total_bytes_received += chunk_size
160
+ logger.debug(f"Received chunk {chunk_counter}: {chunk_size} bytes")
161
+
162
+ audio_chunk = process_received_audio(data)
163
+ logger.debug(f"Processed audio chunk {chunk_counter}: {len(audio_chunk)} samples")
164
+
165
+ audio_buffer = np.concatenate((audio_buffer, audio_chunk))
166
+ logger.debug(f"Audio buffer size: {len(audio_buffer)} samples")
167
+ except Exception as e:
168
+ logger.error(f"Error receiving data: {e}")
169
+ break
170
+
171
+ # Check if enough audio has been buffered
172
+ if len(audio_buffer) >= min_chunk_size * sampling_rate:
173
+ if transcription_task is None or transcription_task.done():
174
+ # Start a new transcription task
175
+ logger.info(f"Starting transcription task for {len(audio_buffer)} samples")
176
+ transcription_task = asyncio.create_task(
177
+ transcribe_and_send(websocket, audio_buffer.copy())
178
+ )
179
+ audio_buffer = np.array([], dtype=np.float32)
180
+ logger.debug("Audio buffer reset after starting transcription task")
181
+
182
+ async def transcribe_and_send(websocket: WebSocket, audio_data):
183
+ """Run transcription in a separate thread and send the result to the client."""
184
+ logger.debug(f"Transcription task started for {len(audio_data)} samples")
185
+ transcription_result = await asyncio.to_thread(sync_transcribe_audio, audio_data)
186
+ if transcription_result:
187
+ try:
188
+ # Send the result as JSON
189
+ await websocket.send_json(transcription_result)
190
+ logger.info("Transcription JSON sent to client")
191
+ except Exception as e:
192
+ logger.error(f"Error sending transcription: {e}")
193
+ else:
194
+ logger.warning("No transcription result to send")
195
+
196
+ def sync_transcribe_audio(audio_data):
197
+ """Synchronously transcribe audio data using the ASR model and format the result."""
198
  try:
 
 
199
 
200
+ logger.info('Starting transcription...')
201
+ segments, info = model.transcribe(
202
+ audio_data, language="he", beam_size=5, word_timestamps=True
203
+ )
204
+ logger.info('Transcription completed')
205
+
206
+ # Build the transcription result as per your requirement
207
+ ret = {'segments': []}
208
+
209
+ for s in segments:
210
+ logger.debug(f"Processing segment {s.id} with start time: {s.start} and end time: {s.end}")
211
+
212
+ # Process words in the segment
213
+ words = [{
214
+ 'start': float(w.start),
215
+ 'end': float(w.end),
216
+ 'word': w.word,
217
+ 'probability': float(w.probability)
218
+ } for w in s.words]
219
+
220
+ seg = {
221
+ 'id': int(s.id),
222
+ 'seek': int(s.seek),
223
+ 'start': float(s.start),
224
+ 'end': float(s.end),
225
+ 'text': s.text,
226
+ 'avg_logprob': float(s.avg_logprob),
227
+ 'compression_ratio': float(s.compression_ratio),
228
+ 'no_speech_prob': float(s.no_speech_prob),
229
+ 'words': words
230
+ }
231
+ logger.debug(f'Adding new transcription segment: {seg}')
232
+ ret['segments'].append(seg)
233
+
234
+ logger.debug(f"Total segments in transcription result: {len(ret['segments'])}")
235
+ return ret
236
  except Exception as e:
237
+ logger.error(f"Transcription error: {e}")
238
+ return {}
239
 
240
+ def process_received_audio(data):
241
+ """Convert received bytes into normalized float32 NumPy array."""
242
+ logger.debug(f"Processing received audio data of size {len(data)} bytes")
243
+ audio_int16 = np.frombuffer(data, dtype=np.int16)
244
+ logger.debug(f"Converted to int16 NumPy array with {len(audio_int16)} samples")
245
 
246
+ audio_float32 = audio_int16.astype(np.float32) / 32768.0 # Normalize to [-1, 1]
247
+ logger.debug(f"Normalized audio data to float32 with {len(audio_float32)} samples")
 
248
 
249
+ return audio_float32
 
 
 
 
 
 
250
 
 
251
 
252
 
 
 
253
 
254
 
 
255
 
256
 
 
 
 
 
 
257
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
258
 
 
 
 
259
 
260
+ # @app.websocket("/wtranscribe")
261
+ # async def websocket_transcribe(websocket: WebSocket):
262
+ # logging.info("New WebSocket connection request received.")
263
+ # await websocket.accept()
264
+ # logging.info("WebSocket connection established successfully.")
265
+ #
266
+ # try:
267
+ # while True:
268
+ # try:
269
+ # audio_chunk = await websocket.receive_bytes()
270
+ # if not audio_chunk:
271
+ # logging.warning("Received empty audio chunk, skipping processing.")
272
+ # continue
273
+ # with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_audio_file: ##new temp file for every chunk
274
+ # logging.info(f"Temporary audio file created at {temp_audio_file.name}")
275
+ # # Receive the next chunk of audio data
276
+ #
277
+ #
278
+ #
279
+ # partial_result = await transcribe_core_ws(temp_audio_file.name)
280
+ # await websocket.send_json(partial_result)
281
+ #
282
+ # except WebSocketDisconnect:
283
+ # logging.info("WebSocket connection closed by the client.")
284
+ # break
285
+ #
286
+ # except Exception as e:
287
+ # logging.error(f"Unexpected error during WebSocket transcription: {e}")
288
+ # await websocket.send_json({"error": str(e)})
289
+ #
290
+ # finally:
291
+ # logging.info("Cleaning up and closing WebSocket connection.")