Spaces:
Sleeping
Sleeping
File size: 9,668 Bytes
e8aa012 9d710fb e8aa012 9d710fb e8aa012 9d710fb e8aa012 9d710fb e8aa012 9d710fb e8aa012 9d710fb e8aa012 9d710fb e8aa012 9d710fb e8aa012 9d710fb e8aa012 9d710fb e8aa012 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 |
import argparse
import json
import threading
import time
from pathlib import Path
from typing import List
import websocket
import os
import librosa
import numpy as np
# Define the default WebSocket endpoint
DEFAULT_WS_URL = "ws://localhost:8000/v1/audio/transcriptions"
def parse_arguments():
parser = argparse.ArgumentParser(description="Stream audio to the transcription WebSocket endpoint.")
parser.add_argument("audio_file", help="Path to the input audio file.")
parser.add_argument("--url", default=DEFAULT_WS_URL, help="WebSocket endpoint URL.")
parser.add_argument("--model", type=str, help="Model name to use for transcription.")
parser.add_argument("--language", type=str, help="Language code for transcription.")
parser.add_argument(
"--response_format",
type=str,
default="verbose_json",
choices=["text", "json", "verbose_json"],
help="Response format.",
)
parser.add_argument("--temperature", type=float, default=0.0, help="Temperature for transcription.")
parser.add_argument("--vad_filter", action="store_true", help="Enable voice activity detection filter.")
parser.add_argument("--chunk_duration", type=float, default=1.0, help="Duration of each audio chunk in seconds.")
return parser.parse_args()
# def preprocess_audio(audio_file, target_sr=16000):
# """
# Load the audio file, convert to mono 16kHz, and return the audio data.
# """
# if audio_file.endswith(".mp3"):
# # Convert MP3 to WAV using ffmpeg
# wav_file = audio_file.replace(".mp3", ".wav")
# if not os.path.exists(wav_file):
# command = f'ffmpeg -i "{audio_file}" -ac 1 -ar {target_sr} "{wav_file}"'
# print(f"Converting MP3 to WAV: {command}")
# os.system(command)
# audio_file = wav_file
#
# print(f"Loading audio file {audio_file}")
# audio_data, sr = librosa.load(audio_file, sr=target_sr, mono=True)
# return audio_data, sr
#
# def chunk_audio(audio_data, sr, chunk_duration):
# """
# Split the audio data into chunks of specified duration.
# """
# chunk_samples = int(chunk_duration * sr)
# total_samples = len(audio_data)
# chunks = [
# audio_data[i:i + chunk_samples]
# for i in range(0, total_samples, chunk_samples)
# ]
# print(f"Split audio into {len(chunks)} chunks of {chunk_duration} seconds each.")
# return chunks
def read_audio_in_chunks(audio_file, target_sr=16000, chunk_duration=1) -> List[np.ndarray]:
"""
Reads a 16kHz mono audio file in 1-second chunks and returns them as little-endian 16-bit integer arrays.
Args:
file_path (str): Path to the audio file.
expected_sr (int): Expected sample rate (16000 by default).
expected_mono (bool): Expect the file to be mono (True by default).
chunk_duration (int): Duration of each chunk in seconds (1 second by default).
Returns:
List of numpy arrays: Each array is a 1-second chunk of the audio as 16-bit integers.
Raises:
ValueError: If the audio file's sample rate or number of channels doesn't match expectations.
"""
if not str(audio_file).endswith(".wav"):
# Convert MP3 to WAV using ffmpeg
wav_file = Path(audio_file).with_suffix(".wav")
if not wav_file.exists():
command = f'ffmpeg -i "{audio_file}" -ac 1 -ar {target_sr} "{wav_file}"'
print(f"Converting MP3 to WAV: {command}")
os.system(command)
audio_file = wav_file
# Load the audio file
audio_data, sr = librosa.load(audio_file, sr=None, mono=True)
# Raise an exception if the sample rate doesn't match
if sr != target_sr:
raise ValueError(f"Unexpected sample rate {sr}. Expected {target_sr}.")
# Convert the audio data to 16-bit PCM (little-endian)
audio_data_int16 = (audio_data * 32767).astype(np.int16)
# Check if the current byte order is little-endian
if audio_data_int16.dtype.byteorder == '>' or (
audio_data_int16.dtype.byteorder == '=' and np.dtype(np.int16).byteorder == '>'):
print("Byte swap performed to convert to little-endian.")
# Ensure little-endian format (if the current format is big-endian)
audio_data_little_endian = audio_data_int16.byteswap().newbyteorder('L')
else:
print("No byte swap needed. Already little-endian.")
audio_data_little_endian = audio_data_int16
# Calculate the number of samples per chunk
samples_per_chunk = target_sr * chunk_duration
# Split the audio into chunks
chunks = [
audio_data_little_endian[i:i + samples_per_chunk]
for i in range(0, len(audio_data_little_endian), samples_per_chunk)
]
return chunks
def build_query_params(args):
"""
Build the query parameters for the WebSocket URL based on command-line arguments.
"""
params = {}
if args.model:
params["model"] = args.model
if args.language:
params["language"] = args.language
if args.response_format:
params["response_format"] = args.response_format
if args.temperature is not None:
params["temperature"] = str(args.temperature)
if args.vad_filter:
params["vad_filter"] = "true"
return params
def websocket_url_with_params(base_url, params):
"""
Append query parameters to the WebSocket URL.
"""
from urllib.parse import urlencode
if params:
query_string = urlencode(params)
url = f"{base_url}?{query_string}"
else:
url = base_url
return url
def on_message(ws, message):
"""
Callback function when a message is received from the server.
"""
try:
data = json.loads(message)
# Accumulate transcriptions
if ws.args.response_format == "verbose_json":
segments = data.get('segments', [])
ws.transcriptions.extend(segments)
for segment in segments:
print(f"Received segment: {segment['text']}")
else:
# For 'json' or 'text' format
ws.transcriptions.append(data)
print(f"Transcription: {data['text']}")
except json.JSONDecodeError:
print(f"Received non-JSON message: {message}")
def on_error(ws, error):
"""
Callback function when an error occurs.
"""
print(f"WebSocket error: {error}")
def on_close(ws, close_status_code, close_msg):
"""
Callback function when the WebSocket connection is closed.
"""
print("WebSocket connection closed")
def on_open(ws):
"""
Callback function when the WebSocket connection is opened.
"""
print("WebSocket connection opened")
ws.transcriptions = [] # Initialize the list to store transcriptions
def send_audio_chunks(ws, audio_chunks, sr):
"""
Send audio chunks to the WebSocket server.
"""
for idx, chunk in enumerate(audio_chunks):
# Ensure little-endian format
audio_bytes = chunk.astype('<f4').tobytes()
ws.send(audio_bytes, opcode=websocket.ABNF.OPCODE_BINARY)
print(f"Sent chunk {idx + 1}/{len(audio_chunks)}")
time.sleep(0.1) # Small delay to simulate real-time streaming
print("All audio chunks sent")
# Optionally, wait to receive remaining messages
time.sleep(2)
ws.close()
print("Closed WebSocket connection")
def format_timestamp(seconds):
"""
Convert seconds to SRT timestamp format (HH:MM:SS,mmm).
"""
total_milliseconds = int(seconds * 1000)
hours = total_milliseconds // (3600 * 1000)
minutes = (total_milliseconds % (3600 * 1000)) // (60 * 1000)
secs = (total_milliseconds % (60 * 1000)) // 1000
milliseconds = total_milliseconds % 1000
return f"{hours:02}:{minutes:02}:{secs:02},{milliseconds:03}"
def generate_srt(transcriptions):
"""
Generate and print SRT content from transcriptions.
"""
print("\nGenerated SRT:")
for idx, segment in enumerate(transcriptions, 1):
start_time = format_timestamp(segment['start'])
end_time = format_timestamp(segment['end'])
text = segment['text']
print(f"{idx}")
print(f"{start_time} --> {end_time}")
print(f"{text}\n")
def run_websocket_client(args):
"""
Run the WebSocket client to stream audio and receive transcriptions.
"""
try:
audio_chunks = read_audio_in_chunks(args.audio_file)
params = build_query_params(args)
ws_url = websocket_url_with_params(args.url, params)
ws = websocket.WebSocketApp(
ws_url,
on_open=on_open,
on_message=on_message,
on_error=on_error,
on_close=on_close,
)
ws.args = args # Attach args to ws to access inside callbacks
# Run the WebSocket in a separate thread to allow sending and receiving simultaneously
ws_thread = threading.Thread(target=ws.run_forever)
ws_thread.start()
# Wait for the connection to open
while not ws.sock or not ws.sock.connected:
time.sleep(0.1)
# Send the audio chunks
send_audio_chunks(ws, audio_chunks, 16000)
except Exception as e:
print(f"An error occurred: {e}")
# Wait for the WebSocket thread to finish
ws_thread.join()
# Generate SRT if transcriptions are available
if hasattr(ws, 'transcriptions') and ws.transcriptions:
generate_srt(ws.transcriptions)
else:
print("No transcriptions received.")
if __name__ == "__main__":
args = parse_arguments()
run_websocket_client(args)
|