Spaces:
Sleeping
Sleeping
# remote_whisper.py | |
import sys | |
import time | |
import logging | |
import os | |
from wave import Wave_read | |
import requests | |
import json | |
import base64 | |
import numpy as np | |
import soundfile as sf | |
import io | |
import librosa | |
# Import the necessary components from whisper_online.py | |
from libs.whisper_streaming.whisper_online import ( | |
ASRBase, | |
OnlineASRProcessor, | |
VACOnlineASRProcessor, | |
add_shared_args, | |
asr_factory as original_asr_factory, | |
set_logging, | |
create_tokenizer, | |
load_audio, | |
load_audio_chunk, OpenaiApiASR, | |
) | |
from model import dict_to_segment, get_raw_words_from_segments | |
logging.basicConfig(level=logging.INFO, format='%(asctime)s %(levelname)s: %(message)s', | |
handlers=[logging.StreamHandler(sys.stdout)], force=True) | |
logger = logging.getLogger(__name__) | |
def convert_to_mono_16k(input_wav: str, output_wav: str) -> None: | |
""" | |
Converts any .wav file to mono 16 kHz. | |
Args: | |
input_wav (str): Path to the input .wav file. | |
output_wav (str): Path to save the output .wav file with mono 16 kHz. | |
""" | |
# Step 1: Load the audio file with librosa | |
audio_data, original_sr = librosa.load(input_wav, sr=None, mono=False) # Load at original sampling rate | |
logger.info("Loaded audio with shape: %s, original sampling rate: %d" % (audio_data.shape, original_sr)) | |
# Step 2: If the audio has multiple channels, average them to make it mono | |
if audio_data.ndim > 1: | |
audio_data = librosa.to_mono(audio_data) | |
# Step 3: Resample the audio to 16 kHz | |
target_sr = 16000 | |
resampled_audio = librosa.resample(audio_data, orig_sr=original_sr, target_sr=target_sr) | |
# Step 4: Save the resampled audio as a .wav file in mono at 16 kHz | |
sf.write(output_wav, resampled_audio, target_sr) | |
logger.info(f"Converted audio saved to {output_wav}") | |
# Example usage: | |
# convert_to_mono_16k('input_audio.wav', 'output_audio_16k_mono.wav') | |
# Define the RemoteFasterWhisperASR class | |
class RemoteFasterWhisperASR(ASRBase): | |
"""Uses a remote FasterWhisper model via WebSocket.""" | |
sep = "" # Same as FasterWhisperASR | |
def load_model(self, *args, **kwargs): | |
import websocket | |
self.ws = websocket.WebSocket() | |
# Replace with your server address | |
server_address = "ws://localhost:8000/ws_transcribe_streaming" # Update with the actual server address | |
self.ws.connect(server_address) | |
logger.info(f"Connected to remote ASR server at {server_address}") | |
def transcribe(self, audio, init_prompt=""): | |
# Convert audio data to WAV bytes | |
if isinstance(audio, str): | |
# If audio is a filename, read the file | |
with open(audio, 'rb') as f: | |
audio_bytes = f.read() | |
elif isinstance(audio, np.ndarray): | |
# Write audio data to a buffer in WAV format | |
audio_bytes_io = io.BytesIO() | |
sf.write(audio_bytes_io, audio, samplerate=16000, format='WAV', subtype='PCM_16') | |
audio_bytes = audio_bytes_io.getvalue() | |
else: | |
raise ValueError("Unsupported audio input type") | |
# Encode to base64 | |
audio_b64 = base64.b64encode(audio_bytes).decode('utf-8') | |
data = { | |
'audio': audio_b64, | |
'init_prompt': init_prompt | |
} | |
self.ws.send(json.dumps(data)) | |
response = self.ws.recv() | |
segments = json.loads(response) | |
segments = [dict_to_segment(s) for s in segments] | |
logger.info(get_raw_words_from_segments(segments)) | |
return segments | |
def ts_words(self, segments): | |
o = [] | |
for segment in segments: | |
for word in segment.words: | |
if segment.no_speech_prob > 0.9: | |
continue | |
# not stripping the spaces -- should not be merged with them! | |
w = word.word | |
t = (word.start, word.end, w) | |
o.append(t) | |
return o | |
def segments_end_ts(self, res): | |
return [s.end for s in res] | |
def use_vad(self): | |
self.transcribe_kargs["vad_filter"] = True | |
def set_translate_task(self): | |
self.transcribe_kargs["task"] = "translate" | |
# Update asr_factory to include RemoteFasterWhisperASR | |
def asr_factory(args, logfile=sys.stderr): | |
""" | |
Creates and configures an ASR and Online ASR Processor instance based on the specified backend and arguments. | |
""" | |
backend = args.backend | |
if backend == "openai-api": | |
logger.debug("Using OpenAI API.") | |
asr = OpenaiApiASR(lan=args.lan) | |
elif backend == "remote-faster-whisper": | |
asr_cls = RemoteFasterWhisperASR | |
else: | |
# Use the original asr_factory for other backends | |
return original_asr_factory(args, logfile) | |
# For RemoteFasterWhisperASR | |
t = time.time() | |
logger.info(f"Initializing Remote Faster Whisper ASR for language '{args.lan}'...") | |
asr = asr_cls(modelsize=args.model, lan=args.lan, cache_dir=args.model_cache_dir, model_dir=args.model_dir) | |
e = time.time() | |
logger.info(f"Initialization done. It took {round(e - t, 2)} seconds.") | |
# Apply common configurations | |
if getattr(args, 'vad', False): # Checks if VAD argument is present and True | |
logger.info("Setting VAD filter") | |
asr.use_vad() | |
language = args.lan | |
if args.task == "translate": | |
asr.set_translate_task() | |
tgt_language = "en" # Whisper translates into English | |
else: | |
tgt_language = language # Whisper transcribes in this language | |
# Create the tokenizer | |
if args.buffer_trimming == "sentence": | |
tokenizer = create_tokenizer(tgt_language) | |
else: | |
tokenizer = None | |
# Create the OnlineASRProcessor | |
if args.vac: | |
online = VACOnlineASRProcessor( | |
args.min_chunk_size, | |
asr, | |
tokenizer, | |
logfile=logfile, | |
buffer_trimming=(args.buffer_trimming, args.buffer_trimming_sec) | |
) | |
else: | |
online = OnlineASRProcessor( | |
asr, | |
tokenizer, | |
logfile=logfile, | |
buffer_trimming=(args.buffer_trimming, args.buffer_trimming_sec) | |
) | |
return asr, online | |
# Now, write the main function that uses RemoteFasterWhisperASR | |
def main(): | |
import argparse | |
import sys | |
import numpy as np | |
import io | |
import soundfile as sf | |
import wave | |
# Download the audio file if not already present | |
AUDIO_FILE_URL = "https://raw.githubusercontent.com/AshDavid12/runpod-serverless-forked/main/test_hebrew.wav" | |
audio_file_path = "test_hebrew.wav" | |
mono16k_audio_file_path = "mono16k." + audio_file_path | |
if not os.path.exists(audio_file_path): | |
response = requests.get(AUDIO_FILE_URL) | |
with open(audio_file_path, 'wb') as f: | |
f.write(response.content) | |
if not os.path.exists(mono16k_audio_file_path): | |
convert_to_mono_16k(audio_file_path, mono16k_audio_file_path) | |
# Set up arguments | |
class Args: | |
def __init__(self): | |
self.audio_path = mono16k_audio_file_path | |
self.lan = 'he' | |
self.model = None # Not used in RemoteFasterWhisperASR | |
self.model_cache_dir = None | |
self.model_dir = None | |
self.backend = 'remote-faster-whisper' | |
self.task = 'transcribe' | |
self.vad = False | |
self.vac = True # Use VAC as default | |
self.buffer_trimming = 'segment' | |
self.buffer_trimming_sec = 15 | |
self.min_chunk_size = 1.0 | |
self.vac_chunk_size = 0.04 | |
self.start_at = 0.0 | |
self.offline = False | |
self.comp_unaware = False | |
self.log_level = 'DEBUG' | |
args = Args() | |
# Set up logging | |
set_logging(args, logger) | |
audio_path = args.audio_path | |
SAMPLING_RATE = 16000 | |
duration = len(load_audio(audio_path)) / SAMPLING_RATE | |
logger.info("Audio duration is: %2.2f seconds" % duration) | |
asr, online = asr_factory(args, logfile=sys.stderr) | |
if args.vac: | |
min_chunk = args.vac_chunk_size | |
else: | |
min_chunk = args.min_chunk_size | |
# Load the audio into the LRU cache before we start the timer | |
a = load_audio_chunk(audio_path, 0, 1) | |
# Warm up the ASR because the very first transcribe takes more time | |
asr.transcribe(a) | |
beg = args.start_at | |
start = time.time() - beg | |
def output_transcript(o, now=None): | |
# Output format in stdout is like: | |
# 4186.3606 0 1720 Takhle to je | |
# - The first three numbers are: | |
# - Emission time from the beginning of processing, in milliseconds | |
# - Begin and end timestamp of the text segment, as estimated by Whisper model | |
# - The next words: segment transcript | |
if now is None: | |
now = time.time() - start | |
if o[0] is not None: | |
print("%1.4f %1.0f %1.0f %s" % (now * 1000, o[0] * 1000, o[1] * 1000, o[2]), flush=True) | |
else: | |
# No text, so no output | |
pass | |
end = 0 | |
while True: | |
now = time.time() - start | |
if now < end + min_chunk: | |
time.sleep(min_chunk + end - now) | |
end = time.time() - start | |
a = load_audio_chunk(audio_path, beg, end) | |
beg = end | |
online.insert_audio_chunk(a) | |
try: | |
o = online.process_iter() | |
except AssertionError as e: | |
logger.error(f"Assertion error: {e}") | |
pass | |
else: | |
output_transcript(o) | |
now = time.time() - start | |
logger.debug(f"## Last processed {end:.2f} s, now is {now:.2f}, latency is {now - end:.2f}") | |
if end >= duration: | |
break | |
now = None | |
o = online.finish() | |
output_transcript(o, now=now) | |
if __name__ == "__main__": | |
main() | |