File size: 8,215 Bytes
e8aa012
 
 
9d710fb
e8aa012
 
 
 
 
9d710fb
e8aa012
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9d710fb
e8aa012
 
 
 
9d710fb
 
e8aa012
 
 
 
 
 
 
9d710fb
 
 
 
 
 
 
 
 
 
e8aa012
9d710fb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e8aa012
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9d710fb
 
 
 
 
 
 
 
 
 
 
 
 
 
e8aa012
 
 
 
9d710fb
 
e8aa012
 
9d710fb
e8aa012
 
 
 
 
9d710fb
 
 
e8aa012
9d710fb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e8aa012
 
 
 
 
 
 
 
 
9d710fb
 
e8aa012
9d710fb
 
 
 
 
 
e8aa012
 
 
9d710fb
e8aa012
 
 
 
 
 
9d710fb
 
 
 
 
e8aa012
9d710fb
 
e8aa012
9d710fb
 
 
e8aa012
 
 
 
 
 
 
 
 
 
9d710fb
 
 
 
 
e8aa012
 
 
 
 
 
 
 
9d710fb
e8aa012
9d710fb
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
# Import the necessary components from whisper_online.py
import logging
import os
from typing import Optional

import librosa
import soundfile
import uvicorn
from fastapi import FastAPI, WebSocket
from pydantic import BaseModel, ConfigDict
from starlette.websockets import WebSocketDisconnect

from libs.whisper_streaming.whisper_online import (
    ASRBase,
    OnlineASRProcessor,
    VACOnlineASRProcessor,
    add_shared_args,
    asr_factory,
    set_logging,
    create_tokenizer,
    load_audio,
    load_audio_chunk, OpenaiApiASR,
    set_logging
)

import argparse
import sys
import numpy as np
import io
import soundfile
import wave
import requests
import argparse

# from libs.whisper_streaming.whisper_online_server import online

logger = logging.getLogger(__name__)

SAMPLING_RATE = 16000
WARMUP_FILE = "mono16k.test_hebrew.wav"
AUDIO_FILE_URL = "https://raw.githubusercontent.com/AshDavid12/runpod-serverless-forked/main/test_hebrew.wav"

app = FastAPI()
args = argparse.ArgumentParser()
add_shared_args(args)

def drop_option_from_parser(parser, option_name):
    """
    Reinitializes the parser and copies all options except the specified option.

    Args:
        parser (argparse.ArgumentParser): The original argument parser.
        option_name (str): The option string to drop (e.g., '--model').

    Returns:
        argparse.ArgumentParser: A new parser without the specified option.
    """
    # Create a new parser with the same description and other attributes
    new_parser = argparse.ArgumentParser(
        description=parser.description,
        epilog=parser.epilog,
        formatter_class=parser.formatter_class
    )

    # Iterate through all the arguments of the original parser
    for action in parser._actions:
        if "-h" in action.option_strings:
            continue

        # Check if the option is not the one to drop
        if option_name not in action.option_strings :
            new_parser._add_action(action)

    return new_parser

def convert_to_mono_16k(input_wav: str, output_wav: str) -> None:
    """
    Converts any .wav file to mono 16 kHz.

    Args:
        input_wav (str): Path to the input .wav file.
        output_wav (str): Path to save the output .wav file with mono 16 kHz.
    """
    # Step 1: Load the audio file with librosa
    audio_data, original_sr = librosa.load(input_wav, sr=None, mono=False)  # Load at original sampling rate
    logger.info("Loaded audio with shape: %s, original sampling rate: %d" % (audio_data.shape, original_sr))

    # Step 2: If the audio has multiple channels, average them to make it mono
    if audio_data.ndim > 1:
        audio_data = librosa.to_mono(audio_data)

    # Step 3: Resample the audio to 16 kHz
    resampled_audio = librosa.resample(audio_data, orig_sr=original_sr, target_sr=SAMPLING_RATE)

    # Step 4: Save the resampled audio as a .wav file in mono at 16 kHz
    sf.write(output_wav, resampled_audio, SAMPLING_RATE)

    logger.info(f"Converted audio saved to {output_wav}")

def download_warmup_file():
    # Download the audio file if not already present
    audio_file_path = "test_hebrew.wav"
    if not os.path.exists(WARMUP_FILE):
        if not os.path.exists(audio_file_path):
            response = requests.get(AUDIO_FILE_URL)
            with open(audio_file_path, 'wb') as f:
                f.write(response.content)

        convert_to_mono_16k(audio_file_path, WARMUP_FILE)




