ivrit-ai-streaming / streaming_client.py
aviadr1's picture
sometimes the client works
d8dadfc
raw
history blame
9.83 kB
# remote_whisper.py
import sys
import time
import logging
import os
from wave import Wave_read
import requests
import json
import base64
import numpy as np
import soundfile as sf
import io
import librosa
# Import the necessary components from whisper_online.py
from libs.whisper_streaming.whisper_online import (
ASRBase,
OnlineASRProcessor,
VACOnlineASRProcessor,
add_shared_args,
asr_factory as original_asr_factory,
set_logging,
create_tokenizer,
load_audio,
load_audio_chunk, OpenaiApiASR,
)
from model import dict_to_segment, get_raw_words_from_segments
logging.basicConfig(level=logging.INFO, format='%(asctime)s %(levelname)s: %(message)s',
handlers=[logging.StreamHandler(sys.stdout)], force=True)
logger = logging.getLogger(__name__)
def convert_to_mono_16k(input_wav: str, output_wav: str) -> None:
"""
Converts any .wav file to mono 16 kHz.
Args:
input_wav (str): Path to the input .wav file.
output_wav (str): Path to save the output .wav file with mono 16 kHz.
"""
# Step 1: Load the audio file with librosa
audio_data, original_sr = librosa.load(input_wav, sr=None, mono=False) # Load at original sampling rate
logger.info("Loaded audio with shape: %s, original sampling rate: %d" % (audio_data.shape, original_sr))
# Step 2: If the audio has multiple channels, average them to make it mono
if audio_data.ndim > 1:
audio_data = librosa.to_mono(audio_data)
# Step 3: Resample the audio to 16 kHz
target_sr = 16000
resampled_audio = librosa.resample(audio_data, orig_sr=original_sr, target_sr=target_sr)
# Step 4: Save the resampled audio as a .wav file in mono at 16 kHz
sf.write(output_wav, resampled_audio, target_sr)
logger.info(f"Converted audio saved to {output_wav}")
# Example usage:
# convert_to_mono_16k('input_audio.wav', 'output_audio_16k_mono.wav')
# Define the RemoteFasterWhisperASR class
class RemoteFasterWhisperASR(ASRBase):
"""Uses a remote FasterWhisper model via WebSocket."""
sep = "" # Same as FasterWhisperASR
def load_model(self, *args, **kwargs):
import websocket
self.ws = websocket.WebSocket()
# Replace with your server address
server_address = "ws://localhost:8000/ws_transcribe_streaming" # Update with the actual server address
self.ws.connect(server_address)
logger.info(f"Connected to remote ASR server at {server_address}")
def transcribe(self, audio, init_prompt=""):
# Convert audio data to WAV bytes
if isinstance(audio, str):
# If audio is a filename, read the file
with open(audio, 'rb') as f:
audio_bytes = f.read()
elif isinstance(audio, np.ndarray):
# Write audio data to a buffer in WAV format
audio_bytes_io = io.BytesIO()
sf.write(audio_bytes_io, audio, samplerate=16000, format='WAV', subtype='PCM_16')
audio_bytes = audio_bytes_io.getvalue()
else:
raise ValueError("Unsupported audio input type")
# Encode to base64
audio_b64 = base64.b64encode(audio_bytes).decode('utf-8')
data = {
'audio': audio_b64,
'init_prompt': init_prompt
}
self.ws.send(json.dumps(data))
response = self.ws.recv()
segments = json.loads(response)
segments = [dict_to_segment(s) for s in segments]
logger.info(get_raw_words_from_segments(segments))
return segments
def ts_words(self, segments):
o = []
for segment in segments:
for word in segment.words:
if segment.no_speech_prob > 0.9:
continue
# not stripping the spaces -- should not be merged with them!
w = word.word
t = (word.start, word.end, w)
o.append(t)
return o
def segments_end_ts(self, res):
return [s.end for s in res]
def use_vad(self):
self.transcribe_kargs["vad_filter"] = True
def set_translate_task(self):
self.transcribe_kargs["task"] = "translate"
# Update asr_factory to include RemoteFasterWhisperASR
def asr_factory(args, logfile=sys.stderr):
"""
Creates and configures an ASR and Online ASR Processor instance based on the specified backend and arguments.
"""
backend = args.backend
if backend == "openai-api":
logger.debug("Using OpenAI API.")
asr = OpenaiApiASR(lan=args.lan)
elif backend == "remote-faster-whisper":
asr_cls = RemoteFasterWhisperASR
else:
# Use the original asr_factory for other backends
return original_asr_factory(args, logfile)
# For RemoteFasterWhisperASR
t = time.time()
logger.info(f"Initializing Remote Faster Whisper ASR for language '{args.lan}'...")
asr = asr_cls(modelsize=args.model, lan=args.lan, cache_dir=args.model_cache_dir, model_dir=args.model_dir)
e = time.time()
logger.info(f"Initialization done. It took {round(e - t, 2)} seconds.")
# Apply common configurations
if getattr(args, 'vad', False): # Checks if VAD argument is present and True
logger.info("Setting VAD filter")
asr.use_vad()
language = args.lan
if args.task == "translate":
asr.set_translate_task()
tgt_language = "en" # Whisper translates into English
else:
tgt_language = language # Whisper transcribes in this language
# Create the tokenizer
if args.buffer_trimming == "sentence":
tokenizer = create_tokenizer(tgt_language)
else:
tokenizer = None
# Create the OnlineASRProcessor
if args.vac:
online = VACOnlineASRProcessor(
args.min_chunk_size,
asr,
tokenizer,
logfile=logfile,
buffer_trimming=(args.buffer_trimming, args.buffer_trimming_sec)
)
else:
online = OnlineASRProcessor(
asr,
tokenizer,
logfile=logfile,
buffer_trimming=(args.buffer_trimming, args.buffer_trimming_sec)
)
return asr, online
# Now, write the main function that uses RemoteFasterWhisperASR
def main():
import argparse
import sys
import numpy as np
import io
import soundfile as sf
import wave
# Download the audio file if not already present
AUDIO_FILE_URL = "https://raw.githubusercontent.com/AshDavid12/runpod-serverless-forked/main/test_hebrew.wav"
audio_file_path = "test_hebrew.wav"
mono16k_audio_file_path = "mono16k." + audio_file_path
if not os.path.exists(audio_file_path):
response = requests.get(AUDIO_FILE_URL)
with open(audio_file_path, 'wb') as f:
f.write(response.content)
if not os.path.exists(mono16k_audio_file_path):
convert_to_mono_16k(audio_file_path, mono16k_audio_file_path)
# Set up arguments
class Args:
def __init__(self):
self.audio_path = mono16k_audio_file_path
self.lan = 'he'
self.model = None # Not used in RemoteFasterWhisperASR
self.model_cache_dir = None
self.model_dir = None
self.backend = 'remote-faster-whisper'
self.task = 'transcribe'
self.vad = False
self.vac = True # Use VAC as default
self.buffer_trimming = 'segment'
self.buffer_trimming_sec = 15
self.min_chunk_size = 1.0
self.vac_chunk_size = 0.04
self.start_at = 0.0
self.offline = False
self.comp_unaware = False
self.log_level = 'DEBUG'
args = Args()
# Set up logging
set_logging(args, logger)
audio_path = args.audio_path
SAMPLING_RATE = 16000
duration = len(load_audio(audio_path)) / SAMPLING_RATE
logger.info("Audio duration is: %2.2f seconds" % duration)
asr, online = asr_factory(args, logfile=sys.stderr)
if args.vac:
min_chunk = args.vac_chunk_size
else:
min_chunk = args.min_chunk_size
# Load the audio into the LRU cache before we start the timer
a = load_audio_chunk(audio_path, 0, 1)
# Warm up the ASR because the very first transcribe takes more time
asr.transcribe(a)
beg = args.start_at
start = time.time() - beg
def output_transcript(o, now=None):
# Output format in stdout is like:
# 4186.3606 0 1720 Takhle to je
# - The first three numbers are:
# - Emission time from the beginning of processing, in milliseconds
# - Begin and end timestamp of the text segment, as estimated by Whisper model
# - The next words: segment transcript
if now is None:
now = time.time() - start
if o[0] is not None:
print("%1.4f %1.0f %1.0f %s" % (now * 1000, o[0] * 1000, o[1] * 1000, o[2]), flush=True)
else:
# No text, so no output
pass
end = 0
while True:
now = time.time() - start
if now < end + min_chunk:
time.sleep(min_chunk + end - now)
end = time.time() - start
a = load_audio_chunk(audio_path, beg, end)
beg = end
online.insert_audio_chunk(a)
try:
o = online.process_iter()
except AssertionError as e:
logger.error(f"Assertion error: {e}")
pass
else:
output_transcript(o)
now = time.time() - start
logger.debug(f"## Last processed {end:.2f} s, now is {now:.2f}, latency is {now - end:.2f}")
if end >= duration:
break
now = None
o = online.finish()
output_transcript(o, now=now)
if __name__ == "__main__":
main()