aviadr1 commited on
Commit
861eb71
·
1 Parent(s): bb46abd

connecting and streaming v1

Browse files
Files changed (5) hide show
  1. infer.py +62 -0
  2. model.py +54 -0
  3. poetry.lock +0 -0
  4. pyproject.toml +30 -17
  5. streaming_client.py +257 -0
infer.py CHANGED
@@ -1,4 +1,5 @@
1
  import base64
 
2
  import os
3
  import wave
4
 
@@ -17,6 +18,8 @@ from typing import Optional
17
  import sys
18
  import asyncio
19
 
 
 
20
  # Configure logging
21
  logging.basicConfig(level=logging.DEBUG, format='%(asctime)s %(levelname)s: %(message)s',
22
  handlers=[logging.StreamHandler(sys.stdout)], force=True)
@@ -37,6 +40,61 @@ logging.info(f'Max payload size set to: {MAX_PAYLOAD_SIZE} bytes')
37
  app = FastAPI()
38
 
39
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
  # class InputData(BaseModel):
41
  # type: str
42
  # data: Optional[str] = None # Used for blob input
@@ -439,3 +497,7 @@ async def download_audio(filename: str):
439
  #
440
  # finally:
441
  # logging.info("Cleaning up and closing WebSocket connection.")
 
 
 
 
 
1
  import base64
2
+ import json
3
  import os
4
  import wave
5
 
 
18
  import sys
19
  import asyncio
20
 
21
+ from model import segment_to_dict
22
+
23
  # Configure logging
24
  logging.basicConfig(level=logging.DEBUG, format='%(asctime)s %(levelname)s: %(message)s',
25
  handlers=[logging.StreamHandler(sys.stdout)], force=True)
 
40
  app = FastAPI()
41
 
42
 
43
+ # Define Pydantic model for input
44
+ class TranscribeInput(BaseModel):
45
+ audio: str # Base64-encoded audio data
46
+ init_prompt: str = ""
47
+
48
+
49
+
50
+ # Define WebSocket endpoint
51
+ @app.websocket("/ws_transcribe_streaming")
52
+ async def websocket_transcribe(websocket: WebSocket):
53
+ logger.info("New WebSocket connection request received.")
54
+ await websocket.accept()
55
+ logger.info("WebSocket connection established successfully.")
56
+
57
+ try:
58
+ while True:
59
+ try:
60
+ # Receive JSON data
61
+ data = await websocket.receive_json()
62
+ # Parse input data
63
+ input_data = TranscribeInput(**data)
64
+ # Decode base64 audio data
65
+ audio_bytes = base64.b64decode(input_data.audio)
66
+ # Write audio data to a temporary file
67
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_audio_file:
68
+ temp_audio_file.write(audio_bytes)
69
+ temp_audio_file.flush()
70
+ audio_file_path = temp_audio_file.name
71
+
72
+ # Call the transcribe function
73
+ segments, info = await asyncio.to_thread(model.transcribe,
74
+ audio_file_path,
75
+ language='he',
76
+ initial_prompt=input_data.init_prompt,
77
+ beam_size=5,
78
+ word_timestamps=True,
79
+ condition_on_previous_text=True
80
+ )
81
+
82
+ # Convert segments to list and serialize
83
+ segments_list = list(segments)
84
+ segments_serializable = [segment_to_dict(s) for s in segments_list]
85
+
86
+ # Send the serialized segments back to the client
87
+ await websocket.send_json(segments_serializable)
88
+
89
+ except WebSocketDisconnect:
90
+ logger.info("WebSocket connection closed by the client.")
91
+ break
92
+ except Exception as e:
93
+ logger.error(f"Unexpected error during WebSocket transcription: {e}")
94
+ await websocket.send_json({"error": str(e)})
95
+ finally:
96
+ logger.info("Cleaning up and closing WebSocket connection.")
97
+
98
  # class InputData(BaseModel):
99
  # type: str
100
  # data: Optional[str] = None # Used for blob input
 
497
  #
498
  # finally:
499
  # logging.info("Cleaning up and closing WebSocket connection.")
