Spaces:
Sleeping
Sleeping
import warnings | |
import argparse | |
from fastapi import FastAPI, HTTPException | |
from pydantic import BaseModel | |
from typing import List, Optional | |
import torch | |
from torch.cuda import memory_allocated, memory_reserved | |
from audiocraft.models import musicgen | |
import numpy as np | |
import io | |
from fastapi.responses import StreamingResponse, JSONResponse | |
from scipy.io.wavfile import write as wav_write | |
import uvicorn | |
import psutil | |
warnings.simplefilter('ignore') | |
# Parse command line arguments | |
parser = argparse.ArgumentParser(description="Music Generation Server") | |
parser.add_argument("--model", type=str, default="musicgen-stereo-small", help="Pretrained model name") | |
parser.add_argument("--device", type=str, default="cuda", help="Device to load the model on") | |
parser.add_argument("--duration", type=int, default=10, help="Duration of generated music in seconds") | |
parser.add_argument("--host", type=str, default="0.0.0.0", help="Host to run the server on") | |
parser.add_argument("--port", type=int, default=8000, help="Port to run the server on") | |
args = parser.parse_args() | |
# Initialize the FastAPI app | |
app = FastAPI() | |
# Build the model name based on the provided arguments | |
if args.model.startswith('facebook/'): | |
args.model_name = args.model | |
else: | |
args.model_name = f"facebook/{args.model}" | |
# Load the model with the provided arguments | |
try: | |
musicgen_model = musicgen.MusicGen.get_pretrained(args.model_name, device=args.device) | |
model_loaded = True | |
except Exception as e: | |
musicgen_model = None | |
model_loaded = False | |
class MusicRequest(BaseModel): | |
prompts: List[str] | |
duration: Optional[int] = 10 # Default duration is 10 seconds if not provided | |
def generate_music(request: MusicRequest): | |
if not model_loaded: | |
raise HTTPException(status_code=500, detail="Model is not loaded.") | |
try: | |
musicgen_model.set_generation_params(duration=request.duration) | |
result = musicgen_model.generate(request.prompts, progress=False) | |
result = result.squeeze().cpu().numpy().T | |
sample_rate = musicgen_model.sample_rate | |
buffer = io.BytesIO() | |
wav_write(buffer, sample_rate, result) | |
buffer.seek(0) | |
return StreamingResponse(buffer, media_type="audio/wav") | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=str(e)) | |
def health_check(): | |
cpu_usage = psutil.cpu_percent(interval=1) | |
ram_usage = psutil.virtual_memory().percent | |
stats = { | |
"server_running": True, | |
"model_loaded": model_loaded, | |
"cpu_usage_percent": cpu_usage, | |
"ram_usage_percent": ram_usage | |
} | |
if args.device == "cuda" and torch.cuda.is_available(): | |
gpu_memory_allocated = memory_allocated() | |
gpu_memory_reserved = memory_reserved() | |
stats.update({ | |
"gpu_memory_allocated": gpu_memory_allocated, | |
"gpu_memory_reserved": gpu_memory_reserved | |
}) | |
return JSONResponse(content=stats) | |
if __name__ == "__main__": | |
uvicorn.run(app, host=args.host, port=args.port) | |