animikhaich commited on
Commit
d8d2011
·
1 Parent(s): 2deb721

Added: Server Health Check Endpoint

Browse files
Files changed (3) hide show
  1. client.py +26 -1
  2. requirements.txt +1 -0
  3. server.py +31 -2
client.py CHANGED
@@ -19,6 +19,9 @@ parser.add_argument(
19
  parser.add_argument(
20
  "--duration", type=int, default=10, help="Duration of generated music in seconds"
21
  )
 
 
 
22
 
23
  args = parser.parse_args()
24
 
@@ -36,5 +39,27 @@ def generate_music(server_url, prompts, duration, output_file):
36
  else:
37
  print(f"Failed to generate music: {response.status_code}, {response.text}")
38
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  if __name__ == "__main__":
40
- generate_music(args.server_url, args.prompts, args.duration, args.output_file)
 
 
 
 
19
  parser.add_argument(
20
  "--duration", type=int, default=10, help="Duration of generated music in seconds"
21
  )
22
+ parser.add_argument(
23
+ "--check_health", action='store_true', help="Check server health"
24
+ )
25
 
26
  args = parser.parse_args()
27
 
 
39
  else:
40
  print(f"Failed to generate music: {response.status_code}, {response.text}")
41
 
42
+ def check_server_health(server_url):
43
+ url = f"{server_url}/health"
44
+ response = requests.get(url)
45
+
46
+ if response.status_code == 200:
47
+ health_status = response.json()
48
+ print("Server Health Check:")
49
+ print(f"Server Running: {health_status['server_running']}")
50
+ print(f"Model Loaded: {health_status['model_loaded']}")
51
+ print(f"CPU Usage: {health_status['cpu_usage_percent']}%")
52
+ print(f"RAM Usage: {health_status['ram_usage_percent']}%")
53
+ if 'gpu_memory_allocated' in health_status:
54
+ gpu_memory_allocated_gb = health_status['gpu_memory_allocated'] / (1024 ** 3)
55
+ gpu_memory_reserved_gb = health_status['gpu_memory_reserved'] / (1024 ** 3)
56
+ print(f"GPU Memory Allocated: {gpu_memory_allocated_gb:.2f} GB")
57
+ print(f"GPU Memory Reserved: {gpu_memory_reserved_gb:.2f} GB")
58
+ else:
59
+ print(f"Failed to check server health: {response.status_code}, {response.text}")
60
+
61
  if __name__ == "__main__":
62
+ if args.check_health:
63
+ check_server_health(args.server_url)
64
+ else:
65
+ generate_music(args.server_url, args.prompts, args.duration, args.output_file)
requirements.txt CHANGED
@@ -6,3 +6,4 @@ Requests==2.32.3
6
  scipy==1.13.1
7
  torch==2.1.0
8
  uvicorn==0.30.1
 
 
6
  scipy==1.13.1
7
  torch==2.1.0
8
  uvicorn==0.30.1
9
+ psutil==6.0.0
server.py CHANGED
@@ -4,12 +4,14 @@ from fastapi import FastAPI, HTTPException
4
  from pydantic import BaseModel
5
  from typing import List, Optional
6
  import torch
 
7
  from audiocraft.models import musicgen
8
  import numpy as np
9
  import io
10
- from fastapi.responses import StreamingResponse
11
  from scipy.io.wavfile import write as wav_write
12
  import uvicorn
 
13
 
14
  warnings.simplefilter('ignore')
15
 
@@ -33,7 +35,12 @@ else:
33
  args.model_name = f"facebook/{args.model}"
34
 
35
  # Load the model with the provided arguments
36
- musicgen_model = musicgen.MusicGen.get_pretrained(args.model_name, device=args.device)
 
 
 
 
 
37
 
38
  class MusicRequest(BaseModel):
39
  prompts: List[str]
@@ -41,6 +48,9 @@ class MusicRequest(BaseModel):
41
 
42
  @app.post("/generate_music")
43
  def generate_music(request: MusicRequest):
 
 
 
44
  try:
45
  musicgen_model.set_generation_params(duration=request.duration)
46
  result = musicgen_model.generate(request.prompts, progress=False)
@@ -57,5 +67,24 @@ def generate_music(request: MusicRequest):
57
  except Exception as e:
58
  raise HTTPException(status_code=500, detail=str(e))
59
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
  if __name__ == "__main__":
61
  uvicorn.run(app, host=args.host, port=args.port)
 
4
  from pydantic import BaseModel
5
  from typing import List, Optional
6
  import torch
7
+ from torch.cuda import memory_allocated, memory_reserved
8
  from audiocraft.models import musicgen
9
  import numpy as np
10
  import io
11
+ from fastapi.responses import StreamingResponse, JSONResponse
12
  from scipy.io.wavfile import write as wav_write
13
  import uvicorn
14
+ import psutil
15
 
16
  warnings.simplefilter('ignore')
17
 
 
35
  args.model_name = f"facebook/{args.model}"
36
 
37
  # Load the model with the provided arguments
38
+ try:
39
+ musicgen_model = musicgen.MusicGen.get_pretrained(args.model_name, device=args.device)
40
+ model_loaded = True
41
+ except Exception as e:
42
+ musicgen_model = None
43
+ model_loaded = False
44
 
45
  class MusicRequest(BaseModel):
46
  prompts: List[str]
 
48
 
49
  @app.post("/generate_music")
50
  def generate_music(request: MusicRequest):
51
+ if not model_loaded:
52
+ raise HTTPException(status_code=500, detail="Model is not loaded.")
53
+
54
  try:
55
  musicgen_model.set_generation_params(duration=request.duration)
56
  result = musicgen_model.generate(request.prompts, progress=False)
 
67
  except Exception as e:
68
  raise HTTPException(status_code=500, detail=str(e))
69
 
70
+ @app.get("/health")
71
+ def health_check():
72
+ cpu_usage = psutil.cpu_percent(interval=1)
73
+ ram_usage = psutil.virtual_memory().percent
74
+ stats = {
75
+ "server_running": True,
76
+ "model_loaded": model_loaded,
77
+ "cpu_usage_percent": cpu_usage,
78
+ "ram_usage_percent": ram_usage
79
+ }
80
+ if args.device == "cuda" and torch.cuda.is_available():
81
+ gpu_memory_allocated = memory_allocated()
82
+ gpu_memory_reserved = memory_reserved()
83
+ stats.update({
84
+ "gpu_memory_allocated": gpu_memory_allocated,
85
+ "gpu_memory_reserved": gpu_memory_reserved
86
+ })
87
+ return JSONResponse(content=stats)
88
+
89
  if __name__ == "__main__":
90
  uvicorn.run(app, host=args.host, port=args.port)