Phoenixak99
commited on
Commit
•
30b75e1
1
Parent(s):
247afcc
Update handler.py
Browse files- handler.py +67 -33
handler.py
CHANGED
@@ -1,7 +1,12 @@
|
|
|
|
1 |
from typing import Dict, Any
|
2 |
from transformers import AutoProcessor, MusicgenForConditionalGeneration
|
3 |
import torch
|
4 |
|
|
|
|
|
|
|
|
|
5 |
class EndpointHandler:
|
6 |
def __init__(self, path=""):
|
7 |
# Load the processor and model from the specified path
|
@@ -11,45 +16,74 @@ class EndpointHandler:
|
|
11 |
).to("cuda")
|
12 |
self.sampling_rate = self.model.config.audio_encoder.sampling_rate
|
13 |
|
14 |
-
def __call__(self, data: Dict[str, Any]) ->
|
15 |
"""
|
16 |
Args:
|
17 |
data (dict): The payload with the text prompt and generation parameters.
|
18 |
"""
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
).to("cuda")
|
35 |
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
**parameters,
|
40 |
-
}
|
41 |
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
|
46 |
-
|
47 |
-
|
48 |
-
|
|
|
|
|
|
|
49 |
|
50 |
-
|
51 |
-
{
|
52 |
-
"
|
53 |
-
"sample_rate": self.sampling_rate,
|
54 |
}
|
55 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
from typing import Dict, Any
|
3 |
from transformers import AutoProcessor, MusicgenForConditionalGeneration
|
4 |
import torch
|
5 |
|
6 |
+
# Set up logging
|
7 |
+
logging.basicConfig(level=logging.INFO)
|
8 |
+
logger = logging.getLogger(__name__)
|
9 |
+
|
10 |
class EndpointHandler:
|
11 |
def __init__(self, path=""):
|
12 |
# Load the processor and model from the specified path
|
|
|
16 |
).to("cuda")
|
17 |
self.sampling_rate = self.model.config.audio_encoder.sampling_rate
|
18 |
|
19 |
+
def __call__(self, data: Dict[str, Any]) -> Any:
|
20 |
"""
|
21 |
Args:
|
22 |
data (dict): The payload with the text prompt and generation parameters.
|
23 |
"""
|
24 |
+
try:
|
25 |
+
# Extract inputs and parameters from the payload
|
26 |
+
inputs = data.get("inputs", data)
|
27 |
+
parameters = data.get("parameters", {})
|
28 |
+
|
29 |
+
# Handle inputs
|
30 |
+
if isinstance(inputs, str):
|
31 |
+
prompt = inputs
|
32 |
+
duration = 10 # Default duration
|
33 |
+
elif isinstance(inputs, dict):
|
34 |
+
prompt = inputs.get("text") or inputs.get("prompt")
|
35 |
+
duration = inputs.get("duration", 10)
|
36 |
+
else:
|
37 |
+
prompt = None
|
38 |
+
duration = 10
|
|
|
39 |
|
40 |
+
# Override duration if provided in parameters
|
41 |
+
if 'duration' in parameters:
|
42 |
+
duration = parameters.pop('duration')
|
|
|
|
|
43 |
|
44 |
+
# Validate the prompt
|
45 |
+
if not prompt:
|
46 |
+
return {"error": "No prompt provided."}
|
47 |
|
48 |
+
# Preprocess the prompt
|
49 |
+
input_ids = self.processor(
|
50 |
+
text=[prompt],
|
51 |
+
padding=True,
|
52 |
+
return_tensors="pt",
|
53 |
+
).to("cuda")
|
54 |
|
55 |
+
# Set generation parameters
|
56 |
+
gen_kwargs = {
|
57 |
+
"max_new_tokens": int(duration * 50), # MusicGen uses 50 tokens per second
|
|
|
58 |
}
|
59 |
+
|
60 |
+
# Filter out unsupported parameters
|
61 |
+
supported_params = [
|
62 |
+
"max_length", "min_length", "do_sample", "early_stopping", "num_beams",
|
63 |
+
"temperature", "top_k", "top_p", "repetition_penalty", "bad_words_ids",
|
64 |
+
"num_return_sequences", "attention_mask"
|
65 |
+
]
|
66 |
+
for param in supported_params:
|
67 |
+
if param in parameters:
|
68 |
+
gen_kwargs[param] = parameters[param]
|
69 |
+
|
70 |
+
logger.info(f"Received prompt: {prompt}")
|
71 |
+
logger.info(f"Generation parameters: {gen_kwargs}")
|
72 |
+
|
73 |
+
# Generate audio
|
74 |
+
with torch.autocast("cuda"):
|
75 |
+
outputs = self.model.generate(**input_ids, **gen_kwargs)
|
76 |
+
|
77 |
+
# Convert the output audio tensor to a list of lists (channel-wise)
|
78 |
+
audio_tensor = outputs[0].cpu() # Shape: [num_channels, seq_len]
|
79 |
+
audio_list = audio_tensor.numpy().tolist() # [[channel1_data], [channel2_data]]
|
80 |
+
|
81 |
+
return [
|
82 |
+
{
|
83 |
+
"generated_audio": audio_list,
|
84 |
+
"sample_rate": self.sampling_rate,
|
85 |
+
}
|
86 |
+
]
|
87 |
+
except Exception as e:
|
88 |
+
logger.error(f"Exception during generation: {e}")
|
89 |
+
return {"error": str(e)}
|