aviadr1 commited on
Commit
9d710fb
·
1 Parent(s): e8aa012
Files changed (4) hide show
  1. faster-whisper-server-client.py +104 -42
  2. pyproject.toml +1 -1
  3. ws_client.py +288 -0
  4. ws_server.py +111 -47
faster-whisper-server-client.py CHANGED
@@ -2,6 +2,9 @@ import argparse
2
  import json
3
  import threading
4
  import time
 
 
 
5
  import websocket
6
  import os
7
 
@@ -31,34 +34,91 @@ def parse_arguments():
31
  return parser.parse_args()
32
 
33
 
34
- def preprocess_audio(audio_file, target_sr=16000):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  """
36
- Load the audio file, convert to mono 16kHz, and return the audio data.
 
 
 
 
 
 
 
 
 
 
 
 
37
  """
38
- if audio_file.endswith(".mp3"):
39
  # Convert MP3 to WAV using ffmpeg
40
- wav_file = audio_file.replace(".mp3", ".wav")
41
- if not os.path.exists(wav_file):
42
  command = f'ffmpeg -i "{audio_file}" -ac 1 -ar {target_sr} "{wav_file}"'
43
  print(f"Converting MP3 to WAV: {command}")
44
  os.system(command)
45
  audio_file = wav_file
46
 
47
- print(f"Loading audio file {audio_file}")
48
- audio_data, sr = librosa.load(audio_file, sr=target_sr, mono=True)
49
- return audio_data, sr
50
 
51
- def chunk_audio(audio_data, sr, chunk_duration):
52
- """
53
- Split the audio data into chunks of specified duration.
54
- """
55
- chunk_samples = int(chunk_duration * sr)
56
- total_samples = len(audio_data)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
  chunks = [
58
- audio_data[i:i + chunk_samples]
59
- for i in range(0, total_samples, chunk_samples)
60
  ]
61
- print(f"Split audio into {len(chunks)} chunks of {chunk_duration} seconds each.")
62
  return chunks
63
 
64
 
@@ -184,31 +244,33 @@ def run_websocket_client(args):
184
  """
185
  Run the WebSocket client to stream audio and receive transcriptions.
186
  """
187
- audio_data, sr = preprocess_audio(args.audio_file)
188
- audio_chunks = chunk_audio(audio_data, sr, args.chunk_duration)
189
-
190
- params = build_query_params(args)
191
- ws_url = websocket_url_with_params(args.url, params)
192
-
193
- ws = websocket.WebSocketApp(
194
- ws_url,
195
- on_open=on_open,
196
- on_message=on_message,
197
- on_error=on_error,
198
- on_close=on_close,
199
- )
200
- ws.args = args # Attach args to ws to access inside callbacks
201
-
202
- # Run the WebSocket in a separate thread to allow sending and receiving simultaneously
203
- ws_thread = threading.Thread(target=ws.run_forever)
204
- ws_thread.start()
205
-
206
- # Wait for the connection to open
207
- while not ws.sock or not ws.sock.connected:
208
- time.sleep(0.1)
209
-
210
- # Send the audio chunks
211
- send_audio_chunks(ws, audio_chunks, sr)
 
 
212
 
213
  # Wait for the WebSocket thread to finish
214
  ws_thread.join()
 
2
  import json
3
  import threading
4
  import time
5
+ from pathlib import Path
6
+ from typing import List
7
+
8
  import websocket
9
  import os
10
 
 
34
  return parser.parse_args()
35
 
36
 
