Spaces:
Sleeping
Sleeping
AshDavid12
commited on
Commit
·
5a62402
1
Parent(s):
aab7acf
complete change
Browse files
client.py
CHANGED
@@ -2,80 +2,173 @@ import asyncio
|
|
2 |
import websockets
|
3 |
import requests
|
4 |
import ssl
|
|
|
|
|
|
|
5 |
|
6 |
# Parameters for reading and sending the audio
|
7 |
AUDIO_FILE_URL = "https://raw.githubusercontent.com/AshDavid12/runpod-serverless-forked/main/test_hebrew.wav" # Use WAV file
|
8 |
#AUDIO_FILE_URL = "https://raw.githubusercontent.com/AshDavid12/hugging_face_ivrit_streaming/main/long_hebrew.wav"
|
9 |
|
10 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
async def send_audio(websocket):
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
async
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
import websockets
|
3 |
import requests
|
4 |
import ssl
|
5 |
+
import wave
|
6 |
+
import logging
|
7 |
+
import sys
|
8 |
|
9 |
# Parameters for reading and sending the audio
|
10 |
AUDIO_FILE_URL = "https://raw.githubusercontent.com/AshDavid12/runpod-serverless-forked/main/test_hebrew.wav" # Use WAV file
|
11 |
#AUDIO_FILE_URL = "https://raw.githubusercontent.com/AshDavid12/hugging_face_ivrit_streaming/main/long_hebrew.wav"
|
12 |
|
13 |
|
14 |
+
|
15 |
+
# Set up logging
|
16 |
+
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s %(levelname)s: %(message)s',
|
17 |
+
handlers=[logging.StreamHandler(sys.stdout)], force=True)
|
18 |
+
logger = logging.getLogger(__name__)
|
19 |
+
|
20 |
+
async def send_receive():
|
21 |
+
uri = "wss://gigaverse-ivrit-ai-streaming.hf.space/ws" # Update with your server's address if needed
|
22 |
+
logger.info(f"Connecting to server at {uri}")
|
23 |
+
try:
|
24 |
+
async with websockets.connect(uri) as websocket:
|
25 |
+
logger.info("WebSocket connection established")
|
26 |
+
# Start tasks for sending and receiving
|
27 |
+
send_task = asyncio.create_task(send_audio(websocket))
|
28 |
+
receive_task = asyncio.create_task(receive_transcriptions(websocket))
|
29 |
+
await asyncio.gather(send_task, receive_task)
|
30 |
+
except Exception as e:
|
31 |
+
logger.error(f"WebSocket connection error: {e}")
|
32 |
+
|
33 |
async def send_audio(websocket):
|
34 |
+
wav_file = 'path/to/your/audio.wav' # Replace with the path to your WAV file
|
35 |
+
logger.info(f"Opening WAV file: {wav_file}")
|
36 |
+
|
37 |
+
try:
|
38 |
+
# Open the WAV file
|
39 |
+
wf = wave.open(wav_file, 'rb')
|
40 |
+
|
41 |
+
# Log WAV file parameters
|
42 |
+
channels = wf.getnchannels()
|
43 |
+
sampwidth = wf.getsampwidth()
|
44 |
+
framerate = wf.getframerate()
|
45 |
+
nframes = wf.getnframes()
|
46 |
+
duration = nframes / framerate
|
47 |
+
logger.debug(f"WAV file parameters: channels={channels}, sample_width={sampwidth}, framerate={framerate}, frames={nframes}, duration={duration:.2f}s")
|
48 |
+
|
49 |
+
# Ensure the WAV file has the expected parameters
|
50 |
+
if channels != 1 or sampwidth != 2 or framerate != 16000:
|
51 |
+
logger.error("WAV file must be mono channel, 16-bit samples, 16kHz sampling rate")
|
52 |
+
return
|
53 |
+
|
54 |
+
chunk_duration = 0.1 # in seconds
|
55 |
+
chunk_size = int(framerate * chunk_duration)
|
56 |
+
logger.info(f"Starting to send audio data in chunks of {chunk_duration}s ({chunk_size} frames)")
|
57 |
+
|
58 |
+
total_chunks = 0
|
59 |
+
total_bytes_sent = 0
|
60 |
+
|
61 |
+
while True:
|
62 |
+
data = wf.readframes(chunk_size)
|
63 |
+
if not data:
|
64 |
+
logger.info("End of WAV file reached")
|
65 |
+
break
|
66 |
+
await websocket.send(data)
|
67 |
+
total_chunks += 1
|
68 |
+
total_bytes_sent += len(data)
|
69 |
+
logger.debug(f"Sent chunk {total_chunks}: {len(data)} bytes")
|
70 |
+
await asyncio.sleep(chunk_duration) # Simulate real-time streaming
|
71 |
+
|
72 |
+
logger.info(f"Finished sending audio data: {total_chunks} chunks sent, total bytes sent: {total_bytes_sent}")
|
73 |
+
except Exception as e:
|
74 |
+
logger.error(f"Send audio error: {e}")
|
75 |
+
finally:
|
76 |
+
wf.close()
|
77 |
+
logger.info("WAV file closed")
|
78 |
+
|
79 |
+
async def receive_transcriptions(websocket):
|
80 |
+
try:
|
81 |
+
logger.info("Starting to receive transcriptions")
|
82 |
+
async for message in websocket: # This is the same as websocket.recv()
|
83 |
+
logger.info(f"Received transcription: {message}")
|
84 |
+
print(f"Transcription: {message}")
|
85 |
+
except Exception as e:
|
86 |
+
logger.error(f"Receive transcription error: {e}")
|
87 |
+
|
88 |
+
if __name__ == "__main__":
|
89 |
+
asyncio.run(send_receive())
|
90 |
+
|
91 |
+
|
92 |
+
|
93 |
+
|
94 |
+
|
95 |
+
|
96 |
+
|
97 |
+
|
98 |
+
|
99 |
+
|
100 |
+
|
101 |
+
|
102 |
+
|
103 |
+
|
104 |
+
# async def send_audio(websocket):
|
105 |
+
# buffer_size = 512 * 1024 #HAVE TO HAVE 512!!
|
106 |
+
# audio_buffer = bytearray()
|
107 |
+
#
|
108 |
+
# with requests.get(AUDIO_FILE_URL, stream=True, allow_redirects=False) as response:
|
109 |
+
# if response.status_code == 200:
|
110 |
+
# print("Starting to stream audio file...")
|
111 |
+
#
|
112 |
+
# for chunk in response.iter_content(chunk_size=1024): # Stream in chunks
|
113 |
+
# if chunk:
|
114 |
+
# audio_buffer.extend(chunk)
|
115 |
+
# #print(f"Received audio chunk of size {len(chunk)} bytes.")
|
116 |
+
#
|
117 |
+
# # Send buffered audio data once it's large enough
|
118 |
+
# if len(audio_buffer) >= buffer_size:
|
119 |
+
# await websocket.send(audio_buffer)
|
120 |
+
# #print(f"Sent {len(audio_buffer)} bytes of audio data.")
|
121 |
+
# audio_buffer.clear()
|
122 |
+
# await asyncio.sleep(0.01)
|
123 |
+
#
|
124 |
+
# print("Finished sending audio.")
|
125 |
+
# else:
|
126 |
+
# print(f"Failed to download audio file. Status code: {response.status_code}")
|
127 |
+
#
|
128 |
+
#
|
129 |
+
# async def receive_transcription(websocket):
|
130 |
+
# while True:
|
131 |
+
# try:
|
132 |
+
#
|
133 |
+
# transcription = await websocket.recv()
|
134 |
+
# # Receive transcription from the server
|
135 |
+
# print(f"Transcription: {transcription}")
|
136 |
+
# except Exception as e:
|
137 |
+
# print(f"Error receiving transcription: {e}")
|
138 |
+
# #await asyncio.sleep(30)
|
139 |
+
# break
|
140 |
+
#
|
141 |
+
#
|
142 |
+
# async def send_heartbeat(websocket):
|
143 |
+
# while True:
|
144 |
+
# try:
|
145 |
+
# await websocket.ping()
|
146 |
+
# print("Sent keepalive ping")
|
147 |
+
# except websockets.ConnectionClosed:
|
148 |
+
# print("Connection closed, stopping heartbeat")
|
149 |
+
# break
|
150 |
+
# await asyncio.sleep(30) # Send ping every 30 seconds (adjust as needed)
|
151 |
+
#
|
152 |
+
#
|
153 |
+
# async def run_client():
|
154 |
+
# uri = ("wss://gigaverse-ivrit-ai-streaming.hf.space/wtranscribe") # WebSocket URL
|
155 |
+
# ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
|
156 |
+
# ssl_context.check_hostname = False
|
157 |
+
# ssl_context.verify_mode = ssl.CERT_NONE
|
158 |
+
# while True:
|
159 |
+
# try:
|
160 |
+
# async with websockets.connect(uri, ssl=ssl_context, ping_timeout=1000, ping_interval=50) as websocket:
|
161 |
+
# await asyncio.gather(
|
162 |
+
# send_audio(websocket),
|
163 |
+
# receive_transcription(websocket),
|
164 |
+
# send_heartbeat(websocket)
|
165 |
+
# )
|
166 |
+
# except websockets.ConnectionClosedError as e:
|
167 |
+
# print(f"WebSocket closed with error: {e}")
|
168 |
+
# # except Exception as e:
|
169 |
+
# # print(f"Unexpected error: {e}")
|
170 |
+
# #
|
171 |
+
# # print("Reconnecting in 5 seconds...")
|
172 |
+
# # await asyncio.sleep(5) # Wait 5 seconds before reconnecting
|
173 |
+
#
|
174 |
+
# asyncio.run(run_client())
|
infer.py
CHANGED
@@ -1,6 +1,8 @@
|
|
1 |
import base64
|
2 |
import faster_whisper
|
3 |
import tempfile
|
|
|
|
|
4 |
import torch
|
5 |
import time
|
6 |
import requests
|
@@ -15,6 +17,7 @@ import asyncio
|
|
15 |
# Configure logging
|
16 |
logging.basicConfig(level=logging.INFO, format='%(asctime)s %(levelname)s: %(message)s',
|
17 |
handlers=[logging.StreamHandler(sys.stdout)], force=True)
|
|
|
18 |
#logging.getLogger("asyncio").setLevel(logging.DEBUG)
|
19 |
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
20 |
logging.info(f'Device selected: {device}')
|
@@ -31,49 +34,49 @@ logging.info(f'Max payload size set to: {MAX_PAYLOAD_SIZE} bytes')
|
|
31 |
app = FastAPI()
|
32 |
|
33 |
|
34 |
-
class InputData(BaseModel):
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
def download_file(url, max_size_bytes, output_filename, api_key=None):
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
|
78 |
|
79 |
@app.get("/")
|
@@ -81,97 +84,208 @@ async def read_root():
|
|
81 |
return {"message": "This is the Ivrit AI Streaming service."}
|
82 |
|
83 |
|
84 |
-
async def transcribe_core_ws(audio_file):
|
85 |
-
|
86 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
87 |
|
88 |
-
:param audio_file: Path to the growing audio file.
|
89 |
-
:param last_transcribed_time: The last time (in seconds) that was transcribed.
|
90 |
-
:return: Newly transcribed segments and the updated last transcribed time.
|
91 |
-
"""
|
92 |
-
ret = {'segments': []}
|
93 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
94 |
try:
|
95 |
-
# Transcribe the entire audio file
|
96 |
-
logging.debug(f"Initiating model transcription for file: {audio_file}")
|
97 |
|
98 |
-
|
99 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
100 |
except Exception as e:
|
101 |
-
|
102 |
-
|
103 |
|
104 |
-
|
105 |
-
|
106 |
-
|
|
|
|
|
107 |
|
108 |
-
|
109 |
-
|
110 |
-
words = [{'start': w.start, 'end': w.end, 'word': w.word, 'probability': w.probability} for w in s.words]
|
111 |
|
112 |
-
|
113 |
-
'id': s.id, 'seek': s.seek, 'start': s.start, 'end': s.end, 'text': s.text,
|
114 |
-
'avg_logprob': s.avg_logprob, 'compression_ratio': s.compression_ratio,
|
115 |
-
'no_speech_prob': s.no_speech_prob, 'words': words
|
116 |
-
}
|
117 |
-
logging.info(f'Adding new transcription segment: {seg}')
|
118 |
-
ret['segements'].append(seg)
|
119 |
|
120 |
-
# Update the last transcribed time to the end of the current segment
|
121 |
|
122 |
|
123 |
-
#logging.info(f"Returning {len(ret['new_segments'])} new segments and updated last transcribed time.")
|
124 |
-
return ret
|
125 |
|
126 |
|
127 |
-
import tempfile
|
128 |
|
129 |
|
130 |
-
@app.websocket("/wtranscribe")
|
131 |
-
async def websocket_transcribe(websocket: WebSocket):
|
132 |
-
logging.info("New WebSocket connection request received.")
|
133 |
-
await websocket.accept()
|
134 |
-
logging.info("WebSocket connection established successfully.")
|
135 |
|
136 |
-
try:
|
137 |
-
processed_segments = [] # Keeps track of the segments already transcribed
|
138 |
-
accumulated_audio_size = 0 # Track how much audio data has been buffered
|
139 |
-
accumulated_audio_time = 0 # Track the total audio duration accumulated
|
140 |
-
last_transcribed_time = 0.0
|
141 |
-
#min_transcription_time = 5.0 # Minimum duration of audio in seconds before transcription starts
|
142 |
-
|
143 |
-
# A temporary file to store the growing audio data
|
144 |
-
|
145 |
-
while True:
|
146 |
-
try:
|
147 |
-
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_audio_file: ##new temp file for every chunk
|
148 |
-
logging.info(f"Temporary audio file created at {temp_audio_file.name}")
|
149 |
-
# Receive the next chunk of audio data
|
150 |
-
audio_chunk = await websocket.receive_bytes()
|
151 |
-
if not audio_chunk:
|
152 |
-
logging.warning("Received empty audio chunk, skipping processing.")
|
153 |
-
continue
|
154 |
-
|
155 |
-
# Write audio chunk to file and accumulate size and time
|
156 |
-
temp_audio_file.write(audio_chunk)
|
157 |
-
temp_audio_file.flush()
|
158 |
-
accumulated_audio_size += len(audio_chunk)
|
159 |
-
|
160 |
-
# Estimate the duration of the chunk based on its size (e.g., 16kHz audio)
|
161 |
-
chunk_duration = len(audio_chunk) / (16000 * 2) # Assuming 16kHz mono WAV (2 bytes per sample)
|
162 |
-
accumulated_audio_time += chunk_duration
|
163 |
-
|
164 |
-
partial_result = await transcribe_core_ws(temp_audio_file)
|
165 |
-
accumulated_audio_time = 0 # Reset the accumulated audio time
|
166 |
-
await websocket.send_json(partial_result)
|
167 |
-
|
168 |
-
except WebSocketDisconnect:
|
169 |
-
logging.info("WebSocket connection closed by the client.")
|
170 |
-
break
|
171 |
|
172 |
-
except Exception as e:
|
173 |
-
logging.error(f"Unexpected error during WebSocket transcription: {e}")
|
174 |
-
await websocket.send_json({"error": str(e)})
|
175 |
|
176 |
-
|
177 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import base64
|
2 |
import faster_whisper
|
3 |
import tempfile
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
import torch
|
7 |
import time
|
8 |
import requests
|
|
|
17 |
# Configure logging
|
18 |
logging.basicConfig(level=logging.INFO, format='%(asctime)s %(levelname)s: %(message)s',
|
19 |
handlers=[logging.StreamHandler(sys.stdout)], force=True)
|
20 |
+
logger = logging.getLogger(__name__)
|
21 |
#logging.getLogger("asyncio").setLevel(logging.DEBUG)
|
22 |
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
23 |
logging.info(f'Device selected: {device}')
|
|
|
34 |
app = FastAPI()
|
35 |
|
36 |
|
37 |
+
# class InputData(BaseModel):
|
38 |
+
# type: str
|
39 |
+
# data: Optional[str] = None # Used for blob input
|
40 |
+
# url: Optional[str] = None # Used for url input
|
41 |
+
#
|
42 |
+
#
|
43 |
+
# def download_file(url, max_size_bytes, output_filename, api_key=None):
|
44 |
+
# """
|
45 |
+
# Download a file from a given URL with size limit and optional API key.
|
46 |
+
# """
|
47 |
+
# logging.debug(f'Starting file download from URL: {url}')
|
48 |
+
# try:
|
49 |
+
# headers = {}
|
50 |
+
# if api_key:
|
51 |
+
# headers['Authorization'] = f'Bearer {api_key}'
|
52 |
+
# logging.debug('API key provided, added to headers')
|
53 |
+
#
|
54 |
+
# response = requests.get(url, stream=True, headers=headers)
|
55 |
+
# response.raise_for_status()
|
56 |
+
#
|
57 |
+
# file_size = int(response.headers.get('Content-Length', 0))
|
58 |
+
# logging.info(f'File size: {file_size} bytes')
|
59 |
+
#
|
60 |
+
# if file_size > max_size_bytes:
|
61 |
+
# logging.error(f'File size exceeds limit: {file_size} > {max_size_bytes}')
|
62 |
+
# return False
|
63 |
+
#
|
64 |
+
# downloaded_size = 0
|
65 |
+
# with open(output_filename, 'wb') as file:
|
66 |
+
# for chunk in response.iter_content(chunk_size=8192):
|
67 |
+
# downloaded_size += len(chunk)
|
68 |
+
# logging.debug(f'Downloaded {downloaded_size} bytes')
|
69 |
+
# if downloaded_size > max_size_bytes:
|
70 |
+
# logging.error('Downloaded size exceeds maximum allowed payload size')
|
71 |
+
# return False
|
72 |
+
# file.write(chunk)
|
73 |
+
#
|
74 |
+
# logging.info(f'File downloaded successfully: {output_filename}')
|
75 |
+
# return True
|
76 |
+
#
|
77 |
+
# except requests.RequestException as e:
|
78 |
+
# logging.error(f"Error downloading file: {e}")
|
79 |
+
# return False
|
80 |
|
81 |
|
82 |
@app.get("/")
|
|
|
84 |
return {"message": "This is the Ivrit AI Streaming service."}
|
85 |
|
86 |
|
87 |
+
# async def transcribe_core_ws(audio_file):
|
88 |
+
# ret = {'segments': []}
|
89 |
+
#
|
90 |
+
# try:
|
91 |
+
#
|
92 |
+
# logging.debug(f"Initiating model transcription for file: {audio_file}")
|
93 |
+
#
|
94 |
+
# segs, _ = await asyncio.to_thread(model.transcribe, audio_file, language='he', word_timestamps=True)
|
95 |
+
# logging.info('Transcription completed successfully.')
|
96 |
+
# except Exception as e:
|
97 |
+
# logging.error(f"Error during transcription: {e}")
|
98 |
+
# raise e
|
99 |
+
#
|
100 |
+
# # Track the new segments and update the last transcribed time
|
101 |
+
# for s in segs:
|
102 |
+
# logging.info(f"Processing segment with start time: {s.start} and end time: {s.end}")
|
103 |
+
#
|
104 |
+
# # Only process segments that start after the last transcribed time
|
105 |
+
# logging.info(f"New segment found starting at {s.start} seconds.")
|
106 |
+
# words = [{'start': w.start, 'end': w.end, 'word': w.word, 'probability': w.probability} for w in s.words]
|
107 |
+
#
|
108 |
+
# seg = {
|
109 |
+
# 'id': s.id, 'seek': s.seek, 'start': s.start, 'end': s.end, 'text': s.text,
|
110 |
+
# 'avg_logprob': s.avg_logprob, 'compression_ratio': s.compression_ratio,
|
111 |
+
# 'no_speech_prob': s.no_speech_prob, 'words': words
|
112 |
+
# }
|
113 |
+
# logging.info(f'Adding new transcription segment: {seg}')
|
114 |
+
# ret['segements'].append(seg)
|
115 |
+
#
|
116 |
+
# # Update the last transcribed time to the end of the current segment
|
117 |
+
#
|
118 |
+
#
|
119 |
+
# #logging.info(f"Returning {len(ret['new_segments'])} new segments and updated last transcribed time.")
|
120 |
+
# return ret
|
121 |
+
|
122 |
+
|
123 |
+
import tempfile
|
124 |
|
|
|
|
|
|
|
|
|
|
|
125 |
|
126 |
+
@app.websocket("/ws")
|
127 |
+
async def websocket_endpoint(websocket: WebSocket):
|
128 |
+
"""WebSocket endpoint to handle client connections."""
|
129 |
+
await websocket.accept()
|
130 |
+
client_ip = websocket.client.host
|
131 |
+
logger.info(f"Client connected: {client_ip}")
|
132 |
+
try:
|
133 |
+
await process_audio_stream(websocket)
|
134 |
+
except WebSocketDisconnect:
|
135 |
+
logger.info(f"Client disconnected: {client_ip}")
|
136 |
+
except Exception as e:
|
137 |
+
logger.error(f"Unexpected error: {e}")
|
138 |
+
await websocket.close()
|
139 |
+
|
140 |
+
async def process_audio_stream(websocket: WebSocket):
|
141 |
+
"""Continuously receive audio chunks and initiate transcription tasks."""
|
142 |
+
sampling_rate = 16000
|
143 |
+
min_chunk_size = 1 # in seconds
|
144 |
+
audio_buffer = np.array([], dtype=np.float32)
|
145 |
+
|
146 |
+
transcription_task = None
|
147 |
+
chunk_counter = 0
|
148 |
+
total_bytes_received = 0
|
149 |
+
|
150 |
+
while True:
|
151 |
+
try:
|
152 |
+
# Receive audio data from client
|
153 |
+
data = await websocket.receive_bytes()
|
154 |
+
if not data:
|
155 |
+
logger.info("No data received, closing connection")
|
156 |
+
break
|
157 |
+
chunk_counter += 1
|
158 |
+
chunk_size = len(data)
|
159 |
+
total_bytes_received += chunk_size
|
160 |
+
logger.debug(f"Received chunk {chunk_counter}: {chunk_size} bytes")
|
161 |
+
|
162 |
+
audio_chunk = process_received_audio(data)
|
163 |
+
logger.debug(f"Processed audio chunk {chunk_counter}: {len(audio_chunk)} samples")
|
164 |
+
|
165 |
+
audio_buffer = np.concatenate((audio_buffer, audio_chunk))
|
166 |
+
logger.debug(f"Audio buffer size: {len(audio_buffer)} samples")
|
167 |
+
except Exception as e:
|
168 |
+
logger.error(f"Error receiving data: {e}")
|
169 |
+
break
|
170 |
+
|
171 |
+
# Check if enough audio has been buffered
|
172 |
+
if len(audio_buffer) >= min_chunk_size * sampling_rate:
|
173 |
+
if transcription_task is None or transcription_task.done():
|
174 |
+
# Start a new transcription task
|
175 |
+
logger.info(f"Starting transcription task for {len(audio_buffer)} samples")
|
176 |
+
transcription_task = asyncio.create_task(
|
177 |
+
transcribe_and_send(websocket, audio_buffer.copy())
|
178 |
+
)
|
179 |
+
audio_buffer = np.array([], dtype=np.float32)
|
180 |
+
logger.debug("Audio buffer reset after starting transcription task")
|
181 |
+
|
182 |
+
async def transcribe_and_send(websocket: WebSocket, audio_data):
|
183 |
+
"""Run transcription in a separate thread and send the result to the client."""
|
184 |
+
logger.debug(f"Transcription task started for {len(audio_data)} samples")
|
185 |
+
transcription_result = await asyncio.to_thread(sync_transcribe_audio, audio_data)
|
186 |
+
if transcription_result:
|
187 |
+
try:
|
188 |
+
# Send the result as JSON
|
189 |
+
await websocket.send_json(transcription_result)
|
190 |
+
logger.info("Transcription JSON sent to client")
|
191 |
+
except Exception as e:
|
192 |
+
logger.error(f"Error sending transcription: {e}")
|
193 |
+
else:
|
194 |
+
logger.warning("No transcription result to send")
|
195 |
+
|
196 |
+
def sync_transcribe_audio(audio_data):
|
197 |
+
"""Synchronously transcribe audio data using the ASR model and format the result."""
|
198 |
try:
|
|
|
|
|
199 |
|
200 |
+
logger.info('Starting transcription...')
|
201 |
+
segments, info = model.transcribe(
|
202 |
+
audio_data, language="he", beam_size=5, word_timestamps=True
|
203 |
+
)
|
204 |
+
logger.info('Transcription completed')
|
205 |
+
|
206 |
+
# Build the transcription result as per your requirement
|
207 |
+
ret = {'segments': []}
|
208 |
+
|
209 |
+
for s in segments:
|
210 |
+
logger.debug(f"Processing segment {s.id} with start time: {s.start} and end time: {s.end}")
|
211 |
+
|
212 |
+
# Process words in the segment
|
213 |
+
words = [{
|
214 |
+
'start': float(w.start),
|
215 |
+
'end': float(w.end),
|
216 |
+
'word': w.word,
|
217 |
+
'probability': float(w.probability)
|
218 |
+
} for w in s.words]
|
219 |
+
|
220 |
+
seg = {
|
221 |
+
'id': int(s.id),
|
222 |
+
'seek': int(s.seek),
|
223 |
+
'start': float(s.start),
|
224 |
+
'end': float(s.end),
|
225 |
+
'text': s.text,
|
226 |
+
'avg_logprob': float(s.avg_logprob),
|
227 |
+
'compression_ratio': float(s.compression_ratio),
|
228 |
+
'no_speech_prob': float(s.no_speech_prob),
|
229 |
+
'words': words
|
230 |
+
}
|
231 |
+
logger.debug(f'Adding new transcription segment: {seg}')
|
232 |
+
ret['segments'].append(seg)
|
233 |
+
|
234 |
+
logger.debug(f"Total segments in transcription result: {len(ret['segments'])}")
|
235 |
+
return ret
|
236 |
except Exception as e:
|
237 |
+
logger.error(f"Transcription error: {e}")
|
238 |
+
return {}
|
239 |
|
240 |
+
def process_received_audio(data):
|
241 |
+
"""Convert received bytes into normalized float32 NumPy array."""
|
242 |
+
logger.debug(f"Processing received audio data of size {len(data)} bytes")
|
243 |
+
audio_int16 = np.frombuffer(data, dtype=np.int16)
|
244 |
+
logger.debug(f"Converted to int16 NumPy array with {len(audio_int16)} samples")
|
245 |
|
246 |
+
audio_float32 = audio_int16.astype(np.float32) / 32768.0 # Normalize to [-1, 1]
|
247 |
+
logger.debug(f"Normalized audio data to float32 with {len(audio_float32)} samples")
|
|
|
248 |
|
249 |
+
return audio_float32
|
|
|
|
|
|
|
|
|
|
|
|
|
250 |
|
|
|
251 |
|
252 |
|
|
|
|
|
253 |
|
254 |
|
|
|
255 |
|
256 |
|
|
|
|
|
|
|
|
|
|
|
257 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
258 |
|
|
|
|
|
|
|
259 |
|
260 |
+
# @app.websocket("/wtranscribe")
|
261 |
+
# async def websocket_transcribe(websocket: WebSocket):
|
262 |
+
# logging.info("New WebSocket connection request received.")
|
263 |
+
# await websocket.accept()
|
264 |
+
# logging.info("WebSocket connection established successfully.")
|
265 |
+
#
|
266 |
+
# try:
|
267 |
+
# while True:
|
268 |
+
# try:
|
269 |
+
# audio_chunk = await websocket.receive_bytes()
|
270 |
+
# if not audio_chunk:
|
271 |
+
# logging.warning("Received empty audio chunk, skipping processing.")
|
272 |
+
# continue
|
273 |
+
# with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_audio_file: ##new temp file for every chunk
|
274 |
+
# logging.info(f"Temporary audio file created at {temp_audio_file.name}")
|
275 |
+
# # Receive the next chunk of audio data
|
276 |
+
#
|
277 |
+
#
|
278 |
+
#
|
279 |
+
# partial_result = await transcribe_core_ws(temp_audio_file.name)
|
280 |
+
# await websocket.send_json(partial_result)
|
281 |
+
#
|
282 |
+
# except WebSocketDisconnect:
|
283 |
+
# logging.info("WebSocket connection closed by the client.")
|
284 |
+
# break
|
285 |
+
#
|
286 |
+
# except Exception as e:
|
287 |
+
# logging.error(f"Unexpected error during WebSocket transcription: {e}")
|
288 |
+
# await websocket.send_json({"error": str(e)})
|
289 |
+
#
|
290 |
+
# finally:
|
291 |
+
# logging.info("Cleaning up and closing WebSocket connection.")
|