Spaces:
Sleeping
Sleeping
AshDavid12
commited on
Commit
·
e37aac1
1
Parent(s):
544d523
changing chunks-try get logs in hf
Browse files
client.py
CHANGED
@@ -16,7 +16,7 @@ async def send_audio(websocket):
|
|
16 |
if response.status_code == 200:
|
17 |
print("Starting to stream audio file...")
|
18 |
|
19 |
-
for chunk in response.iter_content(chunk_size=
|
20 |
if chunk:
|
21 |
audio_buffer.extend(chunk)
|
22 |
#print(f"Received audio chunk of size {len(chunk)} bytes.")
|
@@ -26,7 +26,7 @@ async def send_audio(websocket):
|
|
26 |
await websocket.send(audio_buffer)
|
27 |
#print(f"Sent {len(audio_buffer)} bytes of audio data.")
|
28 |
audio_buffer.clear()
|
29 |
-
await asyncio.sleep(0.
|
30 |
|
31 |
print("Finished sending audio.")
|
32 |
else:
|
@@ -37,12 +37,12 @@ async def receive_transcription(websocket):
|
|
37 |
while True:
|
38 |
try:
|
39 |
|
40 |
-
transcription = await
|
41 |
# Receive transcription from the server
|
42 |
print(f"Transcription: {transcription}")
|
43 |
except Exception as e:
|
44 |
print(f"Error receiving transcription: {e}")
|
45 |
-
await asyncio.sleep(30)
|
46 |
break
|
47 |
|
48 |
|
@@ -54,7 +54,7 @@ async def send_heartbeat(websocket):
|
|
54 |
except websockets.ConnectionClosed:
|
55 |
print("Connection closed, stopping heartbeat")
|
56 |
break
|
57 |
-
await asyncio.sleep(
|
58 |
|
59 |
|
60 |
async def run_client():
|
@@ -64,19 +64,18 @@ async def run_client():
|
|
64 |
ssl_context.verify_mode = ssl.CERT_NONE
|
65 |
while True:
|
66 |
try:
|
67 |
-
async with websockets.connect(uri, ssl=ssl_context, ping_timeout=
|
68 |
await asyncio.gather(
|
69 |
send_audio(websocket),
|
70 |
receive_transcription(websocket),
|
71 |
-
|
72 |
)
|
73 |
except websockets.ConnectionClosedError as e:
|
74 |
print(f"WebSocket closed with error: {e}")
|
75 |
-
except Exception as e:
|
76 |
-
|
77 |
-
|
78 |
-
print("Reconnecting in 5 seconds...")
|
79 |
-
await asyncio.sleep(5) # Wait 5 seconds before reconnecting
|
80 |
|
81 |
-
|
82 |
-
asyncio.run(run_client())
|
|
|
16 |
if response.status_code == 200:
|
17 |
print("Starting to stream audio file...")
|
18 |
|
19 |
+
for chunk in response.iter_content(chunk_size=1024): # Stream in chunks
|
20 |
if chunk:
|
21 |
audio_buffer.extend(chunk)
|
22 |
#print(f"Received audio chunk of size {len(chunk)} bytes.")
|
|
|
26 |
await websocket.send(audio_buffer)
|
27 |
#print(f"Sent {len(audio_buffer)} bytes of audio data.")
|
28 |
audio_buffer.clear()
|
29 |
+
await asyncio.sleep(0.01)
|
30 |
|
31 |
print("Finished sending audio.")
|
32 |
else:
|
|
|
37 |
while True:
|
38 |
try:
|
39 |
|
40 |
+
transcription = await websocket.recv()
|
41 |
# Receive transcription from the server
|
42 |
print(f"Transcription: {transcription}")
|
43 |
except Exception as e:
|
44 |
print(f"Error receiving transcription: {e}")
|
45 |
+
#await asyncio.sleep(30)
|
46 |
break
|
47 |
|
48 |
|
|
|
54 |
except websockets.ConnectionClosed:
|
55 |
print("Connection closed, stopping heartbeat")
|
56 |
break
|
57 |
+
await asyncio.sleep(30) # Send ping every 30 seconds (adjust as needed)
|
58 |
|
59 |
|
60 |
async def run_client():
|
|
|
64 |
ssl_context.verify_mode = ssl.CERT_NONE
|
65 |
while True:
|
66 |
try:
|
67 |
+
async with websockets.connect(uri, ssl=ssl_context, ping_timeout=1000, ping_interval=50) as websocket:
|
68 |
await asyncio.gather(
|
69 |
send_audio(websocket),
|
70 |
receive_transcription(websocket),
|
71 |
+
send_heartbeat(websocket)
|
72 |
)
|
73 |
except websockets.ConnectionClosedError as e:
|
74 |
print(f"WebSocket closed with error: {e}")
|
75 |
+
# except Exception as e:
|
76 |
+
# print(f"Unexpected error: {e}")
|
77 |
+
#
|
78 |
+
# print("Reconnecting in 5 seconds...")
|
79 |
+
# await asyncio.sleep(5) # Wait 5 seconds before reconnecting
|
80 |
|
81 |
+
asyncio.run(run_client())
|
|
infer.py
CHANGED
@@ -13,8 +13,8 @@ import sys
|
|
13 |
import asyncio
|
14 |
|
15 |
# Configure logging
|
16 |
-
|
17 |
-
logging.getLogger("asyncio").setLevel(logging.DEBUG)
|
18 |
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
19 |
logging.info(f'Device selected: {device}')
|
20 |
|
@@ -79,61 +79,6 @@ async def read_root():
|
|
79 |
return {"message": "This is the Ivrit AI Streaming service."}
|
80 |
|
81 |
|
82 |
-
@app.post("/transcribe")
|
83 |
-
async def transcribe(input_data: InputData):
|
84 |
-
logging.info(f'Received transcription request with data: {input_data}')
|
85 |
-
datatype = input_data.type
|
86 |
-
if not datatype:
|
87 |
-
logging.error('datatype field not provided')
|
88 |
-
raise HTTPException(status_code=400, detail="datatype field not provided. Should be 'blob' or 'url'.")
|
89 |
-
|
90 |
-
if datatype not in ['blob', 'url']:
|
91 |
-
logging.error(f'Invalid datatype: {datatype}')
|
92 |
-
raise HTTPException(status_code=400, detail=f"datatype should be 'blob' or 'url', but is {datatype} instead.")
|
93 |
-
|
94 |
-
with tempfile.TemporaryDirectory() as d:
|
95 |
-
audio_file = f'{d}/audio.mp3'
|
96 |
-
logging.debug(f'Created temporary directory: {d}')
|
97 |
-
|
98 |
-
if datatype == 'blob':
|
99 |
-
if not input_data.data:
|
100 |
-
logging.error("Missing 'data' for 'blob' input")
|
101 |
-
raise HTTPException(status_code=400, detail="Missing 'data' for 'blob' input.")
|
102 |
-
logging.info('Decoding base64 blob data')
|
103 |
-
mp3_bytes = base64.b64decode(input_data.data)
|
104 |
-
open(audio_file, 'wb').write(mp3_bytes)
|
105 |
-
logging.info(f'Audio file written: {audio_file}')
|
106 |
-
elif datatype == 'url':
|
107 |
-
if not input_data.url:
|
108 |
-
logging.error("Missing 'url' for 'url' input")
|
109 |
-
raise HTTPException(status_code=400, detail="Missing 'url' for 'url' input.")
|
110 |
-
logging.info(f'Downloading file from URL: {input_data.url}')
|
111 |
-
success = download_file(input_data.url, MAX_PAYLOAD_SIZE, audio_file, None)
|
112 |
-
if not success:
|
113 |
-
logging.error(f"Error downloading data from {input_data.url}")
|
114 |
-
raise HTTPException(status_code=400, detail=f"Error downloading data from {input_data.url}")
|
115 |
-
|
116 |
-
result = transcribe_core(audio_file)
|
117 |
-
return {"result": result}
|
118 |
-
|
119 |
-
|
120 |
-
def transcribe_core(audio_file):
|
121 |
-
logging.info('Starting transcription...')
|
122 |
-
ret = {'segments': []}
|
123 |
-
|
124 |
-
segs, _ = model.transcribe(audio_file, language='he', word_timestamps=True)
|
125 |
-
logging.info('Transcription completed')
|
126 |
-
|
127 |
-
for s in segs:
|
128 |
-
words = [{'start': w.start, 'end': w.end, 'word': w.word, 'probability': w.probability} for w in s.words]
|
129 |
-
seg = {
|
130 |
-
'id': s.id, 'seek': s.seek, 'start': s.start, 'end': s.end, 'text': s.text, 'avg_logprob': s.avg_logprob,
|
131 |
-
'compression_ratio': s.compression_ratio, 'no_speech_prob': s.no_speech_prob, 'words': words
|
132 |
-
}
|
133 |
-
logging.info(f'Transcription segment: {seg}')
|
134 |
-
ret['segments'].append(seg)
|
135 |
-
|
136 |
-
return ret
|
137 |
|
138 |
|
139 |
def transcribe_core_ws(audio_file, last_transcribed_time):
|
@@ -219,14 +164,7 @@ async def websocket_transcribe(websocket: WebSocket):
|
|
219 |
# Estimate the duration of the chunk based on its size (e.g., 16kHz audio)
|
220 |
chunk_duration = len(audio_chunk) / (16000 * 2) # Assuming 16kHz mono WAV (2 bytes per sample)
|
221 |
accumulated_audio_time += chunk_duration
|
222 |
-
#logging.info(f"Received and buffered {len(audio_chunk)} bytes, total buffered: {accumulated_audio_size} bytes, total time: {accumulated_audio_time:.2f} seconds")
|
223 |
-
|
224 |
-
# Transcribe when enough time (audio) is accumulated (e.g., at least 5 seconds of audio)
|
225 |
-
#if accumulated_audio_time >= min_transcription_time:
|
226 |
-
#logging.info("Buffered enough audio time, starting transcription.")
|
227 |
-
|
228 |
|
229 |
-
# Call the transcription function with the last processed time
|
230 |
partial_result, last_transcribed_time = transcribe_core_ws(temp_audio_file.name, last_transcribed_time)
|
231 |
accumulated_audio_time = 0 # Reset the accumulated audio time
|
232 |
processed_segments.extend(partial_result['new_segments'])
|
|
|
13 |
import asyncio
|
14 |
|
15 |
# Configure logging
|
16 |
+
logging.basicConfig(level=logging.INFO, format='%(asctime)s %(levelname)s: %(message)s',handlers=[logging.StreamHandler(sys.stdout)], force=True)
|
17 |
+
#logging.getLogger("asyncio").setLevel(logging.DEBUG)
|
18 |
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
19 |
logging.info(f'Device selected: {device}')
|
20 |
|
|
|
79 |
return {"message": "This is the Ivrit AI Streaming service."}
|
80 |
|
81 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
82 |
|
83 |
|
84 |
def transcribe_core_ws(audio_file, last_transcribed_time):
|
|
|
164 |
# Estimate the duration of the chunk based on its size (e.g., 16kHz audio)
|
165 |
chunk_duration = len(audio_chunk) / (16000 * 2) # Assuming 16kHz mono WAV (2 bytes per sample)
|
166 |
accumulated_audio_time += chunk_duration
|
|
|
|
|
|
|
|
|
|
|
|
|
167 |
|
|
|
168 |
partial_result, last_transcribed_time = transcribe_core_ws(temp_audio_file.name, last_transcribed_time)
|
169 |
accumulated_audio_time = 0 # Reset the accumulated audio time
|
170 |
processed_segments.extend(partial_result['new_segments'])
|