File size: 4,594 Bytes
30b75e1 48e4064 dc480b5 48e4064 30b75e1 48e4064 3570981 dc480b5 48e4064 dc480b5 48e4064 dc480b5 3570981 dc480b5 3570981 30b75e1 dc480b5 30b75e1 dc480b5 30b75e1 dc480b5 30b75e1 dc480b5 30b75e1 dc480b5 30b75e1 dc480b5 30b75e1 dc480b5 30b75e1 dc480b5 30b75e1 dc480b5 06c68e1 dc480b5 30b75e1 dc480b5 30b75e1 dc480b5 30b75e1 dc480b5 30b75e1 dc480b5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 |
import logging
from typing import Dict, Any
from transformers import AutoProcessor, MusicgenForConditionalGeneration
import torch
import gc
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class EndpointHandler:
def __init__(self, path=""):
# Enable CUDA optimization
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.benchmark = True
# Load processor with optimizations
logger.info("Loading processor...")
self.processor = AutoProcessor.from_pretrained(
path,
use_fast=True # Use faster tokenizer
)
logger.info("Loading model...")
self.model = MusicgenForConditionalGeneration.from_pretrained(
path,
torch_dtype=torch.float16,
low_cpu_mem_usage=True
).to("cuda")
# Set model to eval mode
self.model.eval()
# Cache sampling rate
self.sampling_rate = self.model.config.audio_encoder.sampling_rate
# Clear CUDA cache
torch.cuda.empty_cache()
gc.collect()
# Quick warmup
logger.info("Warming up model...")
self._warmup()
def _warmup(self):
"""Perform a minimal forward pass to warm up the model"""
try:
with torch.no_grad():
dummy_input = self.processor(
text=["test"],
padding=True,
return_tensors="pt"
).to("cuda")
# Minimal generation
self.model.generate(
**dummy_input,
max_new_tokens=10,
do_sample=False
)
except Exception as e:
logger.warning(f"Warmup failed (non-critical): {e}")
def __call__(self, data: Dict[str, Any]) -> Any:
try:
# Extract inputs and parameters
inputs = data.get("inputs", data)
parameters = data.get("parameters", {})
# Efficient input handling
if isinstance(inputs, dict):
prompt = inputs.get("text") or inputs.get("prompt")
duration = inputs.get("duration", 10)
else:
prompt = inputs if isinstance(inputs, str) else None
duration = 10
if 'duration' in parameters:
duration = parameters.pop('duration')
if not prompt:
return {"error": "No prompt provided."}
# Preprocess with optimized settings
input_ids = self.processor(
text=[prompt],
padding=True,
return_tensors="pt",
truncation=True,
max_length=512 # Limit input length
).to("cuda")
# Optimized generation settings
gen_kwargs = {
"max_new_tokens": int(duration * 50),
"use_cache": True, # Enable KV-cache
"do_sample": True,
"temperature": 0.8,
"top_k": 50,
"top_p": 0.95
}
# Add any custom parameters
supported_params = [
"max_length", "min_length", "do_sample", "early_stopping", "num_beams",
"temperature", "top_k", "top_p", "repetition_penalty", "bad_words_ids",
"num_return_sequences", "attention_mask"
]
for param in supported_params:
if param in parameters:
gen_kwargs[param] = parameters[param]
logger.info(f"Generating with prompt: {prompt}")
logger.info(f"Generation parameters: {gen_kwargs}")
# Generate with optimized settings
with torch.inference_mode(), torch.autocast("cuda"):
outputs = self.model.generate(**input_ids, **gen_kwargs)
# Convert output
audio_tensor = outputs[0].cpu()
audio_list = audio_tensor.numpy().tolist()
# Clear cache
torch.cuda.empty_cache()
return [{
"generated_audio": audio_list,
"sample_rate": self.sampling_rate,
}]
except Exception as e:
logger.error(f"Generation failed: {e}")
return {"error": str(e)} |