500
+
501
+ if __name__ == "__main__":
502
+ import uvicorn
503
+ uvicorn.run(app)
model.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Function to convert segments to dictionaries
2
+ from faster_whisper.transcribe import Segment, Word
3
+
4
+
5
+
6
+ # Function to dump a Word instance to a dictionary
7
+ def word_to_dict(word: Word) -> dict:
8
+ return {
9
+ "start": word.start,
10
+ "end": word.end,
11
+ "word": word.word,
12
+ "probability": word.probability
13
+ }
14
+
15
+ # Function to load a Word instance from a dictionary
16
+ def dict_to_word(data: dict) -> Word:
17
+ return Word(
18
+ start=data["start"],
19
+ end=data["end"],
20
+ word=data["word"],
21
+ probability=data["probability"]
22
+ )
23
+
24
+ # Function to dump a Segment instance to a dictionary
25
+ def segment_to_dict(segment: Segment) -> dict:
26
+ return {
27
+ "id": segment.id,
28
+ "seek": segment.seek,
29
+ "start": segment.start,
30
+ "end": segment.end,
31
+ "text": segment.text,
32
+ "tokens": segment.tokens,
33
+ "temperature": segment.temperature,
34
+ "avg_logprob": segment.avg_logprob,
35
+ "compression_ratio": segment.compression_ratio,
36
+ "no_speech_prob": segment.no_speech_prob,
37
+ "words": [word_to_dict(word) for word in segment.words] if segment.words else None
38
+ }
39
+
40
+ # Function to load a Segment instance from a dictionary
41
+ def dict_to_segment(data: dict) -> Segment:
42
+ return Segment(
43
+ id=data["id"],
44
+ seek=data["seek"],
45
+ start=data["start"],
46
+ end=data["end"],
47
+ text=data["text"],
48
+ tokens=data["tokens"],
49
+ temperature=data["temperature"],
50
+ avg_logprob=data["avg_logprob"],
51
+ compression_ratio=data["compression_ratio"],
52
+ no_speech_prob=data["no_speech_prob"],
53
+ words=[dict_to_word(word) for word in data["words"]] if data["words"] else None
54
+ )
poetry.lock CHANGED
The diff for this file is too large to render. See raw diff
 
pyproject.toml CHANGED
@@ -5,25 +5,38 @@ description = ""
5
  authors = ["AshDavid12 <[email protected]>"]
6
 
7
  [tool.poetry.dependencies]
8
- python = "3.9.1"
9
- pytube = "^15.0.0"
10
- pydantic = "^2.8.2"
11
- python-dotenv = "^1.0.1"
12
- fastapi = "^0.111.1"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  requests = "^2.32.3"
14
- httpx = "^0.27.0"
15
- uvicorn = "^0.30.1"
16
- asyncio = "^3.4.3"
17
- confluent-kafka = "^2.5.0"
18
- flask = "^3.0.3"
19
- faster-whisper = "1.0.3"
20
- runpod = "^1.7.0"
21
- librosa = ">=0.10.2.post1,<0.11.0"
22
  soundfile = "^0.12.1"
23
- openai = "^1.42.0"
24
- numpy = "^1.22.0"
25
- torch = "2.1.0"
26
- sounddevice = "^0.5.0"
 
 
 
 
27
 
28
 
29
 
 
5
  authors = ["AshDavid12 <[email protected]>"]
6
 
7
  [tool.poetry.dependencies]
8
+ python = "3.11.7"
9
+ # pytube = "^15.0.0"
10
+ #pydantic = "^2.8.2"
11
+ #python-dotenv = "^1.0.1"
12
+ #fastapi = "^0.111.1"
13
+ #requests = "^2.32.3"
14
+ #httpx = "^0.27.0"
15
+ #uvicorn = "^0.30.1"
16
+ #asyncio = "^3.4.3"
17
+ #confluent-kafka = "^2.5.0"
18
+ #flask = "^3.0.3"
19
+ #faster-whisper = "1.0.3"
20
+ #runpod = "^1.7.0"
21
+ #librosa = ">=0.10.2.post1,<0.11.0"
22
+ #soundfile = "^0.12.1"
23
+ #openai = "^1.42.0"
24
+ #numpy = "^1.22.0"
25
+ #torch = "2.1.0"
26
+ #sounddevice = "^0.5.0"
27
+ torch = "^2.4.1"
28
+ whisper = "^1.1.10"
29
  requests = "^2.32.3"
30
+ transformers = "^4.44.2"
 
 
 
 
 
 
 
31
  soundfile = "^0.12.1"
32
+ faster-whisper = "^1.0.3"
33
+ fastapi = "^0.114.2"
34
+ websockets = "^13.0.1"
35
+ websocket-client = "^1.8.0"
36
+ librosa = "^0.10.2.post1"
37
+ uvicorn = "^0.30.6"
38
+ torchaudio = "^2.4.1"
39
+ silero-vad = "^5.1"
40
 
41
 
42
 
