Phoenixak99's picture
Create handler.py
48e4064 verified
raw
history blame
2.16 kB
# handler.py
from typing import Dict, Any
from transformers import AutoProcessor, MusicgenForConditionalGeneration
import torch
import numpy as np
class EndpointHandler:
def __init__(self, path=""):
# Load model and processor from path
self.processor = AutoProcessor.from_pretrained(path)
self.model = MusicgenForConditionalGeneration.from_pretrained(
path,
torch_dtype=torch.float16
).to("cuda")
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
"""
Args:
data (Dict): The request data, containing:
- inputs (Dict): Contains 'prompt' and optional 'duration'
- parameters (Dict, optional): Generation parameters
"""
# Extract inputs and parameters
inputs = data.pop("inputs", data)
parameters = data.pop("parameters", {})
# Get prompt and duration
prompt = inputs.get("prompt", "")
duration = inputs.get("duration", 30) # Default 30 seconds
# Calculate max_new_tokens based on duration
# MusicGen generates audio at 32000 Hz, with each token representing 1024 samples
samples_per_token = 1024
sampling_rate = 32000
max_new_tokens = int((duration * sampling_rate) / samples_per_token)
# Process input text
inputs = self.processor(
text=[prompt],
padding=True,
return_tensors="pt"
).to("cuda")
# Set default generation parameters
generation_params = {
"do_sample": True,
"guidance_scale": 3,
"max_new_tokens": max_new_tokens
}
# Update with any user-provided parameters
generation_params.update(parameters)
# Generate audio
with torch.cuda.amp.autocast():
outputs = self.model.generate(**inputs, **generation_params)
# Convert to list for JSON serialization
generated_audio = outputs.cpu().numpy().tolist()
return [{"generated_audio": generated_audio}]