aviadr1 commited on
Commit
e8aa012
·
1 Parent(s): d8dadfc
Files changed (2) hide show
  1. faster-whisper-server-client.py +225 -0
  2. ws_server.py +175 -0
faster-whisper-server-client.py ADDED
@@ -0,0 +1,225 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ import threading
4
+ import time
5
+ import websocket
6
+ import os
7
+
8
+ import librosa
9
+ import numpy as np
10
+
11
+ # Define the default WebSocket endpoint
12
+ DEFAULT_WS_URL = "ws://localhost:8000/v1/audio/transcriptions"
13
+
14
+
15
+ def parse_arguments():
16
+ parser = argparse.ArgumentParser(description="Stream audio to the transcription WebSocket endpoint.")
17
+ parser.add_argument("audio_file", help="Path to the input audio file.")
18
+ parser.add_argument("--url", default=DEFAULT_WS_URL, help="WebSocket endpoint URL.")
19
+ parser.add_argument("--model", type=str, help="Model name to use for transcription.")
20
+ parser.add_argument("--language", type=str, help="Language code for transcription.")
21
+ parser.add_argument(
22
+ "--response_format",
23
+ type=str,
24
+ default="verbose_json",
25
+ choices=["text", "json", "verbose_json"],
26
+ help="Response format.",
27
+ )
28
+ parser.add_argument("--temperature", type=float, default=0.0, help="Temperature for transcription.")
29
+ parser.add_argument("--vad_filter", action="store_true", help="Enable voice activity detection filter.")
30
+ parser.add_argument("--chunk_duration", type=float, default=1.0, help="Duration of each audio chunk in seconds.")
31
+ return parser.parse_args()
32
+
33
+
34
+ def preprocess_audio(audio_file, target_sr=16000):
35
+ """
36
+ Load the audio file, convert to mono 16kHz, and return the audio data.
37
+ """
38
+ if audio_file.endswith(".mp3"):
39
+ # Convert MP3 to WAV using ffmpeg
40
+ wav_file = audio_file.replace(".mp3", ".wav")
41
+ if not os.path.exists(wav_file):
42
+ command = f'ffmpeg -i "{audio_file}" -ac 1 -ar {target_sr} "{wav_file}"'
43
+ print(f"Converting MP3 to WAV: {command}")
44
+ os.system(command)
45
+ audio_file = wav_file
46
+
47
+ print(f"Loading audio file {audio_file}")
48
+ audio_data, sr = librosa.load(audio_file, sr=target_sr, mono=True)
49
+ return audio_data, sr
50
+
51
+ def chunk_audio(audio_data, sr, chunk_duration):
52
+ """
53
+ Split the audio data into chunks of specified duration.
54
+ """
55
+ chunk_samples = int(chunk_duration * sr)
56
+ total_samples = len(audio_data)
57
+ chunks = [
58
+ audio_data[i:i + chunk_samples]
59
+ for i in range(0, total_samples, chunk_samples)
60
+ ]
61
+ print(f"Split audio into {len(chunks)} chunks of {chunk_duration} seconds each.")
62
+ return chunks
63
+
64
+
65
+ def build_query_params(args):
66
+ """
67
+ Build the query parameters for the WebSocket URL based on command-line arguments.
68
+ """
69
+ params = {}
70
+ if args.model:
71
+ params["model"] = args.model
72
+ if args.language:
73
+ params["language"] = args.language
74
+ if args.response_format:
75
+ params["response_format"] = args.response_format
76
+ if args.temperature is not None:
77
+ params["temperature"] = str(args.temperature)
78
+ if args.vad_filter:
79
+ params["vad_filter"] = "true"
80
+ return params
81
+
82
+
83
+ def websocket_url_with_params(base_url, params):
84
+ """
85
+ Append query parameters to the WebSocket URL.
86
+ """
87
+ from urllib.parse import urlencode
88
+
89
+ if params:
90
+ query_string = urlencode(params)
91
+ url = f"{base_url}?{query_string}"
92
+ else:
93
+ url = base_url
94
+ return url
95
+
96
+
97
+ def on_message(ws, message):
98
+ """
99
+ Callback function when a message is received from the server.
100
+ """
101
+ try:
102
+ data = json.loads(message)
103
+ # Accumulate transcriptions
104
+ if ws.args.response_format == "verbose_json":
105
+ segments = data.get('segments', [])
106
+ ws.transcriptions.extend(segments)
107
+ for segment in segments:
108
+ print(f"Received segment: {segment['text']}")
109
+ else:
110
+ # For 'json' or 'text' format
111
+ ws.transcriptions.append(data)
112
+ print(f"Transcription: {data['text']}")
113
+ except json.JSONDecodeError:
114
+ print(f"Received non-JSON message: {message}")
115
+
116
+
117
+ def on_error(ws, error):
118
+ """
119
+ Callback function when an error occurs.
120
+ """
121
+ print(f"WebSocket error: {error}")
122
+
123
+
124
+ def on_close(ws, close_status_code, close_msg):
125
+ """
126
+ Callback function when the WebSocket connection is closed.
127
+ """
128
+ print("WebSocket connection closed")
129
+
130
+
131
+ def on_open(ws):
132
+ """
133
+ Callback function when the WebSocket connection is opened.
134
+ """
135
+ print("WebSocket connection opened")
136
+ ws.transcriptions = [] # Initialize the list to store transcriptions
137
+
138
+
139
+ def send_audio_chunks(ws, audio_chunks, sr):
140
+ """
141
+ Send audio chunks to the WebSocket server.
142
+ """
143
+ for idx, chunk in enumerate(audio_chunks):
144
+ # Ensure little-endian format
145
+ audio_bytes = chunk.astype('<f4').tobytes()
146
+ ws.send(audio_bytes, opcode=websocket.ABNF.OPCODE_BINARY)
147
+ print(f"Sent chunk {idx + 1}/{len(audio_chunks)}")
148
+ time.sleep(0.1) # Small delay to simulate real-time streaming
149
+ print("All audio chunks sent")
150
+ # Optionally, wait to receive remaining messages
151
+ time.sleep(2)
152
+ ws.close()
153
+ print("Closed WebSocket connection")
154
+
155
+
156
+
157
+ def format_timestamp(seconds):
158
+ """
159
+ Convert seconds to SRT timestamp format (HH:MM:SS,mmm).
160
+ """
161
+ total_milliseconds = int(seconds * 1000)
162
+ hours = total_milliseconds // (3600 * 1000)
163
+ minutes = (total_milliseconds % (3600 * 1000)) // (60 * 1000)
164
+ secs = (total_milliseconds % (60 * 1000)) // 1000
165
+ milliseconds = total_milliseconds % 1000
166
+ return f"{hours:02}:{minutes:02}:{secs:02},{milliseconds:03}"
167
+
168
+
169
+ def generate_srt(transcriptions):
170
+ """
171
+ Generate and print SRT content from transcriptions.
172
+ """
173
+ print("\nGenerated SRT:")
174
+ for idx, segment in enumerate(transcriptions, 1):
175
+ start_time = format_timestamp(segment['start'])
176
+ end_time = format_timestamp(segment['end'])
177
+ text = segment['text']
178
+ print(f"{idx}")
179
+ print(f"{start_time} --> {end_time}")
180
+ print(f"{text}\n")
181
+
182
+
183
+ def run_websocket_client(args):
184
+ """
185
+ Run the WebSocket client to stream audio and receive transcriptions.
186
+ """
187
+ audio_data, sr = preprocess_audio(args.audio_file)
188
+ audio_chunks = chunk_audio(audio_data, sr, args.chunk_duration)
189
+
190
+ params = build_query_params(args)
191
+ ws_url = websocket_url_with_params(args.url, params)
192
+
193
+ ws = websocket.WebSocketApp(
194
+ ws_url,
195
+ on_open=on_open,
196
+ on_message=on_message,
197
+ on_error=on_error,
198
+ on_close=on_close,
199
+ )
200
+ ws.args = args # Attach args to ws to access inside callbacks
201
+
202
+ # Run the WebSocket in a separate thread to allow sending and receiving simultaneously
203
+ ws_thread = threading.Thread(target=ws.run_forever)
204
+ ws_thread.start()
205
+
206
+ # Wait for the connection to open
207
+ while not ws.sock or not ws.sock.connected:
208
+ time.sleep(0.1)
209
+
210
+ # Send the audio chunks
211
+ send_audio_chunks(ws, audio_chunks, sr)
212
+
213
+ # Wait for the WebSocket thread to finish
214
+ ws_thread.join()
215
+
216
+ # Generate SRT if transcriptions are available
217
+ if hasattr(ws, 'transcriptions') and ws.transcriptions:
218
+ generate_srt(ws.transcriptions)
219
+ else:
220
+ print("No transcriptions received.")
221
+
222
+
223
+ if __name__ == "__main__":
224
+ args = parse_arguments()
225
+ run_websocket_client(args)
ws_server.py ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Import the necessary components from whisper_online.py
2
+ import logging
3
+ import os
4
+
5
+ import librosa
6
+ import soundfile
7
+ import uvicorn
8
+ from fastapi import FastAPI, WebSocket
9
+ from starlette.websockets import WebSocketDisconnect
10
+
11
+ from libs.whisper_streaming.whisper_online import (
12
+ ASRBase,
13
+ OnlineASRProcessor,
14
+ VACOnlineASRProcessor,
15
+ add_shared_args,
16
+ asr_factory,
17
+ set_logging,
18
+ create_tokenizer,
19
+ load_audio,
20
+ load_audio_chunk, OpenaiApiASR,
21
+ set_logging
22
+ )
23
+
24
+ import argparse
25
+ import sys
26
+ import numpy as np
27
+ import io
28
+ import soundfile as sf
29
+ import wave
30
+ import requests
31
+ import argparse
32
+
33
+ logger = logging.getLogger(__name__)
34
+
35
+ SAMPLING_RATE = 16000
36
+ WARMUP_FILE = "mono16k.test_hebrew.wav"
37
+ AUDIO_FILE_URL = "https://raw.githubusercontent.com/AshDavid12/runpod-serverless-forked/main/test_hebrew.wav"
38
+
39
+ is_first = True
40
+ asr, online = None, None
41
+ min_limit = None # min_chunk*SAMPLING_RATE
42
+ app = FastAPI()
43
+
44
+
45
+ def convert_to_mono_16k(input_wav: str, output_wav: str) -> None:
46
+ """
47
+ Converts any .wav file to mono 16 kHz.
48
+
49
+ Args:
50
+ input_wav (str): Path to the input .wav file.
51
+ output_wav (str): Path to save the output .wav file with mono 16 kHz.
52
+ """
53
+ # Step 1: Load the audio file with librosa
54
+ audio_data, original_sr = librosa.load(input_wav, sr=None, mono=False) # Load at original sampling rate
55
+ logger.info("Loaded audio with shape: %s, original sampling rate: %d" % (audio_data.shape, original_sr))
56
+
57
+ # Step 2: If the audio has multiple channels, average them to make it mono
58
+ if audio_data.ndim > 1:
59
+ audio_data = librosa.to_mono(audio_data)
60
+
61
+ # Step 3: Resample the audio to 16 kHz
62
+ resampled_audio = librosa.resample(audio_data, orig_sr=original_sr, target_sr=SAMPLING_RATE)
63
+
64
+ # Step 4: Save the resampled audio as a .wav file in mono at 16 kHz
65
+ sf.write(output_wav, resampled_audio, SAMPLING_RATE)
66
+
67
+ logger.info(f"Converted audio saved to {output_wav}")
68
+
69
+ def download_warmup_file():
70
+ # Download the audio file if not already present
71
+ audio_file_path = "test_hebrew.wav"
72
+ if not os.path.exists(WARMUP_FILE):
73
+ if not os.path.exists(audio_file_path):
74
+ response = requests.get(AUDIO_FILE_URL)
75
+ with open(audio_file_path, 'wb') as f:
76
+ f.write(response.content)
77
+
78
+ convert_to_mono_16k(audio_file_path, WARMUP_FILE)
79
+
80
+
81
+ async def receive_audio_chunk(self, websocket: WebSocket):
82
+ # receive all audio that is available by this time
83
+ # blocks operation if less than self.min_chunk seconds is available
84
+ # unblocks if connection is closed or a chunk is available
85
+ out = []
86
+ while sum(len(x) for x in out) < min_limit:
87
+ raw_bytes = await websocket.receive_bytes()
88
+ if not raw_bytes:
89
+ break
90
+
91
+ sf = soundfile.SoundFile(io.BytesIO(raw_bytes), channels=1,endian="LITTLE",samplerate=SAMPLING_RATE, subtype="PCM_16",format="RAW")
92
+ audio, _ = librosa.load(sf,sr=SAMPLING_RATE,dtype=np.float32)
93
+ out.append(audio)
94
+
95
+ if not out:
96
+ return None
97
+
98
+ conc = np.concatenate(out)
99
+ if self.is_first and len(conc) < min_limit:
100
+ return None
101
+ self.is_first = False
102
+ return conc
103
+
104
+ # Define WebSocket endpoint
105
+ @app.websocket("/ws_transcribe_streaming")
106
+ async def websocket_transcribe(websocket: WebSocket):
107
+ logger.info("New WebSocket connection request received.")
108
+ await websocket.accept()
109
+ logger.info("WebSocket connection established successfully.")
110
+
111
+ asr, online = asr_factory(args)
112
+
113
+ # warm up the ASR because the very first transcribe takes more time than the others.
114
+ # Test results in https://github.com/ufal/whisper_streaming/pull/81
115
+ a = load_audio_chunk(WARMUP_FILE, 0, 1)
116
+ asr.transcribe(a)
117
+ logger.info("Whisper is warmed up.")
118
+ global min_limit
119
+ min_limit = args.min_chunk_size * SAMPLING_RATE
120
+
121
+ try:
122
+ out = []
123
+ while True:
124
+ try:
125
+ # Receive JSON data
126
+ raw_bytes = await websocket.receive_json()
127
+
128
+ sf = soundfile.SoundFile(io.BytesIO(raw_bytes), channels=1, endian="LITTLE", samplerate=SAMPLING_RATE,
129
+ subtype="PCM_16", format="RAW")
130
+ audio, _ = librosa.load(sf, sr=SAMPLING_RATE, dtype=np.float32)
131
+ out.append(audio)
132
+
133
+ # Call the transcribe function
134
+ # segments, info = await asyncio.to_thread(model.transcribe,
135
+ segments, info = model.transcribe(
136
+ audio_file_path,
137
+ language='he',
138
+ initial_prompt=input_data.init_prompt,
139
+ beam_size=5,
140
+ word_timestamps=True,
141
+ condition_on_previous_text=True
142
+ )
143
+
144
+ # Convert segments to list and serialize
145
+ segments_list = list(segments)
146
+ segments_serializable = [segment_to_dict(s) for s in segments_list]
147
+ logger.info(get_raw_words_from_segments(segments_list))
148
+ # Send the serialized segments back to the client
149
+ await websocket.send_json(segments_serializable)
150
+
151
+ except WebSocketDisconnect:
152
+ logger.info("WebSocket connection closed by the client.")
153
+ break
154
+ except Exception as e:
155
+ logger.error(f"Unexpected error during WebSocket transcription: {e}")
156
+ await websocket.send_json({"error": str(e)})
157
+ finally:
158
+ logger.info("Cleaning up and closing WebSocket connection.")
159
+
160
+ def main():
161
+ args = argparse.ArgumentParser()
162
+ args = add_shared_args(args)
163
+ args.parse_args([
164
+ '--lan', 'he',
165
+ '--model', 'ivrit-ai/faster-whisper-v2-d4',
166
+ '--backend', 'faster-whisper',
167
+ '--vad',
168
+ # '--vac', '--buffer_trimming', 'segment', '--buffer_trimming_sec', '15', '--min_chunk_size', '1.0', '--vac_chunk_size', '0.04', '--start_at', '0.0', '--offline', '--comp_unaware', '--log_level', 'DEBUG'
169
+ ])
170
+
171
+
172
+ global asr, online
173
+
174
+
175
+ uvicorn.run(app)