37
+ # def preprocess_audio(audio_file, target_sr=16000):
38
+ # """
39
+ # Load the audio file, convert to mono 16kHz, and return the audio data.
40
+ # """
41
+ # if audio_file.endswith(".mp3"):
42
+ # # Convert MP3 to WAV using ffmpeg
43
+ # wav_file = audio_file.replace(".mp3", ".wav")
44
+ # if not os.path.exists(wav_file):
45
+ # command = f'ffmpeg -i "{audio_file}" -ac 1 -ar {target_sr} "{wav_file}"'
46
+ # print(f"Converting MP3 to WAV: {command}")
47
+ # os.system(command)
48
+ # audio_file = wav_file
49
+ #
50
+ # print(f"Loading audio file {audio_file}")
51
+ # audio_data, sr = librosa.load(audio_file, sr=target_sr, mono=True)
52
+ # return audio_data, sr
53
+ #
54
+ # def chunk_audio(audio_data, sr, chunk_duration):
55
+ # """
56
+ # Split the audio data into chunks of specified duration.
57
+ # """
58
+ # chunk_samples = int(chunk_duration * sr)
59
+ # total_samples = len(audio_data)
60
+ # chunks = [
61
+ # audio_data[i:i + chunk_samples]
62
+ # for i in range(0, total_samples, chunk_samples)
63
+ # ]
64
+ # print(f"Split audio into {len(chunks)} chunks of {chunk_duration} seconds each.")
65
+ # return chunks
66
+
67
+
68
+ def read_audio_in_chunks(audio_file, target_sr=16000, chunk_duration=1) -> List[np.ndarray]:
69
  """
70
+ Reads a 16kHz mono audio file in 1-second chunks and returns them as little-endian 16-bit integer arrays.
71
+
72
+ Args:
73
+ file_path (str): Path to the audio file.
74
+ expected_sr (int): Expected sample rate (16000 by default).
75
+ expected_mono (bool): Expect the file to be mono (True by default).
76
+ chunk_duration (int): Duration of each chunk in seconds (1 second by default).
77
+
78
+ Returns:
79
+ List of numpy arrays: Each array is a 1-second chunk of the audio as 16-bit integers.
80
+
81
+ Raises:
82
+ ValueError: If the audio file's sample rate or number of channels doesn't match expectations.
83
  """
84
+ if not str(audio_file).endswith(".wav"):
85
  # Convert MP3 to WAV using ffmpeg
86
+ wav_file = Path(audio_file).with_suffix(".wav")
87
+ if not wav_file.exists():
88
  command = f'ffmpeg -i "{audio_file}" -ac 1 -ar {target_sr} "{wav_file}"'
89
  print(f"Converting MP3 to WAV: {command}")
90
  os.system(command)
91
  audio_file = wav_file
92
 
93
+ # Load the audio file
94
+ audio_data, sr = librosa.load(audio_file, sr=None, mono=True)
 
95
 
96
+ # Raise an exception if the sample rate doesn't match
97
+ if sr != target_sr:
98
+ raise ValueError(f"Unexpected sample rate {sr}. Expected {target_sr}.")
99
+
100
+ # Convert the audio data to 16-bit PCM (little-endian)
101
+ audio_data_int16 = (audio_data * 32767).astype(np.int16)
102
+
103
+ # Check if the current byte order is little-endian
104
+ if audio_data_int16.dtype.byteorder == '>' or (
105
+ audio_data_int16.dtype.byteorder == '=' and np.dtype(np.int16).byteorder == '>'):
106
+ print("Byte swap performed to convert to little-endian.")
107
+ # Ensure little-endian format (if the current format is big-endian)
108
+ audio_data_little_endian = audio_data_int16.byteswap().newbyteorder('L')
109
+ else:
110
+ print("No byte swap needed. Already little-endian.")
111
+ audio_data_little_endian = audio_data_int16
112
+
113
+ # Calculate the number of samples per chunk
114
+ samples_per_chunk = target_sr * chunk_duration
115
+
116
+ # Split the audio into chunks
117
  chunks = [
118
+ audio_data_little_endian[i:i + samples_per_chunk]
119
+ for i in range(0, len(audio_data_little_endian), samples_per_chunk)
120
  ]
121
+
122
  return chunks
123
 
124
 
 
244
  """
245
  Run the WebSocket client to stream audio and receive transcriptions.
246
  """
247
+ try:
248
+ audio_chunks = read_audio_in_chunks(args.audio_file)
249
+
250
+ params = build_query_params(args)
251
+ ws_url = websocket_url_with_params(args.url, params)
252
+
253
+ ws = websocket.WebSocketApp(
254
+ ws_url,
255
+ on_open=on_open,
256
+ on_message=on_message,
257
+ on_error=on_error,
258
+ on_close=on_close,
259
+ )
260
+ ws.args = args # Attach args to ws to access inside callbacks
261
+
262
+ # Run the WebSocket in a separate thread to allow sending and receiving simultaneously
263
+ ws_thread = threading.Thread(target=ws.run_forever)
264
+ ws_thread.start()
265
+
266
+ # Wait for the connection to open
267
+ while not ws.sock or not ws.sock.connected:
268
+ time.sleep(0.1)
269
+
270
+ # Send the audio chunks
271
+ send_audio_chunks(ws, audio_chunks, 16000)
272
+ except Exception as e:
273
+ print(f"An error occurred: {e}")
274
 
