import logging from typing import Dict, Any from transformers import AutoProcessor, MusicgenForConditionalGeneration import torch import gc logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) class EndpointHandler: def __init__(self, path=""): # Enable CUDA optimization torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.benchmark = True # Load processor with optimizations logger.info("Loading processor...") self.processor = AutoProcessor.from_pretrained( path, use_fast=True # Use faster tokenizer ) logger.info("Loading model...") self.model = MusicgenForConditionalGeneration.from_pretrained( path, torch_dtype=torch.float16, low_cpu_mem_usage=True ).to("cuda") # Set model to eval mode self.model.eval() # Cache sampling rate self.sampling_rate = self.model.config.audio_encoder.sampling_rate # Clear CUDA cache torch.cuda.empty_cache() gc.collect() # Quick warmup logger.info("Warming up model...") self._warmup() def _warmup(self): """Perform a minimal forward pass to warm up the model""" try: with torch.no_grad(): dummy_input = self.processor( text=["test"], padding=True, return_tensors="pt" ).to("cuda") # Minimal generation self.model.generate( **dummy_input, max_new_tokens=10, do_sample=False ) except Exception as e: logger.warning(f"Warmup failed (non-critical): {e}") def __call__(self, data: Dict[str, Any]) -> Any: try: # Extract inputs and parameters inputs = data.get("inputs", data) parameters = data.get("parameters", {}) # Efficient input handling if isinstance(inputs, dict): prompt = inputs.get("text") or inputs.get("prompt") duration = inputs.get("duration", 10) else: prompt = inputs if isinstance(inputs, str) else None duration = 10 if 'duration' in parameters: duration = parameters.pop('duration') if not prompt: return {"error": "No prompt provided."} # Preprocess with optimized settings input_ids = self.processor( text=[prompt], padding=True, return_tensors="pt", truncation=True, max_length=512 # Limit input length ).to("cuda") # Optimized generation settings gen_kwargs = { "max_new_tokens": int(duration * 50), "use_cache": True, # Enable KV-cache "do_sample": True, "temperature": 0.8, "top_k": 50, "top_p": 0.95 } # Add any custom parameters supported_params = [ "max_length", "min_length", "do_sample", "early_stopping", "num_beams", "temperature", "top_k", "top_p", "repetition_penalty", "bad_words_ids", "num_return_sequences", "attention_mask" ] for param in supported_params: if param in parameters: gen_kwargs[param] = parameters[param] logger.info(f"Generating with prompt: {prompt}") logger.info(f"Generation parameters: {gen_kwargs}") # Generate with optimized settings with torch.inference_mode(), torch.autocast("cuda"): outputs = self.model.generate(**input_ids, **gen_kwargs) # Convert output audio_tensor = outputs[0].cpu() audio_list = audio_tensor.numpy().tolist() # Clear cache torch.cuda.empty_cache() return [{ "generated_audio": audio_list, "sample_rate": self.sampling_rate, }] except Exception as e: logger.error(f"Generation failed: {e}") return {"error": str(e)}