File size: 4,594 Bytes
30b75e1
48e4064
 
 
dc480b5
48e4064
30b75e1
 
 
48e4064
3570981
dc480b5
 
 
 
 
 
 
 
 
 
 
 
48e4064
dc480b5
 
 
48e4064
dc480b5
 
 
 
 
3570981
dc480b5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3570981
30b75e1
 
dc480b5
30b75e1
 
 
dc480b5
 
30b75e1
 
 
dc480b5
30b75e1
dc480b5
30b75e1
 
dc480b5
30b75e1
 
dc480b5
 
30b75e1
 
 
 
dc480b5
 
30b75e1
dc480b5
 
30b75e1
dc480b5
 
 
 
 
 
06c68e1
dc480b5
 
30b75e1
 
 
 
 
 
 
 
dc480b5
 
30b75e1
dc480b5
 
 
30b75e1
dc480b5
 
 
 
 
 
 
 
 
 
 
 
 
30b75e1
dc480b5
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import logging
from typing import Dict, Any
from transformers import AutoProcessor, MusicgenForConditionalGeneration
import torch
import gc

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

class EndpointHandler:
    def __init__(self, path=""):
        # Enable CUDA optimization
        torch.backends.cuda.matmul.allow_tf32 = True
        torch.backends.cudnn.benchmark = True
        
        # Load processor with optimizations
        logger.info("Loading processor...")
        self.processor = AutoProcessor.from_pretrained(
            path,
            use_fast=True  # Use faster tokenizer
        )
        
        logger.info("Loading model...")
        self.model = MusicgenForConditionalGeneration.from_pretrained(
            path,
            torch_dtype=torch.float16,
            low_cpu_mem_usage=True
        ).to("cuda")
        
        # Set model to eval mode
        self.model.eval()
        
        # Cache sampling rate
        self.sampling_rate = self.model.config.audio_encoder.sampling_rate
        
        # Clear CUDA cache
        torch.cuda.empty_cache()
        gc.collect()
        
        # Quick warmup
        logger.info("Warming up model...")
        self._warmup()

    def _warmup(self):
        """Perform a minimal forward pass to warm up the model"""
        try:
            with torch.no_grad():
                dummy_input = self.processor(
                    text=["test"],
                    padding=True,
                    return_tensors="pt"
                ).to("cuda")
                
                # Minimal generation
                self.model.generate(
                    **dummy_input,
                    max_new_tokens=10,
                    do_sample=False
                )
        except Exception as e:
            logger.warning(f"Warmup failed (non-critical): {e}")

    def __call__(self, data: Dict[str, Any]) -> Any:
        try:
            # Extract inputs and parameters
            inputs = data.get("inputs", data)
            parameters = data.get("parameters", {})
            
            # Efficient input handling
            if isinstance(inputs, dict):
                prompt = inputs.get("text") or inputs.get("prompt")
                duration = inputs.get("duration", 10)
            else:
                prompt = inputs if isinstance(inputs, str) else None
                duration = 10
                
            if 'duration' in parameters:
                duration = parameters.pop('duration')
                
            if not prompt:
                return {"error": "No prompt provided."}
            
            # Preprocess with optimized settings
            input_ids = self.processor(
                text=[prompt],
                padding=True,
                return_tensors="pt",
                truncation=True,
                max_length=512  # Limit input length
            ).to("cuda")
            
            # Optimized generation settings
            gen_kwargs = {
                "max_new_tokens": int(duration * 50),
                "use_cache": True,  # Enable KV-cache
                "do_sample": True,
                "temperature": 0.8,
                "top_k": 50,
                "top_p": 0.95
            }
            
            # Add any custom parameters
            supported_params = [
                "max_length", "min_length", "do_sample", "early_stopping", "num_beams",
                "temperature", "top_k", "top_p", "repetition_penalty", "bad_words_ids",
                "num_return_sequences", "attention_mask"
            ]
            for param in supported_params:
                if param in parameters:
                    gen_kwargs[param] = parameters[param]
                    
            logger.info(f"Generating with prompt: {prompt}")
            logger.info(f"Generation parameters: {gen_kwargs}")
            
            # Generate with optimized settings
            with torch.inference_mode(), torch.autocast("cuda"):
                outputs = self.model.generate(**input_ids, **gen_kwargs)
            
            # Convert output
            audio_tensor = outputs[0].cpu()
            audio_list = audio_tensor.numpy().tolist()
            
            # Clear cache
            torch.cuda.empty_cache()
            
            return [{
                "generated_audio": audio_list,
                "sample_rate": self.sampling_rate,
            }]
            
        except Exception as e:
            logger.error(f"Generation failed: {e}")
            return {"error": str(e)}