# handler.py from typing import Dict, Any from transformers import AutoProcessor, MusicgenForConditionalGeneration import torch import numpy as np class EndpointHandler: def __init__(self, path=""): # Load model and processor from path self.processor = AutoProcessor.from_pretrained(path) self.model = MusicgenForConditionalGeneration.from_pretrained( path, torch_dtype=torch.float16 ).to("cuda") def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: """ Args: data (Dict): The request data, containing: - inputs (Dict): Contains 'prompt' and optional 'duration' - parameters (Dict, optional): Generation parameters """ # Extract inputs and parameters inputs = data.pop("inputs", data) parameters = data.pop("parameters", {}) # Get prompt and duration prompt = inputs.get("prompt", "") duration = inputs.get("duration", 30) # Default 30 seconds # Calculate max_new_tokens based on duration # MusicGen generates audio at 32000 Hz, with each token representing 1024 samples samples_per_token = 1024 sampling_rate = 32000 max_new_tokens = int((duration * sampling_rate) / samples_per_token) # Process input text inputs = self.processor( text=[prompt], padding=True, return_tensors="pt" ).to("cuda") # Set default generation parameters generation_params = { "do_sample": True, "guidance_scale": 3, "max_new_tokens": max_new_tokens } # Update with any user-provided parameters generation_params.update(parameters) # Generate audio with torch.cuda.amp.autocast(): outputs = self.model.generate(**inputs, **generation_params) # Convert to list for JSON serialization generated_audio = outputs.cpu().numpy().tolist() return [{"generated_audio": generated_audio}]