class State(BaseModel):
    model_config = ConfigDict(arbitrary_types_allowed=True)

    websocket: WebSocket
    asr: ASRBase
    online: OnlineASRProcessor
    min_limit: int

    is_first: bool = True
    last_end: Optional[float] = None

async def receive_audio_chunk(state: State) -> Optional[np.ndarray]:
    # receive all audio that is available by this time
    # blocks operation if less than self.min_chunk seconds is available
    # unblocks if connection is closed or a chunk is available
    out = []
    while sum(len(x) for x in out) < state.min_limit:
        raw_bytes = await state.websocket.receive_bytes()
        if not raw_bytes:
            break
#            print("received audio:",len(raw_bytes), "bytes", raw_bytes[:10])
        sf = soundfile.SoundFile(io.BytesIO(raw_bytes), channels=1,endian="LITTLE",samplerate=SAMPLING_RATE, subtype="PCM_16",format="RAW")
        audio, _ = librosa.load(sf,sr=SAMPLING_RATE,dtype=np.float32)
        out.append(audio)
    if not out:
        return None
    flat_out = np.concatenate(out)
    if state.is_first and len(flat_out) < state.min_limit:
        return None

    state.is_first = False
    return flat_out

def format_output_transcript(state, o) -> dict:
    # output format in stdout is like:
    # 0 1720 Takhle to je
    # - the first two words are:
    #    - beg and end timestamp of the text segment, as estimated by Whisper model. The timestamps are not accurate, but they're useful anyway
    # - the next words: segment transcript

    # This function differs from whisper_online.output_transcript in the following:
    # succeeding [beg,end] intervals are not overlapping because ELITR protocol (implemented in online-text-flow events) requires it.
    # Therefore, beg, is max of previous end and current beg outputed by Whisper.
    # Usually it differs negligibly, by appx 20 ms.

    if o[0] is not None:
        beg, end = o[0]*1000,o[1]*1000
        if state.last_end is not None:
            beg = max(beg, state.last_end)

        state.last_end = end
        print("%1.0f %1.0f %s" % (beg,end,o[2]),flush=True,file=sys.stderr)
        return {
            "start": "%1.0f" % beg,
            "end": "%1.0f" % end,
            "text": "%s" % o[2],
        }
    else:
        logger.debug("No text in this segment")
        return None

# Define WebSocket endpoint
@app.websocket("/ws_transcribe_streaming")
async def websocket_transcribe(websocket: WebSocket):
    logger.info("New WebSocket connection request received.")
    await websocket.accept()
    logger.info("WebSocket connection established successfully.")

    # initialize the ASR model
    logger.info("Loading whisper model...")
    asr, online = asr_factory(args)
    state = State(
        websocket=websocket,
        asr=asr,
        online=online,
        min_limit=args.min_chunk_size * SAMPLING_RATE,
    )

    # warm up the ASR because the very first transcribe takes more time than the others.
    # Test results in https://github.com/ufal/whisper_streaming/pull/81
    logger.info("Warming up the whisper model...")
    a = load_audio_chunk(WARMUP_FILE, 0, 1)
    asr.transcribe(a)
    logger.info("Whisper is warmed up.")

    try:
        while True:
            a = await receive_audio_chunk(state)
            if a is None:
                break
            state.online.insert_audio_chunk(a)
            o = online.process_iter()
            try:
                if result := format_output_transcript(state, o):
                    await websocket.send_json(result)

            except BrokenPipeError:
                logger.info("broken pipe -- connection closed?")
                break
            except WebSocketDisconnect:
                logger.info("WebSocket connection closed by the client.")
                break
            except Exception as e:
                logger.error(f"Unexpected error during WebSocket transcription: {e}")
                await websocket.send_json({"error": str(e)})
    finally:
        logger.info("Cleaning up and closing WebSocket connection.")

def main():
    global args
    args = drop_option_from_parser(args, '--model')
    args.add_argument('--model', type=str,
                      help="Name size of the Whisper model to use. The model is automatically downloaded from the model hub if not present in model cache dir.")

    args.parse_args([
        '--lan', 'he',
        '--model', 'ivrit-ai/faster-whisper-v2-d4',
        '--backend', 'faster-whisper',
        '--vad',
        # '--vac', '--buffer_trimming', 'segment', '--buffer_trimming_sec', '15', '--min_chunk_size', '1.0', '--vac_chunk_size', '0.04', '--start_at', '0.0', '--offline', '--comp_unaware', '--log_level', 'DEBUG'
    ])

    uvicorn.run(app)

if __name__ == "__main__":
    main()