VidTune-Gradio / server.py
animikhaich's picture
Added: Server Health Check Endpoint
d8d2011
raw
history blame
3.11 kB
import warnings
import argparse
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from typing import List, Optional
import torch
from torch.cuda import memory_allocated, memory_reserved
from audiocraft.models import musicgen
import numpy as np
import io
from fastapi.responses import StreamingResponse, JSONResponse
from scipy.io.wavfile import write as wav_write
import uvicorn
import psutil
warnings.simplefilter('ignore')
# Parse command line arguments
parser = argparse.ArgumentParser(description="Music Generation Server")
parser.add_argument("--model", type=str, default="musicgen-stereo-small", help="Pretrained model name")
parser.add_argument("--device", type=str, default="cuda", help="Device to load the model on")
parser.add_argument("--duration", type=int, default=10, help="Duration of generated music in seconds")
parser.add_argument("--host", type=str, default="0.0.0.0", help="Host to run the server on")
parser.add_argument("--port", type=int, default=8000, help="Port to run the server on")
args = parser.parse_args()
# Initialize the FastAPI app
app = FastAPI()
# Build the model name based on the provided arguments
if args.model.startswith('facebook/'):
args.model_name = args.model
else:
args.model_name = f"facebook/{args.model}"
# Load the model with the provided arguments
try:
musicgen_model = musicgen.MusicGen.get_pretrained(args.model_name, device=args.device)
model_loaded = True
except Exception as e:
musicgen_model = None
model_loaded = False
class MusicRequest(BaseModel):
prompts: List[str]
duration: Optional[int] = 10 # Default duration is 10 seconds if not provided
@app.post("/generate_music")
def generate_music(request: MusicRequest):
if not model_loaded:
raise HTTPException(status_code=500, detail="Model is not loaded.")
try:
musicgen_model.set_generation_params(duration=request.duration)
result = musicgen_model.generate(request.prompts, progress=False)
result = result.squeeze().cpu().numpy().T
sample_rate = musicgen_model.sample_rate
buffer = io.BytesIO()
wav_write(buffer, sample_rate, result)
buffer.seek(0)
return StreamingResponse(buffer, media_type="audio/wav")
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/health")
def health_check():
cpu_usage = psutil.cpu_percent(interval=1)
ram_usage = psutil.virtual_memory().percent
stats = {
"server_running": True,
"model_loaded": model_loaded,
"cpu_usage_percent": cpu_usage,
"ram_usage_percent": ram_usage
}
if args.device == "cuda" and torch.cuda.is_available():
gpu_memory_allocated = memory_allocated()
gpu_memory_reserved = memory_reserved()
stats.update({
"gpu_memory_allocated": gpu_memory_allocated,
"gpu_memory_reserved": gpu_memory_reserved
})
return JSONResponse(content=stats)
if __name__ == "__main__":
uvicorn.run(app, host=args.host, port=args.port)