Phoenixak99 commited on
Commit
7288895
1 Parent(s): dc480b5

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +44 -88
handler.py CHANGED
@@ -2,104 +2,62 @@ import logging
2
  from typing import Dict, Any
3
  from transformers import AutoProcessor, MusicgenForConditionalGeneration
4
  import torch
5
- import gc
6
 
 
7
  logging.basicConfig(level=logging.INFO)
8
  logger = logging.getLogger(__name__)
9
 
10
  class EndpointHandler:
11
  def __init__(self, path=""):
12
- # Enable CUDA optimization
13
- torch.backends.cuda.matmul.allow_tf32 = True
14
- torch.backends.cudnn.benchmark = True
15
-
16
- # Load processor with optimizations
17
- logger.info("Loading processor...")
18
- self.processor = AutoProcessor.from_pretrained(
19
- path,
20
- use_fast=True # Use faster tokenizer
21
- )
22
-
23
- logger.info("Loading model...")
24
  self.model = MusicgenForConditionalGeneration.from_pretrained(
25
- path,
26
- torch_dtype=torch.float16,
27
- low_cpu_mem_usage=True
28
  ).to("cuda")
29
-
30
- # Set model to eval mode
31
- self.model.eval()
32
-
33
- # Cache sampling rate
34
  self.sampling_rate = self.model.config.audio_encoder.sampling_rate
35
-
36
- # Clear CUDA cache
37
- torch.cuda.empty_cache()
38
- gc.collect()
39
-
40
- # Quick warmup
41
- logger.info("Warming up model...")
42
- self._warmup()
43
-
44
- def _warmup(self):
45
- """Perform a minimal forward pass to warm up the model"""
46
- try:
47
- with torch.no_grad():
48
- dummy_input = self.processor(
49
- text=["test"],
50
- padding=True,
51
- return_tensors="pt"
52
- ).to("cuda")
53
-
54
- # Minimal generation
55
- self.model.generate(
56
- **dummy_input,
57
- max_new_tokens=10,
58
- do_sample=False
59
- )
60
- except Exception as e:
61
- logger.warning(f"Warmup failed (non-critical): {e}")
62
 
63
  def __call__(self, data: Dict[str, Any]) -> Any:
 
 
 
 
64
  try:
65
- # Extract inputs and parameters
66
  inputs = data.get("inputs", data)
67
  parameters = data.get("parameters", {})
68
-
69
- # Efficient input handling
70
- if isinstance(inputs, dict):
 
 
 
71
  prompt = inputs.get("text") or inputs.get("prompt")
72
  duration = inputs.get("duration", 10)
73
  else:
74
- prompt = inputs if isinstance(inputs, str) else None
75
  duration = 10
76
-
 
77
  if 'duration' in parameters:
78
  duration = parameters.pop('duration')
79
-
 
80
  if not prompt:
81
  return {"error": "No prompt provided."}
82
-
83
- # Preprocess with optimized settings
84
  input_ids = self.processor(
85
  text=[prompt],
86
  padding=True,
87
  return_tensors="pt",
88
- truncation=True,
89
- max_length=512 # Limit input length
90
  ).to("cuda")
91
-
92
- # Optimized generation settings
93
  gen_kwargs = {
94
- "max_new_tokens": int(duration * 50),
95
- "use_cache": True, # Enable KV-cache
96
- "do_sample": True,
97
- "temperature": 0.8,
98
- "top_k": 50,
99
- "top_p": 0.95
100
  }
