AshDavid12 commited on
Commit
9ccc5b5
·
1 Parent(s): e37aac1

changed temp file into loop

Browse files
Files changed (2) hide show
  1. client.py +2 -2
  2. infer.py +16 -16
client.py CHANGED
@@ -4,8 +4,8 @@ import requests
4
  import ssl
5
 
6
  # Parameters for reading and sending the audio
7
- #AUDIO_FILE_URL = "https://raw.githubusercontent.com/AshDavid12/runpod-serverless-forked/main/test_hebrew.wav" # Use WAV file
8
- AUDIO_FILE_URL = "https://raw.githubusercontent.com/AshDavid12/hugging_face_ivrit_streaming/main/long_hebrew.wav"
9
 
10
 
11
  async def send_audio(websocket):
 
4
  import ssl
5
 
6
  # Parameters for reading and sending the audio
7
+ AUDIO_FILE_URL = "https://raw.githubusercontent.com/AshDavid12/runpod-serverless-forked/main/test_hebrew.wav" # Use WAV file
8
+ #AUDIO_FILE_URL = "https://raw.githubusercontent.com/AshDavid12/hugging_face_ivrit_streaming/main/long_hebrew.wav"
9
 
10
 
11
  async def send_audio(websocket):
infer.py CHANGED
@@ -5,7 +5,7 @@ import torch
5
  import time
6
  import requests
7
  import logging
8
- from fastapi import FastAPI, HTTPException, WebSocket,WebSocketDisconnect
9
  import websockets
10
  from pydantic import BaseModel
11
  from typing import Optional
@@ -13,7 +13,8 @@ import sys
13
  import asyncio
14
 
15
  # Configure logging
16
- logging.basicConfig(level=logging.INFO, format='%(asctime)s %(levelname)s: %(message)s',handlers=[logging.StreamHandler(sys.stdout)], force=True)
 
17
  #logging.getLogger("asyncio").setLevel(logging.DEBUG)
18
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
19
  logging.info(f'Device selected: {device}')
@@ -74,14 +75,13 @@ def download_file(url, max_size_bytes, output_filename, api_key=None):
74
  logging.error(f"Error downloading file: {e}")
75
  return False
76
 
 
77
  @app.get("/")
78
  async def read_root():
79
  return {"message": "This is the Ivrit AI Streaming service."}
80
 
81
 
82
-
83
-
84
- def transcribe_core_ws(audio_file, last_transcribed_time):
85
  """
86
  Transcribe the audio file and return only the segments that have not been processed yet.
87
 
@@ -97,7 +97,8 @@ def transcribe_core_ws(audio_file, last_transcribed_time):
97
  try:
98
  # Transcribe the entire audio file
99
  logging.debug(f"Initiating model transcription for file: {audio_file}")
100
- segs, _ = model.transcribe(audio_file, language='he', word_timestamps=True)
 
101
  logging.info('Transcription completed successfully.')
102
  except Exception as e:
103
  logging.error(f"Error during transcription: {e}")
@@ -145,11 +146,11 @@ async def websocket_transcribe(websocket: WebSocket):
145
  #min_transcription_time = 5.0 # Minimum duration of audio in seconds before transcription starts
146
 
147
  # A temporary file to store the growing audio data
148
- with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_audio_file:
149
- logging.info(f"Temporary audio file created at {temp_audio_file.name}")
150
 
151
- while True:
152
- try:
 
 
153
  # Receive the next chunk of audio data
154
  audio_chunk = await websocket.receive_bytes()
155
  if not audio_chunk:
@@ -165,7 +166,8 @@ async def websocket_transcribe(websocket: WebSocket):
165
  chunk_duration = len(audio_chunk) / (16000 * 2) # Assuming 16kHz mono WAV (2 bytes per sample)
166
  accumulated_audio_time += chunk_duration
167
 
168
- partial_result, last_transcribed_time = transcribe_core_ws(temp_audio_file.name, last_transcribed_time)
 
169
  accumulated_audio_time = 0 # Reset the accumulated audio time
170
  processed_segments.extend(partial_result['new_segments'])
171
 
@@ -180,9 +182,9 @@ async def websocket_transcribe(websocket: WebSocket):
180
  logging.info(f"Sending {len(partial_result['new_segments'])} new segments to the client.")
181
  await websocket.send_json(response)
182
 
183
- except WebSocketDisconnect:
184
- logging.info("WebSocket connection closed by the client.")
185
- break
186
 
187
  except Exception as e:
188
  logging.error(f"Unexpected error during WebSocket transcription: {e}")
@@ -190,5 +192,3 @@ async def websocket_transcribe(websocket: WebSocket):
190
 
191
  finally:
192
  logging.info("Cleaning up and closing WebSocket connection.")
193
-
194
-
 
5
  import time
6
  import requests
7
  import logging
8
+ from fastapi import FastAPI, HTTPException, WebSocket, WebSocketDisconnect
9
  import websockets
10
  from pydantic import BaseModel
11
  from typing import Optional
 
13
  import asyncio
14
 
15
  # Configure logging
16
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s %(levelname)s: %(message)s',
17
+ handlers=[logging.StreamHandler(sys.stdout)], force=True)
18
  #logging.getLogger("asyncio").setLevel(logging.DEBUG)
19
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
20
  logging.info(f'Device selected: {device}')
 
75
  logging.error(f"Error downloading file: {e}")
76
  return False
77
 
78
+
79
  @app.get("/")
80
  async def read_root():
81
  return {"message": "This is the Ivrit AI Streaming service."}
82
 
83
 
84
+ async def transcribe_core_ws(audio_file, last_transcribed_time):
 
 
85
  """
86
  Transcribe the audio file and return only the segments that have not been processed yet.
87
 
 
97
  try:
98
  # Transcribe the entire audio file
99
  logging.debug(f"Initiating model transcription for file: {audio_file}")
100
+
101
+ segs, _ = await asyncio.to_thread(model.transcribe, audio_file, language='he', word_timestamps=True)
102
  logging.info('Transcription completed successfully.')
103
  except Exception as e:
104
  logging.error(f"Error during transcription: {e}")
 
146
  #min_transcription_time = 5.0 # Minimum duration of audio in seconds before transcription starts
147
 
148
  # A temporary file to store the growing audio data
 
 
149
 
150
+ while True:
151
+ try:
152
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_audio_file:
153
+ logging.info(f"Temporary audio file created at {temp_audio_file.name}")
154
  # Receive the next chunk of audio data
155
  audio_chunk = await websocket.receive_bytes()
156
  if not audio_chunk:
 
166
  chunk_duration = len(audio_chunk) / (16000 * 2) # Assuming 16kHz mono WAV (2 bytes per sample)
167
  accumulated_audio_time += chunk_duration
168
 
169
+ partial_result, last_transcribed_time = transcribe_core_ws(temp_audio_file.name,
170
+ last_transcribed_time)
171
  accumulated_audio_time = 0 # Reset the accumulated audio time
172
  processed_segments.extend(partial_result['new_segments'])
173
 
 
182
  logging.info(f"Sending {len(partial_result['new_segments'])} new segments to the client.")
183
  await websocket.send_json(response)
184
 
185
+ except WebSocketDisconnect:
186
+ logging.info("WebSocket connection closed by the client.")
187
+ break
188
 
189
  except Exception as e:
190
  logging.error(f"Unexpected error during WebSocket transcription: {e}")
 
192
 
193
  finally:
194
  logging.info("Cleaning up and closing WebSocket connection.")