Phoenixak99 commited on
Commit
30b75e1
1 Parent(s): 247afcc

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +67 -33
handler.py CHANGED
@@ -1,7 +1,12 @@
 
1
  from typing import Dict, Any
2
  from transformers import AutoProcessor, MusicgenForConditionalGeneration
3
  import torch
4
 
 
 
 
 
5
  class EndpointHandler:
6
  def __init__(self, path=""):
7
  # Load the processor and model from the specified path
@@ -11,45 +16,74 @@ class EndpointHandler:
11
  ).to("cuda")
12
  self.sampling_rate = self.model.config.audio_encoder.sampling_rate
13
 
14
- def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
15
  """
16
  Args:
17
  data (dict): The payload with the text prompt and generation parameters.
18
  """
19
- # Extract inputs and parameters from the payload
20
- inputs = data.get("inputs", {})
21
- prompt = inputs.get("prompt", "")
22
- duration = inputs.get("duration", 10)
23
- parameters = data.get("parameters", {})
24
-
25
- # Validate the prompt
26
- if not prompt:
27
- return {"error": "No prompt provided."}
28
-
29
- # Preprocess the prompt
30
- input_ids = self.processor(
31
- text=[prompt],
32
- padding=True,
33
- return_tensors="pt",
34
- ).to("cuda")
35
 
36
- # Set generation parameters
37
- gen_kwargs = {
38
- "max_new_tokens": int(duration * 50), # MusicGen uses 50 tokens per second
39
- **parameters,
40
- }
41
 
42
- # Generate audio
43
- with torch.autocast("cuda"):
44
- outputs = self.model.generate(**input_ids, **gen_kwargs)
45
 
46
- # Convert the output audio tensor to a list of lists (channel-wise)
47
- audio_tensor = outputs[0].cpu() # Shape: [num_channels, seq_len]
48
- audio_list = audio_tensor.numpy().tolist() # [[channel1_data], [channel2_data]]
 
 
 
49
 
50
- return [
51
- {
52
- "generated_audio": audio_list,
53
- "sample_rate": self.sampling_rate,
54
  }
55
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
  from typing import Dict, Any
3
  from transformers import AutoProcessor, MusicgenForConditionalGeneration
4
  import torch
5
 
6
+ # Set up logging
7
+ logging.basicConfig(level=logging.INFO)
8
+ logger = logging.getLogger(__name__)
9
+
10
  class EndpointHandler:
11
  def __init__(self, path=""):
12
  # Load the processor and model from the specified path
 
16
  ).to("cuda")
17
  self.sampling_rate = self.model.config.audio_encoder.sampling_rate
18
 
19
+ def __call__(self, data: Dict[str, Any]) -> Any:
20
  """
21
  Args:
22
  data (dict): The payload with the text prompt and generation parameters.
23
  """
24
+ try:
25
+ # Extract inputs and parameters from the payload
26
+ inputs = data.get("inputs", data)
27
+ parameters = data.get("parameters", {})
28
+
29
+ # Handle inputs
30
+ if isinstance(inputs, str):
31
+ prompt = inputs
32
+ duration = 10 # Default duration
33
+ elif isinstance(inputs, dict):
34
+ prompt = inputs.get("text") or inputs.get("prompt")
35
+ duration = inputs.get("duration", 10)
36
+ else:
37
+ prompt = None
38
+ duration = 10
 
39
 
40
+ # Override duration if provided in parameters
41
+ if 'duration' in parameters:
42
+ duration = parameters.pop('duration')
 
 
43
 
44
+ # Validate the prompt
45
+ if not prompt:
46
+ return {"error": "No prompt provided."}
47
 
48
+ # Preprocess the prompt
49
+ input_ids = self.processor(
50
+ text=[prompt],
51
+ padding=True,
52
+ return_tensors="pt",
53
+ ).to("cuda")
54
 
55
+ # Set generation parameters
56
+ gen_kwargs = {
57
+ "max_new_tokens": int(duration * 50), # MusicGen uses 50 tokens per second
 
58
  }
59
+
60
+ # Filter out unsupported parameters
61
+ supported_params = [
62
+ "max_length", "min_length", "do_sample", "early_stopping", "num_beams",
63
+ "temperature", "top_k", "top_p", "repetition_penalty", "bad_words_ids",
64
+ "num_return_sequences", "attention_mask"
65
+ ]
66
+ for param in supported_params:
67
+ if param in parameters:
68
+ gen_kwargs[param] = parameters[param]
69
+
70
+ logger.info(f"Received prompt: {prompt}")
71
+ logger.info(f"Generation parameters: {gen_kwargs}")
72
+
73
+ # Generate audio
74
+ with torch.autocast("cuda"):
75
+ outputs = self.model.generate(**input_ids, **gen_kwargs)
76
+
77
+ # Convert the output audio tensor to a list of lists (channel-wise)
78
+ audio_tensor = outputs[0].cpu() # Shape: [num_channels, seq_len]
79
+ audio_list = audio_tensor.numpy().tolist() # [[channel1_data], [channel2_data]]
80
+
81
+ return [
82
+ {
83
+ "generated_audio": audio_list,
84
+ "sample_rate": self.sampling_rate,
85
+ }
86
+ ]
87
+ except Exception as e:
88
+ logger.error(f"Exception during generation: {e}")
89
+ return {"error": str(e)}