275
  # Wait for the WebSocket thread to finish
276
  ws_thread.join()
pyproject.toml CHANGED
@@ -32,7 +32,7 @@ transformers = "^4.44.2"
32
  soundfile = "^0.12.1"
33
  faster-whisper = "^1.0.3"
34
  fastapi = "^0.114.2"
35
- websockets = "^13.0.1"
36
  #websocket-client = "^1.8.0"
37
  librosa = "^0.10.2.post1"
38
  uvicorn = "^0.30.6"
 
32
  soundfile = "^0.12.1"
33
  faster-whisper = "^1.0.3"
34
  fastapi = "^0.114.2"
35
+ #websockets = "^13.0.1"
36
  #websocket-client = "^1.8.0"
37
  librosa = "^0.10.2.post1"
38
  uvicorn = "^0.30.6"
ws_client.py ADDED
@@ -0,0 +1,288 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ import threading
4
+ import time
5
+ from pathlib import Path
6
+ from typing import List
7
+
8
+ import websocket
9
+ import os
10
+
11
+ import librosa
12
+ import numpy as np
13
+
14
+ # Define the default WebSocket endpoint
15
+ DEFAULT_WS_URL = "ws://localhost:8000/v1/ws_transcribe_streaming"
16
+
17
+
18
+ def parse_arguments():
19
+ parser = argparse.ArgumentParser(description="Stream audio to the transcription WebSocket endpoint.")
20
+ parser.add_argument("audio_file", help="Path to the input audio file.")
21
+ parser.add_argument("--url", default=DEFAULT_WS_URL, help="WebSocket endpoint URL.")
22
+ parser.add_argument("--model", type=str, help="Model name to use for transcription.")
23
+ parser.add_argument("--language", type=str, help="Language code for transcription.")
24
+ parser.add_argument(
25
+ "--response_format",
26
+ type=str,
27
+ default="verbose_json",
28
+ choices=["text", "json", "verbose_json"],
29
+ help="Response format.",
30
+ )
31
+ parser.add_argument("--temperature", type=float, default=0.0, help="Temperature for transcription.")
32
+ parser.add_argument("--vad_filter", action="store_true", help="Enable voice activity detection filter.")
33
+ parser.add_argument("--chunk_duration", type=float, default=1.0, help="Duration of each audio chunk in seconds.")
34
+ return parser.parse_args()
35
+
36
+
37
+ # def preprocess_audio(audio_file, target_sr=16000):
38
+ # """
39
+ # Load the audio file, convert to mono 16kHz, and return the audio data.
40
+ # """
41
+ # if audio_file.endswith(".mp3"):
42
+ # # Convert MP3 to WAV using ffmpeg
43
+ # wav_file = audio_file.replace(".mp3", ".wav")
44
+ # if not os.path.exists(wav_file):
45
+ # command = f'ffmpeg -i "{audio_file}" -ac 1 -ar {target_sr} "{wav_file}"'
46
+ # print(f"Converting MP3 to WAV: {command}")
47
+ # os.system(command)
48
+ # audio_file = wav_file
49
+ #
50
+ # print(f"Loading audio file {audio_file}")
51
+ # audio_data, sr = librosa.load(audio_file, sr=target_sr, mono=True)
52
+ # return audio_data, sr
53
+ #
54
+ # def chunk_audio(audio_data, sr, chunk_duration):
55
+ # """
56
+ # Split the audio data into chunks of specified duration.
57
+ # """
58
+ # chunk_samples = int(chunk_duration * sr)
59
+ # total_samples = len(audio_data)
60
+ # chunks = [
61
+ # audio_data[i:i + chunk_samples]
62
+ # for i in range(0, total_samples, chunk_samples)
63
+ # ]
64
+ # print(f"Split audio into {len(chunks)} chunks of {chunk_duration} seconds each.")
65
+ # return chunks
66
+
67
+
68
+ def read_audio_in_chunks(audio_file, target_sr=16000, chunk_duration=1) -> List[np.ndarray]:
69
+ """
70
+ Reads a 16kHz mono audio file in 1-second chunks and returns them as little-endian 16-bit integer arrays.
71
+
72
+ Args:
73
+ file_path (str): Path to the audio file.
74
+ expected_sr (int): Expected sample rate (16000 by default).
75
+ expected_mono (bool): Expect the file to be mono (True by default).
76
+ chunk_duration (int): Duration of each chunk in seconds (1 second by default).
77
+
78
+ Returns:
79
+ List of numpy arrays: Each array is a 1-second chunk of the audio as 16-bit integers.
80
+
81
+ Raises:
82
+ ValueError: If the audio file's sample rate or number of channels doesn't match expectations.
83
+ """
84
+ if not str(audio_file).endswith(".wav"):
85
+ # Convert MP3 to WAV using ffmpeg
86
+ wav_file = Path(audio_file).with_suffix(".wav")
87
+ if not wav_file.exists():
88
+ command = f'ffmpeg -i "{audio_file}" -ac 1 -ar {target_sr} "{wav_file}"'
89
+ print(f"Converting MP3 to WAV: {command}")
90
+ os.system(command)
91
+ audio_file = wav_file
92
+
93
+ # Load the audio file
94
+ audio_data, sr = librosa.load(audio_file, sr=None, mono=True)
95
+
96
+ # Raise an exception if the sample rate doesn't match
97
+ if sr != target_sr:
98
+ raise ValueError(f"Unexpected sample rate {sr}. Expected {target_sr}.")
99
+
100
+ # Convert the audio data to 16-bit PCM (little-endian)
101
+ audio_data_int16 = (audio_data * 32767).astype(np.int16)
102
+
103
+ # Check if the current byte order is little-endian
104
+ if audio_data_int16.dtype.byteorder == '>' or (
105
+ audio_data_int16.dtype.byteorder == '=' and np.dtype(np.int16).byteorder == '>'):
106
+ print("Byte swap performed to convert to little-endian.")
107
+ # Ensure little-endian format (if the current format is big-endian)
108
+ audio_data_little_endian = audio_data_int16.byteswap().newbyteorder('L')
109
+ else:
110
+ print("No byte swap needed. Already little-endian.")
111
+ audio_data_little_endian = audio_data_int16
112
+
113
+ # Calculate the number of samples per chunk
114
+ samples_per_chunk = target_sr * chunk_duration
115
+
116
+ # Split the audio into chunks
117
+ chunks = [
118
+ audio_data_little_endian[i:i + samples_per_chunk]
119
+ for i in range(0, len(audio_data_little_endian), samples_per_chunk)
120
+ ]
121
+
122
+ return chunks
123
+
124
+
125
+ def build_query_params(args):
126
+ """
127
+ Build the query parameters for the WebSocket URL based on command-line arguments.
128
+ """
129
+ params = {}
130
+ if args.model:
131
+ params["model"] = args.model
132
+ if args.language:
133
+ params["language"] = args.language
134
+ if args.response_format:
135
+ params["response_format"] = args.response_format
136
+ if args.temperature is not None:
137
+ params["temperature"] = str(args.temperature)
138
+ if args.vad_filter:
139
+ params["vad_filter"] = "true"
140
+ return params
141
+
142
+
143
+ def websocket_url_with_params(base_url, params):
144
+ """
145
+ Append query parameters to the WebSocket URL.
146
+ """
147
+ from urllib.parse import urlencode
148
+
149
+ if params:
150
+ query_string = urlencode(params)
151
+ url = f"{base_url}?{query_string}"
152
+ else:
153
+ url = base_url
154
+ return url
155
+
156
+
157
+ def on_message(ws, message):
158
+ """
159
+ Callback function when a message is received from the server.
160
+ """
161
+ try:
162
+ data = json.loads(message)
163
+ # Accumulate transcriptions
164
+ if ws.args.response_format == "verbose_json":
165
+ segments = data.get('segments', [])
166
+ ws.transcriptions.extend(segments)
167
+ for segment in segments:
168
+ print(f"Received segment: {segment['text']}")
169
+ else:
170
+ # For 'json' or 'text' format
171
+ ws.transcriptions.append(data)
172
+ print(f"Transcription: {data['text']}")
173
+ except json.JSONDecodeError:
174
+ print(f"Received non-JSON message: {message}")
175
+
176
+
177
+ def on_error(ws, error):
178
+ """
179
+ Callback function when an error occurs.
180
+ """
181
+ print(f"WebSocket error: {error}")
182
+
183
+
184
+ def on_close(ws, close_status_code, close_msg):
185
+ """
186
+ Callback function when the WebSocket connection is closed.
187
+ """
188
+ print("WebSocket connection closed")
189
+
190
+
191
+ def on_open(ws):
192
+ """
193
+ Callback function when the WebSocket connection is opened.
194
+ """
195
+ print("WebSocket connection opened")
196
+ ws.transcriptions = [] # Initialize the list to store transcriptions
197
+
198
+
199
+ def send_audio_chunks(ws, audio_chunks, sr):
200
+ """
201
+ Send audio chunks to the WebSocket server.
202
+ """
203
+ for idx, chunk in enumerate(audio_chunks):
204
+ # Ensure little-endian format
205
+ audio_bytes = chunk.astype('<f4').tobytes()
206
+ ws.send(audio_bytes, opcode=websocket.ABNF.OPCODE_BINARY)
207
+ print(f"Sent chunk {idx + 1}/{len(audio_chunks)}")
208
+ time.sleep(0.1) # Small delay to simulate real-time streaming
209
+ print("All audio chunks sent")
210
+ # Optionally, wait to receive remaining messages
211
+ time.sleep(2)
212
+ ws.close()
213
+ print("Closed WebSocket connection")
214
+
215
+
216
+
217
+ def format_timestamp(seconds):
218
+ """
219
+ Convert seconds to SRT timestamp format (HH:MM:SS,mmm).
220
+ """
221
+ total_milliseconds = int(seconds * 1000)
222
+ hours = total_milliseconds // (3600 * 1000)
223
+ minutes = (total_milliseconds % (3600 * 1000)) // (60 * 1000)
224
+ secs = (total_milliseconds % (60 * 1000)) // 1000
225
+ milliseconds = total_milliseconds % 1000
226
+ return f"{hours:02}:{minutes:02}:{secs:02},{milliseconds:03}"
227
+
228
+
229
+ def generate_srt(transcriptions):
230
+ """
231
+ Generate and print SRT content from transcriptions.
232
+ """
233
+ print("\nGenerated SRT:")
234
+ for idx, segment in enumerate(transcriptions, 1):
235
+ start_time = format_timestamp(segment['start'])
236
+ end_time = format_timestamp(segment['end'])
237
+ text = segment['text']
238
+ print(f"{idx}")
239
+ print(f"{start_time} --> {end_time}")
240
+ print(f"{text}\n")
241
+
242
+
243
+ def run_websocket_client(args):
244
+ """
245
+ Run the WebSocket client to stream audio and receive transcriptions.
246
+ """
247
+ try:
248
+ audio_chunks = read_audio_in_chunks(args.audio_file)
249
+
250
+ # params = build_query_params(args)
251
+ # ws_url = websocket_url_with_params(args.url, params)
252
+ ws_url = args.url
253
+
254
+ ws = websocket.WebSocketApp(
255
+ ws_url,
256
+ on_open=on_open,
257
+ on_message=on_message,
258
+ on_error=on_error,
259
+ on_close=on_close,
260
+ )
261
+ ws.args = args # Attach args to ws to access inside callbacks
262
+
263
+ # Run the WebSocket in a separate thread to allow sending and receiving simultaneously
264
+ ws_thread = threading.Thread(target=ws.run_forever)
265
+ ws_thread.start()
266
+
267
+ # Wait for the connection to open
268
+ while not ws.sock or not ws.sock.connected:
269
+ time.sleep(0.1)
270
+
271
+ # Send the audio chunks
272
+ send_audio_chunks(ws, audio_chunks, 16000)
273
+ except Exception as e:
274
+ print(f"An error occurred: {e}")
275
+
276
+ # Wait for the WebSocket thread to finish
277
+ ws_thread.join()
278
+
279
+ # Generate SRT if transcriptions are available
280
+ if hasattr(ws, 'transcriptions') and ws.transcriptions:
281
+ generate_srt(ws.transcriptions)
282
+ else:
283
+ print("No transcriptions received.")
284
+
285
+
286
+ if __name__ == "__main__":
287
+ args = parse_arguments()
288
+ run_websocket_client(args)
ws_server.py CHANGED
@@ -1,11 +1,13 @@
1
  # Import the necessary components from whisper_online.py
