Spaces:
Sleeping
Sleeping
import argparse | |
import json | |
import threading | |
import time | |
from pathlib import Path | |
from typing import List | |
import websocket | |
import os | |
import librosa | |
import numpy as np | |
# Define the default WebSocket endpoint | |
DEFAULT_WS_URL = "ws://localhost:8000/v1/ws_transcribe_streaming" | |
def parse_arguments(): | |
parser = argparse.ArgumentParser(description="Stream audio to the transcription WebSocket endpoint.") | |
parser.add_argument("audio_file", help="Path to the input audio file.") | |
parser.add_argument("--url", default=DEFAULT_WS_URL, help="WebSocket endpoint URL.") | |
parser.add_argument("--model", type=str, help="Model name to use for transcription.") | |
parser.add_argument("--language", type=str, help="Language code for transcription.") | |
parser.add_argument( | |
"--response_format", | |
type=str, | |
default="verbose_json", | |
choices=["text", "json", "verbose_json"], | |
help="Response format.", | |
) | |
parser.add_argument("--temperature", type=float, default=0.0, help="Temperature for transcription.") | |
parser.add_argument("--vad_filter", action="store_true", help="Enable voice activity detection filter.") | |
parser.add_argument("--chunk_duration", type=float, default=1.0, help="Duration of each audio chunk in seconds.") | |
return parser.parse_args() | |
# def preprocess_audio(audio_file, target_sr=16000): | |
# """ | |
# Load the audio file, convert to mono 16kHz, and return the audio data. | |
# """ | |
# if audio_file.endswith(".mp3"): | |
# # Convert MP3 to WAV using ffmpeg | |
# wav_file = audio_file.replace(".mp3", ".wav") | |
# if not os.path.exists(wav_file): | |
# command = f'ffmpeg -i "{audio_file}" -ac 1 -ar {target_sr} "{wav_file}"' | |
# print(f"Converting MP3 to WAV: {command}") | |
# os.system(command) | |
# audio_file = wav_file | |
# | |
# print(f"Loading audio file {audio_file}") | |
# audio_data, sr = librosa.load(audio_file, sr=target_sr, mono=True) | |
# return audio_data, sr | |
# | |
# def chunk_audio(audio_data, sr, chunk_duration): | |
# """ | |
# Split the audio data into chunks of specified duration. | |
# """ | |
# chunk_samples = int(chunk_duration * sr) | |
# total_samples = len(audio_data) | |
# chunks = [ | |
# audio_data[i:i + chunk_samples] | |
# for i in range(0, total_samples, chunk_samples) | |
# ] | |
# print(f"Split audio into {len(chunks)} chunks of {chunk_duration} seconds each.") | |
# return chunks | |
def read_audio_in_chunks(audio_file, target_sr=16000, chunk_duration=1) -> List[np.ndarray]: | |
""" | |
Reads a 16kHz mono audio file in 1-second chunks and returns them as little-endian 16-bit integer arrays. | |
Args: | |
file_path (str): Path to the audio file. | |
expected_sr (int): Expected sample rate (16000 by default). | |
expected_mono (bool): Expect the file to be mono (True by default). | |
chunk_duration (int): Duration of each chunk in seconds (1 second by default). | |
Returns: | |
List of numpy arrays: Each array is a 1-second chunk of the audio as 16-bit integers. | |
Raises: | |
ValueError: If the audio file's sample rate or number of channels doesn't match expectations. | |
""" | |
if not str(audio_file).endswith(".wav"): | |
# Convert MP3 to WAV using ffmpeg | |
wav_file = Path(audio_file).with_suffix(".wav") | |
if not wav_file.exists(): | |
command = f'ffmpeg -i "{audio_file}" -ac 1 -ar {target_sr} "{wav_file}"' | |
print(f"Converting MP3 to WAV: {command}") | |
os.system(command) | |
audio_file = wav_file | |
# Load the audio file | |
audio_data, sr = librosa.load(audio_file, sr=None, mono=True) | |
# Raise an exception if the sample rate doesn't match | |
if sr != target_sr: | |
raise ValueError(f"Unexpected sample rate {sr}. Expected {target_sr}.") | |
# Convert the audio data to 16-bit PCM (little-endian) | |
audio_data_int16 = (audio_data * 32767).astype(np.int16) | |
# Check if the current byte order is little-endian | |
if audio_data_int16.dtype.byteorder == '>' or ( | |
audio_data_int16.dtype.byteorder == '=' and np.dtype(np.int16).byteorder == '>'): | |
print("Byte swap performed to convert to little-endian.") | |
# Ensure little-endian format (if the current format is big-endian) | |
audio_data_little_endian = audio_data_int16.byteswap().newbyteorder('L') | |
else: | |
print("No byte swap needed. Already little-endian.") | |
audio_data_little_endian = audio_data_int16 | |
# Calculate the number of samples per chunk | |
samples_per_chunk = target_sr * chunk_duration | |
# Split the audio into chunks | |
chunks = [ | |
audio_data_little_endian[i:i + samples_per_chunk] | |
for i in range(0, len(audio_data_little_endian), samples_per_chunk) | |
] | |
return chunks | |
def build_query_params(args): | |
""" | |
Build the query parameters for the WebSocket URL based on command-line arguments. | |
""" | |
params = {} | |
if args.model: | |
params["model"] = args.model | |
if args.language: | |
params["language"] = args.language | |
if args.response_format: | |
params["response_format"] = args.response_format | |
if args.temperature is not None: | |
params["temperature"] = str(args.temperature) | |
if args.vad_filter: | |
params["vad_filter"] = "true" | |
return params | |
def websocket_url_with_params(base_url, params): | |
""" | |
Append query parameters to the WebSocket URL. | |
""" | |
from urllib.parse import urlencode | |
if params: | |
query_string = urlencode(params) | |
url = f"{base_url}?{query_string}" | |
else: | |
url = base_url | |
return url | |
def on_message(ws, message): | |
""" | |
Callback function when a message is received from the server. | |
""" | |
try: | |
data = json.loads(message) | |
# Accumulate transcriptions | |
if ws.args.response_format == "verbose_json": | |
segments = data.get('segments', []) | |
ws.transcriptions.extend(segments) | |
for segment in segments: | |
print(f"Received segment: {segment['text']}") | |
else: | |
# For 'json' or 'text' format | |
ws.transcriptions.append(data) | |
print(f"Transcription: {data['text']}") | |
except json.JSONDecodeError: | |
print(f"Received non-JSON message: {message}") | |
def on_error(ws, error): | |
""" | |
Callback function when an error occurs. | |
""" | |
print(f"WebSocket error: {error}") | |
def on_close(ws, close_status_code, close_msg): | |
""" | |
Callback function when the WebSocket connection is closed. | |
""" | |
print("WebSocket connection closed") | |
def on_open(ws): | |
""" | |
Callback function when the WebSocket connection is opened. | |
""" | |
print("WebSocket connection opened") | |
ws.transcriptions = [] # Initialize the list to store transcriptions | |
def send_audio_chunks(ws, audio_chunks, sr): | |
""" | |
Send audio chunks to the WebSocket server. | |
""" | |
for idx, chunk in enumerate(audio_chunks): | |
# Ensure little-endian format | |
audio_bytes = chunk.astype('<f4').tobytes() | |
ws.send(audio_bytes, opcode=websocket.ABNF.OPCODE_BINARY) | |
print(f"Sent chunk {idx + 1}/{len(audio_chunks)}") | |
time.sleep(0.1) # Small delay to simulate real-time streaming | |
print("All audio chunks sent") | |
# Optionally, wait to receive remaining messages | |
time.sleep(2) | |
ws.close() | |
print("Closed WebSocket connection") | |
def format_timestamp(seconds): | |
""" | |
Convert seconds to SRT timestamp format (HH:MM:SS,mmm). | |
""" | |
total_milliseconds = int(seconds * 1000) | |
hours = total_milliseconds // (3600 * 1000) | |
minutes = (total_milliseconds % (3600 * 1000)) // (60 * 1000) | |
secs = (total_milliseconds % (60 * 1000)) // 1000 | |
milliseconds = total_milliseconds % 1000 | |
return f"{hours:02}:{minutes:02}:{secs:02},{milliseconds:03}" | |
def generate_srt(transcriptions): | |
""" | |
Generate and print SRT content from transcriptions. | |
""" | |
print("\nGenerated SRT:") | |
for idx, segment in enumerate(transcriptions, 1): | |
start_time = format_timestamp(segment['start']) | |
end_time = format_timestamp(segment['end']) | |
text = segment['text'] | |
print(f"{idx}") | |
print(f"{start_time} --> {end_time}") | |
print(f"{text}\n") | |
def run_websocket_client(args): | |
""" | |
Run the WebSocket client to stream audio and receive transcriptions. | |
""" | |
try: | |
audio_chunks = read_audio_in_chunks(args.audio_file) | |
# params = build_query_params(args) | |
# ws_url = websocket_url_with_params(args.url, params) | |
ws_url = args.url | |
ws = websocket.WebSocketApp( | |
ws_url, | |
on_open=on_open, | |
on_message=on_message, | |
on_error=on_error, | |
on_close=on_close, | |
) | |
ws.args = args # Attach args to ws to access inside callbacks | |
# Run the WebSocket in a separate thread to allow sending and receiving simultaneously | |
ws_thread = threading.Thread(target=ws.run_forever) | |
ws_thread.start() | |
# Wait for the connection to open | |
while not ws.sock or not ws.sock.connected: | |
time.sleep(0.1) | |
# Send the audio chunks | |
send_audio_chunks(ws, audio_chunks, 16000) | |
except Exception as e: | |
print(f"An error occurred: {e}") | |
# Wait for the WebSocket thread to finish | |
ws_thread.join() | |
# Generate SRT if transcriptions are available | |
if hasattr(ws, 'transcriptions') and ws.transcriptions: | |
generate_srt(ws.transcriptions) | |
else: | |
print("No transcriptions received.") | |
if __name__ == "__main__": | |
args = parse_arguments() | |
run_websocket_client(args) | |