File size: 3,311 Bytes
30b75e1
48e4064
 
 
 
30b75e1
 
 
 
48e4064
3570981
 
48e4064
 
3570981
48e4064
3570981
 
30b75e1
3570981
 
 
 
30b75e1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3570981
30b75e1
 
 
3570981
30b75e1
 
 
3570981
30b75e1
 
 
 
 
 
3570981
30b75e1
 
 
06c68e1
30b75e1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import logging
from typing import Dict, Any
from transformers import AutoProcessor, MusicgenForConditionalGeneration
import torch

# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

class EndpointHandler:
    def __init__(self, path=""):
        # Load the processor and model from the specified path
        self.processor = AutoProcessor.from_pretrained(path)
        self.model = MusicgenForConditionalGeneration.from_pretrained(
            path, torch_dtype=torch.float16
        ).to("cuda")
        self.sampling_rate = self.model.config.audio_encoder.sampling_rate

    def __call__(self, data: Dict[str, Any]) -> Any:
        """
        Args:
            data (dict): The payload with the text prompt and generation parameters.
        """
        try:
            # Extract inputs and parameters from the payload
            inputs = data.get("inputs", data)
            parameters = data.get("parameters", {})
            
            # Handle inputs
            if isinstance(inputs, str):
                prompt = inputs
                duration = 10  # Default duration
            elif isinstance(inputs, dict):
                prompt = inputs.get("text") or inputs.get("prompt")
                duration = inputs.get("duration", 10)
            else:
                prompt = None
                duration = 10

            # Override duration if provided in parameters
            if 'duration' in parameters:
                duration = parameters.pop('duration')

            # Validate the prompt
            if not prompt:
                return {"error": "No prompt provided."}

            # Preprocess the prompt
            input_ids = self.processor(
                text=[prompt],
                padding=True,
                return_tensors="pt",
            ).to("cuda")

            # Set generation parameters
            gen_kwargs = {
                "max_new_tokens": int(duration * 50),  # MusicGen uses 50 tokens per second
            }

            # Filter out unsupported parameters
            supported_params = [
                "max_length", "min_length", "do_sample", "early_stopping", "num_beams",
                "temperature", "top_k", "top_p", "repetition_penalty", "bad_words_ids",
                "num_return_sequences", "attention_mask"
            ]
            for param in supported_params:
                if param in parameters:
                    gen_kwargs[param] = parameters[param]

            logger.info(f"Received prompt: {prompt}")
            logger.info(f"Generation parameters: {gen_kwargs}")

            # Generate audio
            with torch.autocast("cuda"):
                outputs = self.model.generate(**input_ids, **gen_kwargs)

            # Convert the output audio tensor to a list of lists (channel-wise)
            audio_tensor = outputs[0].cpu()  # Shape: [num_channels, seq_len]
            audio_list = audio_tensor.numpy().tolist()  # [[channel1_data], [channel2_data]]

            return [
                {
                    "generated_audio": audio_list,
                    "sample_rate": self.sampling_rate,
                }
            ]
        except Exception as e:
            logger.error(f"Exception during generation: {e}")
            return {"error": str(e)}