AshDavid12 commited on
Commit
1317fe0
·
1 Parent(s): 23abdce

added docker timeout flag

Browse files
Files changed (3) hide show
  1. Dockerfile +1 -1
  2. client.py +17 -7
  3. infer.py +2 -2
Dockerfile CHANGED
@@ -28,4 +28,4 @@ RUN pip install --no-cache-dir -r requirements.txt
28
  COPY . .
29
 
30
  # Run FastAPI with Uvicorn
31
- CMD ["uvicorn", "infer:app", "--host", "0.0.0.0", "--port","7860"]
 
28
  COPY . .
29
 
30
  # Run FastAPI with Uvicorn
31
+ CMD ["uvicorn", "infer:app", "--host", "0.0.0.0", "--port","7860","--timeout-keep-alive","300"]
client.py CHANGED
@@ -6,8 +6,10 @@ import ssl
6
  # Parameters for reading and sending the audio
7
  #AUDIO_FILE_URL = "https://raw.githubusercontent.com/AshDavid12/runpod-serverless-forked/main/test_hebrew.wav" # Use WAV file
8
  AUDIO_FILE_URL = "https://raw.githubusercontent.com/AshDavid12/hugging_face_ivrit_streaming/main/long_hebrew.wav"
 
 
9
  async def send_audio(websocket):
10
- buffer_size = 512*1024 #HAVE TO HAVE 512!!
11
  audio_buffer = bytearray()
12
 
13
  with requests.get(AUDIO_FILE_URL, stream=True, allow_redirects=False) as response:
@@ -30,18 +32,20 @@ async def send_audio(websocket):
30
  else:
31
  print(f"Failed to download audio file. Status code: {response.status_code}")
32
 
 
33
  async def receive_transcription(websocket):
34
  while True:
35
  try:
36
 
37
- transcription = await asyncio.wait_for(websocket.recv(),timeout=300)
38
- # Receive transcription from the server
39
  print(f"Transcription: {transcription}")
40
  except Exception as e:
41
  print(f"Error receiving transcription: {e}")
42
  await asyncio.sleep(30)
43
  break
44
 
 
45
  async def send_heartbeat(websocket):
46
  while True:
47
  try:
@@ -60,13 +64,19 @@ async def run_client():
60
  ssl_context.verify_mode = ssl.CERT_NONE
61
  while True:
62
  try:
63
- async with websockets.connect(uri, ssl=ssl_context, ping_timeout=120,ping_interval=None) as websocket:
64
  await asyncio.gather(
65
  send_audio(websocket),
66
  receive_transcription(websocket),
67
- send_heartbeat(websocket)
68
  )
69
  except websockets.ConnectionClosedError as e:
70
- print(f"web closed :{e}")
 
 
 
 
 
71
 
72
- asyncio.run(run_client())
 
 
6
  # Parameters for reading and sending the audio
7
  #AUDIO_FILE_URL = "https://raw.githubusercontent.com/AshDavid12/runpod-serverless-forked/main/test_hebrew.wav" # Use WAV file
8
  AUDIO_FILE_URL = "https://raw.githubusercontent.com/AshDavid12/hugging_face_ivrit_streaming/main/long_hebrew.wav"
9
+
10
+
11
  async def send_audio(websocket):
12
+ buffer_size = 16 * 1024 #HAVE TO HAVE 512!!
13
  audio_buffer = bytearray()
14
 
15
  with requests.get(AUDIO_FILE_URL, stream=True, allow_redirects=False) as response:
 
32
  else:
33
  print(f"Failed to download audio file. Status code: {response.status_code}")
34
 
35
+
36
  async def receive_transcription(websocket):
37
  while True:
38
  try:
39
 
40
+ transcription = await asyncio.wait_for(websocket.recv(), timeout=300)
41
+ # Receive transcription from the server
42
  print(f"Transcription: {transcription}")
43
  except Exception as e:
44
  print(f"Error receiving transcription: {e}")
45
  await asyncio.sleep(30)
46
  break
47
 
48
+
49
  async def send_heartbeat(websocket):
50
  while True:
51
  try:
 
64
  ssl_context.verify_mode = ssl.CERT_NONE
65
  while True:
66
  try:
67
+ async with websockets.connect(uri, ssl=ssl_context, ping_timeout=500, ping_interval=20) as websocket:
68
  await asyncio.gather(
69
  send_audio(websocket),
70
  receive_transcription(websocket),
71
+ #send_heartbeat(websocket)
72
  )
73
  except websockets.ConnectionClosedError as e:
74
+ print(f"WebSocket closed with error: {e}")
75
+ except Exception as e:
76
+ print(f"Unexpected error: {e}")
77
+
78
+ print("Reconnecting in 5 seconds...")
79
+ await asyncio.sleep(5) # Wait 5 seconds before reconnecting
80
 
81
+ if __name__ == "__main__":
82
+ asyncio.run(run_client())
infer.py CHANGED
@@ -13,8 +13,8 @@ import sys
13
  import asyncio
14
 
15
  # Configure logging
16
- logging.basicConfig(level=logging.DEBUG, format='%(asctime)s %(levelname)s: %(message)s',handlers=[logging.StreamHandler(sys.stdout)])
17
-
18
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
19
  logging.info(f'Device selected: {device}')
20
 
 
13
  import asyncio
14
 
15
  # Configure logging
16
+ #logging.basicConfig(level=logging.DEBUG, format='%(asctime)s %(levelname)s: %(message)s',handlers=[logging.StreamHandler(sys.stdout)])
17
+ logging.getLogger("asyncio").setLevel(logging.DEBUG)
18
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
19
  logging.info(f'Device selected: {device}')
20