streaming_client.py ADDED
@@ -0,0 +1,257 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # remote_whisper.py
2
+
3
+ import sys
4
+ import time
5
+ import logging
6
+ import os
7
+ import requests
8
+
9
+ import json
10
+ import base64
11
+ import numpy as np
12
+ import soundfile as sf
13
+ import io
14
+
15
+ # Import the necessary components from whisper_online.py
16
+ from libs.whisper_streaming.whisper_online import (
17
+ ASRBase,
18
+ OnlineASRProcessor,
19
+ VACOnlineASRProcessor,
20
+ add_shared_args,
21
+ asr_factory as original_asr_factory,
22
+ set_logging,
23
+ create_tokenizer,
24
+ load_audio,
25
+ load_audio_chunk, OpenaiApiASR,
26
+ )
27
+ from model import dict_to_segment
28
+
29
+ logger = logging.getLogger(__name__)
30
+
31
+ # Define the RemoteFasterWhisperASR class
32
+ class RemoteFasterWhisperASR(ASRBase):
33
+ """Uses a remote FasterWhisper model via WebSocket."""
34
+ sep = "" # Same as FasterWhisperASR
35
+
36
+ def load_model(self, *args, **kwargs):
37
+ import websocket
38
+ self.ws = websocket.WebSocket()
39
+ # Replace with your server address
40
+ server_address = "ws://localhost:8000/ws_transcribe_streaming" # Update with the actual server address
41
+ self.ws.connect(server_address)
42
+ logger.info(f"Connected to remote ASR server at {server_address}")
43
+
44
+ def transcribe(self, audio, init_prompt=""):
45
+ # Convert audio data to WAV bytes
46
+ if isinstance(audio, str):
47
+ # If audio is a filename, read the file
48
+ with open(audio, 'rb') as f:
49
+ audio_bytes = f.read()
50
+ elif isinstance(audio, np.ndarray):
51
+ # Write audio data to a buffer in WAV format
52
+ audio_bytes_io = io.BytesIO()
53
+ sf.write(audio_bytes_io, audio, samplerate=16000, format='WAV', subtype='PCM_16')
54
+ audio_bytes = audio_bytes_io.getvalue()
55
+ else:
56
+ raise ValueError("Unsupported audio input type")
57
+
58
+ # Encode to base64
59
+ audio_b64 = base64.b64encode(audio_bytes).decode('utf-8')
60
+ data = {
61
+ 'audio': audio_b64,
62
+ 'init_prompt': init_prompt
63
+ }
64
+ self.ws.send(json.dumps(data))
65
+ response = self.ws.recv()
66
+ segments = json.loads(response)
67
+ segments = [dict_to_segment(s) for s in segments]
68
+ return segments
69
+
70
+ def ts_words(self, segments):
71
+ o = []
72
+ for segment in segments:
73
+ for word in segment.words:
74
+ if segment.no_speech_prob > 0.9:
75
+ continue
76
+ # not stripping the spaces -- should not be merged with them!
77
+ w = word.word
78
+ t = (word.start, word.end, w)
79
+ o.append(t)
80
+ return o
81
+
82
+ def segments_end_ts(self, res):
83
+ return [s.end for s in res]
84
+
85
+ def use_vad(self):
86
+ self.transcribe_kargs["vad_filter"] = True
87
+
88
+ def set_translate_task(self):
89
+ self.transcribe_kargs["task"] = "translate"
90
+
91
+ # Update asr_factory to include RemoteFasterWhisperASR
92
+ def asr_factory(args, logfile=sys.stderr):
93
+ """
94
+ Creates and configures an ASR and Online ASR Processor instance based on the specified backend and arguments.
95
+ """
96
+ backend = args.backend
97
+ if backend == "openai-api":
98
+ logger.debug("Using OpenAI API.")
99
+ asr = OpenaiApiASR(lan=args.lan)
100
+ elif backend == "remote-faster-whisper":
101
+ asr_cls = RemoteFasterWhisperASR
102
+ else:
103
+ # Use the original asr_factory for other backends
104
+ return original_asr_factory(args, logfile)
105
+
106
+ # For RemoteFasterWhisperASR
107
+ t = time.time()
108
+ logger.info(f"Initializing Remote Faster Whisper ASR for language '{args.lan}'...")
109
+ asr = asr_cls(modelsize=args.model, lan=args.lan, cache_dir=args.model_cache_dir, model_dir=args.model_dir)
110
+ e = time.time()
111
+ logger.info(f"Initialization done. It took {round(e - t, 2)} seconds.")
112
+
113
+ # Apply common configurations
114
+ if getattr(args, 'vad', False): # Checks if VAD argument is present and True
115
+ logger.info("Setting VAD filter")
116
+ asr.use_vad()
117
+
118
+ language = args.lan
119
+ if args.task == "translate":
120
+ asr.set_translate_task()
121
+ tgt_language = "en" # Whisper translates into English
122
+ else:
123
+ tgt_language = language # Whisper transcribes in this language
124
+
125
+ # Create the tokenizer
126
+ if args.buffer_trimming == "sentence":
127
+ tokenizer = create_tokenizer(tgt_language)
128
+ else:
129
+ tokenizer = None
130
+
131
+ # Create the OnlineASRProcessor
132
+ if args.vac:
133
+ online = VACOnlineASRProcessor(
134
+ args.min_chunk_size,
135
+ asr,
136
+ tokenizer,
137
+ logfile=logfile,
138
+ buffer_trimming=(args.buffer_trimming, args.buffer_trimming_sec)
139
+ )
140
+ else:
141
+ online = OnlineASRProcessor(
142
+ asr,
143
+ tokenizer,
144
+ logfile=logfile,
145
+ buffer_trimming=(args.buffer_trimming, args.buffer_trimming_sec)
146
+ )
147
+
148
+ return asr, online
149
+
150
+ # Now, write the main function that uses RemoteFasterWhisperASR
151
+ def main():
152
+ import argparse
153
+ import sys
154
+ import numpy as np
155
+ import io
156
+ import soundfile as sf
157
+
158
+ # Download the audio file if not already present
159
+ AUDIO_FILE_URL = "https://raw.githubusercontent.com/AshDavid12/runpod-serverless-forked/main/test_hebrew.wav"
160
+ audio_file_path = "test_hebrew.wav"
161
+ if not os.path.exists(audio_file_path):
162
+ response = requests.get(AUDIO_FILE_URL)
163
+ with open(audio_file_path, 'wb') as f:
164
+ f.write(response.content)
165
+
166
+ # Set up arguments
167
+ class Args:
168
+ def __init__(self):
169
+ self.audio_path = audio_file_path
170
+ self.lan = 'he'
171
+ self.model = None # Not used in RemoteFasterWhisperASR
172
+ self.model_cache_dir = None
173
+ self.model_dir = None
174
+ self.backend = 'remote-faster-whisper'
175
+ self.task = 'transcribe'
176
+ self.vad = False
177
+ self.vac = True # Use VAC as default
178
+ self.buffer_trimming = 'segment'
179
+ self.buffer_trimming_sec = 15
180
+ self.min_chunk_size = 1.0
181
+ self.vac_chunk_size = 0.04
182
+ self.start_at = 0.0
183
+ self.offline = False
184
+ self.comp_unaware = False
185
+ self.log_level = 'DEBUG'
186
+
187
+ args = Args()
188
+
189
+ # Set up logging
190
+ set_logging(args, logger)
191
+
192
+ audio_path = args.audio_path
193
+
194
+ SAMPLING_RATE = 16000
195
+
196
+ duration = len(load_audio(audio_path)) / SAMPLING_RATE
197
+ logger.info("Audio duration is: %2.2f seconds" % duration)
198
+
199
+ asr, online = asr_factory(args, logfile=sys.stderr)
200
+ if args.vac:
201
+ min_chunk = args.vac_chunk_size
202
+ else:
203
+ min_chunk = args.min_chunk_size
204
+
205
+ # Load the audio into the LRU cache before we start the timer
206
+ a = load_audio_chunk(audio_path, 0, 1)
207
+
208
+ # Warm up the ASR because the very first transcribe takes more time
209
+ asr.transcribe(a)
210
+
211
+ beg = args.start_at
212
+ start = time.time() - beg
213
+
214
+ def output_transcript(o, now=None):
215
+ # Output format in stdout is like:
216
+ # 4186.3606 0 1720 Takhle to je
217
+ # - The first three numbers are:
218
+ # - Emission time from the beginning of processing, in milliseconds
219
+ # - Begin and end timestamp of the text segment, as estimated by Whisper model
220
+ # - The next words: segment transcript
221
+ if now is None:
222
+ now = time.time() - start
223
+ if o[0] is not None:
224
+ print("%1.4f %1.0f %1.0f %s" % (now * 1000, o[0] * 1000, o[1] * 1000, o[2]), flush=True)
225
+ else:
226
+ # No text, so no output
227
+ pass
228
+
229
+ end = 0
230
+ while True:
231
+ now = time.time() - start
232
+ if now < end + min_chunk:
233
+ time.sleep(min_chunk + end - now)
234
+ end = time.time() - start
235
+ a = load_audio_chunk(audio_path, beg, end)
236
+ beg = end
237
+ online.insert_audio_chunk(a)
238
+
239
+ try:
240
+ o = online.process_iter()
241
+ except AssertionError as e:
242
+ logger.error(f"Assertion error: {e}")
243
+ pass
244
+ else:
245
+ output_transcript(o)
246
+ now = time.time() - start
247
+ logger.debug(f"## Last processed {end:.2f} s, now is {now:.2f}, latency is {now - end:.2f}")
248
+
249
+ if end >= duration:
250
+ break
251
+ now = None
252
+
253
+ o = online.finish()
254
+ output_transcript(o, now=now)
255
+
256
+ if __name__ == "__main__":
257
+ main()