101
-
102
- # Add any custom parameters
103
  supported_params = [
104
  "max_length", "min_length", "do_sample", "early_stopping", "num_beams",
105
  "temperature", "top_k", "top_p", "repetition_penalty", "bad_words_ids",
@@ -108,26 +66,24 @@ class EndpointHandler:
108
  for param in supported_params:
109
  if param in parameters:
110
  gen_kwargs[param] = parameters[param]
111
-
112
- logger.info(f"Generating with prompt: {prompt}")
113
  logger.info(f"Generation parameters: {gen_kwargs}")
114
-
115
- # Generate with optimized settings
116
- with torch.inference_mode(), torch.autocast("cuda"):
117
  outputs = self.model.generate(**input_ids, **gen_kwargs)
118
-
119
- # Convert output
120
- audio_tensor = outputs[0].cpu()
121
- audio_list = audio_tensor.numpy().tolist()
122
-
123
- # Clear cache
124
- torch.cuda.empty_cache()
125
-
126
- return [{
127
- "generated_audio": audio_list,
128
- "sample_rate": self.sampling_rate,
129
- }]
130
-
131
  except Exception as e:
132
- logger.error(f"Generation failed: {e}")
133
  return {"error": str(e)}
 
2
  from typing import Dict, Any
3
  from transformers import AutoProcessor, MusicgenForConditionalGeneration
4
  import torch
 
5
 
6
+ # Set up logging
7
  logging.basicConfig(level=logging.INFO)
8
  logger = logging.getLogger(__name__)
9
 
10
  class EndpointHandler:
11
  def __init__(self, path=""):
12
+ # Load the processor and model from the specified path
13
+ self.processor = AutoProcessor.from_pretrained(path)
 
 
 
 
 
 
 
 
 
 
14
  self.model = MusicgenForConditionalGeneration.from_pretrained(
15
+ path, torch_dtype=torch.float16
 
 
16
  ).to("cuda")
 
 
 
 
 
17
  self.sampling_rate = self.model.config.audio_encoder.sampling_rate
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
  def __call__(self, data: Dict[str, Any]) -> Any:
20
+ """
21
+ Args:
22
+ data (dict): The payload with the text prompt and generation parameters.
23
+ """
24
  try:
25
+ # Extract inputs and parameters from the payload
26
  inputs = data.get("inputs", data)
27
  parameters = data.get("parameters", {})
28
+
29
+ # Handle inputs
30
+ if isinstance(inputs, str):
31
+ prompt = inputs
32
+ duration = 10 # Default duration
33
+ elif isinstance(inputs, dict):
34
  prompt = inputs.get("text") or inputs.get("prompt")
35
  duration = inputs.get("duration", 10)
36
  else:
37
+ prompt = None
38
  duration = 10
39
+
40
+ # Override duration if provided in parameters
41
  if 'duration' in parameters:
42
  duration = parameters.pop('duration')
43
+
44
+ # Validate the prompt
45
  if not prompt:
46
  return {"error": "No prompt provided."}
47
+
48
+ # Preprocess the prompt
49
  input_ids = self.processor(
50
  text=[prompt],
51
  padding=True,
52
  return_tensors="pt",
 
 
53
  ).to("cuda")
54
+
55
+ # Set generation parameters
56
  gen_kwargs = {
57
+ "max_new_tokens": int(duration * 50), # MusicGen uses 50 tokens per second
 
 
 
 
 
58
  }
59
+
60
+ # Filter out unsupported parameters
61
  supported_params = [
62
  "max_length", "min_length", "do_sample", "early_stopping", "num_beams",
63
  "temperature", "top_k", "top_p", "repetition_penalty", "bad_words_ids",
 
66
  for param in supported_params:
67
  if param in parameters:
68
  gen_kwargs[param] = parameters[param]
69
+
70
+ logger.info(f"Received prompt: {prompt}")
71
  logger.info(f"Generation parameters: {gen_kwargs}")
72
+
73
+ # Generate audio
74
+ with torch.autocast("cuda"):
75
  outputs = self.model.generate(**input_ids, **gen_kwargs)
76
+
77
+ # Convert the output audio tensor to a list of lists (channel-wise)
78
+ audio_tensor = outputs[0].cpu() # Shape: [num_channels, seq_len]
79
+ audio_list = audio_tensor.numpy().tolist() # [[channel1_data], [channel2_data]]
80
+
81
+ return [
82
+ {
83
+ "generated_audio": audio_list,
84
+ "sample_rate": self.sampling_rate,
85
+ }
86
+ ]
 
 
87
  except Exception as e:
88
+ logger.error(f"Exception during generation: {e}")
89
  return {"error": str(e)}