2
  import logging
3
  import os
 
4
 
5
  import librosa
6
  import soundfile
7
  import uvicorn
8
  from fastapi import FastAPI, WebSocket
 
9
  from starlette.websockets import WebSocketDisconnect
10
 
11
  from libs.whisper_streaming.whisper_online import (
@@ -25,22 +27,51 @@ import argparse
25
  import sys
26
  import numpy as np
27
  import io
28
- import soundfile as sf
29
  import wave
30
  import requests
31
  import argparse
32
 
 
 
33
  logger = logging.getLogger(__name__)
34
 
35
  SAMPLING_RATE = 16000
36
  WARMUP_FILE = "mono16k.test_hebrew.wav"
37
  AUDIO_FILE_URL = "https://raw.githubusercontent.com/AshDavid12/runpod-serverless-forked/main/test_hebrew.wav"
38
 
39
- is_first = True
40
- asr, online = None, None
41
- min_limit = None # min_chunk*SAMPLING_RATE
42
  app = FastAPI()
 
 
 
 
 
 
 
 
 
 
43
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
 
45
  def convert_to_mono_16k(input_wav: str, output_wav: str) -> None:
46
  """
@@ -78,28 +109,68 @@ def download_warmup_file():
78
  convert_to_mono_16k(audio_file_path, WARMUP_FILE)
79
 
80
 
81
- async def receive_audio_chunk(self, websocket: WebSocket):
 
 
 
 
 
 
 
 
 
 
 
 
 
82
  # receive all audio that is available by this time
83
  # blocks operation if less than self.min_chunk seconds is available
84
  # unblocks if connection is closed or a chunk is available
85
  out = []
86
- while sum(len(x) for x in out) < min_limit:
87
- raw_bytes = await websocket.receive_bytes()
88
  if not raw_bytes:
89
  break
90
-
91
  sf = soundfile.SoundFile(io.BytesIO(raw_bytes), channels=1,endian="LITTLE",samplerate=SAMPLING_RATE, subtype="PCM_16",format="RAW")
92
  audio, _ = librosa.load(sf,sr=SAMPLING_RATE,dtype=np.float32)
93
  out.append(audio)
94
-
95
  if not out:
96
  return None
 
 
 
97
 
98
- conc = np.concatenate(out)
99
- if self.is_first and len(conc) < min_limit:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
  return None
101
- self.is_first = False
102
- return conc
103
 
104
  # Define WebSocket endpoint
105
  @app.websocket("/ws_transcribe_streaming")
@@ -108,46 +179,37 @@ async def websocket_transcribe(websocket: WebSocket):
108
  await websocket.accept()
109
  logger.info("WebSocket connection established successfully.")
110
 
 
 
111
  asr, online = asr_factory(args)
 
 
 
 
 
 
112
 
113
  # warm up the ASR because the very first transcribe takes more time than the others.
114
  # Test results in https://github.com/ufal/whisper_streaming/pull/81
 
115
  a = load_audio_chunk(WARMUP_FILE, 0, 1)
116
  asr.transcribe(a)
117
  logger.info("Whisper is warmed up.")
118
- global min_limit
119
- min_limit = args.min_chunk_size * SAMPLING_RATE
120
 
121
  try:
122
- out = []
123
  while True:
 
 
 
 
 
124
  try:
125
- # Receive JSON data
126
- raw_bytes = await websocket.receive_json()
127
-
128
- sf = soundfile.SoundFile(io.BytesIO(raw_bytes), channels=1, endian="LITTLE", samplerate=SAMPLING_RATE,
129
- subtype="PCM_16", format="RAW")
130
- audio, _ = librosa.load(sf, sr=SAMPLING_RATE, dtype=np.float32)
131
- out.append(audio)
132
-
133
- # Call the transcribe function
134
- # segments, info = await asyncio.to_thread(model.transcribe,
135
- segments, info = model.transcribe(
136
- audio_file_path,
137
- language='he',
138
- initial_prompt=input_data.init_prompt,
139
- beam_size=5,
140
- word_timestamps=True,
141
- condition_on_previous_text=True
142
- )
143
-
144
- # Convert segments to list and serialize
145
- segments_list = list(segments)
146
- segments_serializable = [segment_to_dict(s) for s in segments_list]
147
- logger.info(get_raw_words_from_segments(segments_list))
148
- # Send the serialized segments back to the client
149
- await websocket.send_json(segments_serializable)
150
 
 
 
 
151
  except WebSocketDisconnect:
152
  logger.info("WebSocket connection closed by the client.")
153
  break
@@ -158,8 +220,11 @@ async def websocket_transcribe(websocket: WebSocket):
158
  logger.info("Cleaning up and closing WebSocket connection.")
159
 
160
  def main():
161
- args = argparse.ArgumentParser()
162
- args = add_shared_args(args)
 
 
 
163
  args.parse_args([
164
  '--lan', 'he',
165
  '--model', 'ivrit-ai/faster-whisper-v2-d4',
@@ -168,8 +233,7 @@ def main():
168
  # '--vac', '--buffer_trimming', 'segment', '--buffer_trimming_sec', '15', '--min_chunk_size', '1.0', '--vac_chunk_size', '0.04', '--start_at', '0.0', '--offline', '--comp_unaware', '--log_level', 'DEBUG'
169
  ])
170
 
 
171
 
172
- global asr, online
173
-
174
-
175
- uvicorn.run(app)
 
1
  # Import the necessary components from whisper_online.py
2
  import logging
3
  import os
4
+ from typing import Optional
5
 
6
  import librosa
7
  import soundfile
8
  import uvicorn
9
  from fastapi import FastAPI, WebSocket
10
+ from pydantic import BaseModel, ConfigDict
11
  from starlette.websockets import WebSocketDisconnect
12
 
13
  from libs.whisper_streaming.whisper_online import (
 
27
  import sys
28
  import numpy as np
29
  import io
30
+ import soundfile
31
  import wave
32
  import requests
33
  import argparse
34
 
35
+ # from libs.whisper_streaming.whisper_online_server import online
36
+
37
  logger = logging.getLogger(__name__)
38
 
39
  SAMPLING_RATE = 16000
40
  WARMUP_FILE = "mono16k.test_hebrew.wav"
41
  AUDIO_FILE_URL = "https://raw.githubusercontent.com/AshDavid12/runpod-serverless-forked/main/test_hebrew.wav"
42
 
 
 
 
43
  app = FastAPI()
44
+ args = argparse.ArgumentParser()
45
+ add_shared_args(args)
46
+
47
+ def drop_option_from_parser(parser, option_name):
48
+ """
49
+ Reinitializes the parser and copies all options except the specified option.
50
+
51
+ Args:
52
+ parser (argparse.ArgumentParser): The original argument parser.
53
+ option_name (str): The option string to drop (e.g., '--model').
54
 
55
+ Returns:
56
+ argparse.ArgumentParser: A new parser without the specified option.
57
+ """
58
+ # Create a new parser with the same description and other attributes
59
+ new_parser = argparse.ArgumentParser(
60
+ description=parser.description,
61
+ epilog=parser.epilog,
62
+ formatter_class=parser.formatter_class
63
+ )
64
+
65
+ # Iterate through all the arguments of the original parser
66
+ for action in parser._actions:
67
+ if "-h" in action.option_strings:
68
+ continue
69
+
70
+ # Check if the option is not the one to drop
71
+ if option_name not in action.option_strings :
72
+ new_parser._add_action(action)
73
+
74
+ return new_parser
75
 
76
  def convert_to_mono_16k(input_wav: str, output_wav: str) -> None:
77
  """
 
109
  convert_to_mono_16k(audio_file_path, WARMUP_FILE)
110
 
111
 
112
+
113
+
114
+ class State(BaseModel):
115
+ model_config = ConfigDict(arbitrary_types_allowed=True)
116
+
117
+ websocket: WebSocket
118
+ asr: ASRBase
119
+ online: OnlineASRProcessor
120
+ min_limit: int
121
+
122
+ is_first: bool = True
123
+ last_end: Optional[float] = None
124
+
125
+ async def receive_audio_chunk(state: State) -> Optional[np.ndarray]:
126
  # receive all audio that is available by this time
127
  # blocks operation if less than self.min_chunk seconds is available
128
  # unblocks if connection is closed or a chunk is available
129
  out = []
130
+ while sum(len(x) for x in out) < state.min_limit:
131
+ raw_bytes = await state.websocket.receive_bytes()
132
  if not raw_bytes:
133
  break
134
+ # print("received audio:",len(raw_bytes), "bytes", raw_bytes[:10])
135
  sf = soundfile.SoundFile(io.BytesIO(raw_bytes), channels=1,endian="LITTLE",samplerate=SAMPLING_RATE, subtype="PCM_16",format="RAW")
136
  audio, _ = librosa.load(sf,sr=SAMPLING_RATE,dtype=np.float32)
137
  out.append(audio)
 
138
  if not out:
139
  return None
140
+ flat_out = np.concatenate(out)
141
+ if state.is_first and len(flat_out) < state.min_limit:
142
+ return None
143
 
144
+ state.is_first = False
145
+ return flat_out
146
+
147
+ def format_output_transcript(state, o) -> dict:
148
+ # output format in stdout is like:
149
+ # 0 1720 Takhle to je
150
+ # - the first two words are:
151
+ # - beg and end timestamp of the text segment, as estimated by Whisper model. The timestamps are not accurate, but they're useful anyway
152
+ # - the next words: segment transcript
153
+
154
+ # This function differs from whisper_online.output_transcript in the following:
155
+ # succeeding [beg,end] intervals are not overlapping because ELITR protocol (implemented in online-text-flow events) requires it.
156
+ # Therefore, beg, is max of previous end and current beg outputed by Whisper.
157
+ # Usually it differs negligibly, by appx 20 ms.
158
+
159
+ if o[0] is not None:
160
+ beg, end = o[0]*1000,o[1]*1000
161
+ if state.last_end is not None:
162
+ beg = max(beg, state.last_end)
163
+
164
+ state.last_end = end
165
+ print("%1.0f %1.0f %s" % (beg,end,o[2]),flush=True,file=sys.stderr)
166
+ return {
167
+ "start": "%1.0f" % beg,
168
+ "end": "%1.0f" % end,
169
+ "text": "%s" % o[2],
170
+ }
171
+ else:
172
+ logger.debug("No text in this segment")
173
  return None
 
 
174
 
175
  # Define WebSocket endpoint
176
  @app.websocket("/ws_transcribe_streaming")
 
179
  await websocket.accept()
180
  logger.info("WebSocket connection established successfully.")
181
 
182
+ # initialize the ASR model
183
+ logger.info("Loading whisper model...")
184
  asr, online = asr_factory(args)
185
+ state = State(
186
+ websocket=websocket,
187
+ asr=asr,
188
+ online=online,
189
+ min_limit=args.min_chunk_size * SAMPLING_RATE,
190
+ )
191
 
192
  # warm up the ASR because the very first transcribe takes more time than the others.
193
  # Test results in https://github.com/ufal/whisper_streaming/pull/81
194
+ logger.info("Warming up the whisper model...")
195
  a = load_audio_chunk(WARMUP_FILE, 0, 1)
196
  asr.transcribe(a)
197
  logger.info("Whisper is warmed up.")
 
 
198
 
199
  try:
 
200
  while True:
201
+ a = await receive_audio_chunk(state)
202
+ if a is None:
203
+ break
204
+ state.online.insert_audio_chunk(a)
205
+ o = online.process_iter()
206
  try:
207
+ if result := format_output_transcript(state, o):
208
+ await websocket.send_json(result)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
209
 
210
+ except BrokenPipeError:
211
+ logger.info("broken pipe -- connection closed?")
212
+ break
213
  except WebSocketDisconnect:
214
  logger.info("WebSocket connection closed by the client.")
215
  break
 
220
  logger.info("Cleaning up and closing WebSocket connection.")
221
 
222
  def main():
223
+ global args
224
+ args = drop_option_from_parser(args, '--model')
225
+ args.add_argument('--model', type=str,
226
+ help="Name size of the Whisper model to use. The model is automatically downloaded from the model hub if not present in model cache dir.")
227
+
228
  args.parse_args([
229
  '--lan', 'he',
230
  '--model', 'ivrit-ai/faster-whisper-v2-d4',
 
233
  # '--vac', '--buffer_trimming', 'segment', '--buffer_trimming_sec', '15', '--min_chunk_size', '1.0', '--vac_chunk_size', '0.04', '--start_at', '0.0', '--offline', '--comp_unaware', '--log_level', 'DEBUG'
234
  ])
235
 
236
+ uvicorn.run(app)
237
 
238
+ if __name__ == "__main__":
239
+ main()