Phoenixak99's picture
Update handler.py
dc480b5 verified
raw
history blame
4.59 kB
import logging
from typing import Dict, Any
from transformers import AutoProcessor, MusicgenForConditionalGeneration
import torch
import gc
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class EndpointHandler:
def __init__(self, path=""):
# Enable CUDA optimization
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.benchmark = True
# Load processor with optimizations
logger.info("Loading processor...")
self.processor = AutoProcessor.from_pretrained(
path,
use_fast=True # Use faster tokenizer
)
logger.info("Loading model...")
self.model = MusicgenForConditionalGeneration.from_pretrained(
path,
torch_dtype=torch.float16,
low_cpu_mem_usage=True
).to("cuda")
# Set model to eval mode
self.model.eval()
# Cache sampling rate
self.sampling_rate = self.model.config.audio_encoder.sampling_rate
# Clear CUDA cache
torch.cuda.empty_cache()
gc.collect()
# Quick warmup
logger.info("Warming up model...")
self._warmup()
def _warmup(self):
"""Perform a minimal forward pass to warm up the model"""
try:
with torch.no_grad():
dummy_input = self.processor(
text=["test"],
padding=True,
return_tensors="pt"
).to("cuda")
# Minimal generation
self.model.generate(
**dummy_input,
max_new_tokens=10,
do_sample=False
)
except Exception as e:
logger.warning(f"Warmup failed (non-critical): {e}")
def __call__(self, data: Dict[str, Any]) -> Any:
try:
# Extract inputs and parameters
inputs = data.get("inputs", data)
parameters = data.get("parameters", {})
# Efficient input handling
if isinstance(inputs, dict):
prompt = inputs.get("text") or inputs.get("prompt")
duration = inputs.get("duration", 10)
else:
prompt = inputs if isinstance(inputs, str) else None
duration = 10
if 'duration' in parameters:
duration = parameters.pop('duration')
if not prompt:
return {"error": "No prompt provided."}
# Preprocess with optimized settings
input_ids = self.processor(
text=[prompt],
padding=True,
return_tensors="pt",
truncation=True,
max_length=512 # Limit input length
).to("cuda")
# Optimized generation settings
gen_kwargs = {
"max_new_tokens": int(duration * 50),
"use_cache": True, # Enable KV-cache
"do_sample": True,
"temperature": 0.8,
"top_k": 50,
"top_p": 0.95
}
# Add any custom parameters
supported_params = [
"max_length", "min_length", "do_sample", "early_stopping", "num_beams",
"temperature", "top_k", "top_p", "repetition_penalty", "bad_words_ids",
"num_return_sequences", "attention_mask"
]
for param in supported_params:
if param in parameters:
gen_kwargs[param] = parameters[param]
logger.info(f"Generating with prompt: {prompt}")
logger.info(f"Generation parameters: {gen_kwargs}")
# Generate with optimized settings
with torch.inference_mode(), torch.autocast("cuda"):
outputs = self.model.generate(**input_ids, **gen_kwargs)
# Convert output
audio_tensor = outputs[0].cpu()
audio_list = audio_tensor.numpy().tolist()
# Clear cache
torch.cuda.empty_cache()
return [{
"generated_audio": audio_list,
"sample_rate": self.sampling_rate,
}]
except Exception as e:
logger.error(f"Generation failed: {e}")
return {"error": str(e)}