import logging from typing import Dict, Any from transformers import AutoProcessor, MusicgenForConditionalGeneration import torch # Set up logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) class EndpointHandler: def __init__(self, path=""): # Load the processor and model from the specified path self.processor = AutoProcessor.from_pretrained(path) self.model = MusicgenForConditionalGeneration.from_pretrained( path, torch_dtype=torch.float16 ).to("cuda") self.sampling_rate = self.model.config.audio_encoder.sampling_rate def __call__(self, data: Dict[str, Any]) -> Any: """ Args: data (dict): The payload with the text prompt and generation parameters. """ try: # Extract inputs and parameters from the payload inputs = data.get("inputs", data) parameters = data.get("parameters", {}) # Handle inputs if isinstance(inputs, str): prompt = inputs duration = 10 # Default duration elif isinstance(inputs, dict): prompt = inputs.get("text") or inputs.get("prompt") duration = inputs.get("duration", 10) else: prompt = None duration = 10 # Override duration if provided in parameters if 'duration' in parameters: duration = parameters.pop('duration') # Validate the prompt if not prompt: return {"error": "No prompt provided."} # Preprocess the prompt input_ids = self.processor( text=[prompt], padding=True, return_tensors="pt", ).to("cuda") # Set generation parameters gen_kwargs = { "max_new_tokens": int(duration * 50), # MusicGen uses 50 tokens per second } # Filter out unsupported parameters supported_params = [ "max_length", "min_length", "do_sample", "early_stopping", "num_beams", "temperature", "top_k", "top_p", "repetition_penalty", "bad_words_ids", "num_return_sequences", "attention_mask" ] for param in supported_params: if param in parameters: gen_kwargs[param] = parameters[param] logger.info(f"Received prompt: {prompt}") logger.info(f"Generation parameters: {gen_kwargs}") # Generate audio with torch.autocast("cuda"): outputs = self.model.generate(**input_ids, **gen_kwargs) # Convert the output audio tensor to a list of lists (channel-wise) audio_tensor = outputs[0].cpu() # Shape: [num_channels, seq_len] audio_list = audio_tensor.numpy().tolist() # [[channel1_data], [channel2_data]] return [ { "generated_audio": audio_list, "sample_rate": self.sampling_rate, } ] except Exception as e: logger.error(f"Exception during generation: {e}") return {"error": str(e)}