Phoenixak99 commited on
Commit
06c68e1
·
verified ·
1 Parent(s): c98fa01

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +45 -47
handler.py CHANGED
@@ -2,60 +2,58 @@
2
  from typing import Dict, Any
3
  from transformers import AutoProcessor, MusicgenForConditionalGeneration
4
  import torch
5
- import numpy as np
6
 
7
  class EndpointHandler:
8
  def __init__(self, path=""):
9
- # Load model and processor from path
10
  self.processor = AutoProcessor.from_pretrained(path)
11
  self.model = MusicgenForConditionalGeneration.from_pretrained(
12
  path,
13
- torch_dtype=torch.float16
 
14
  ).to("cuda")
15
 
16
  def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
17
- """
18
- Args:
19
- data (Dict): The request data, containing:
20
- - inputs (Dict): Contains 'prompt' and optional 'duration'
21
- - parameters (Dict, optional): Generation parameters
22
- """
23
- # Extract inputs and parameters
24
- inputs = data.pop("inputs", data)
25
- parameters = data.pop("parameters", {})
26
-
27
- # Get prompt and duration
28
- prompt = inputs.get("prompt", "")
29
- duration = inputs.get("duration", 30) # Default 30 seconds
30
-
31
- # Calculate max_new_tokens based on duration
32
- # MusicGen generates audio at 32000 Hz, with each token representing 1024 samples
33
- samples_per_token = 1024
34
- sampling_rate = 32000
35
- max_new_tokens = int((duration * sampling_rate) / samples_per_token)
36
-
37
- # Process input text
38
- inputs = self.processor(
39
- text=[prompt],
40
- padding=True,
41
- return_tensors="pt"
42
- ).to("cuda")
43
-
44
- # Set default generation parameters
45
- generation_params = {
46
- "do_sample": True,
47
- "guidance_scale": 3,
48
- "max_new_tokens": max_new_tokens
49
- }
50
-
51
- # Update with any user-provided parameters
52
- generation_params.update(parameters)
53
-
54
- # Generate audio
55
- with torch.cuda.amp.autocast():
56
- outputs = self.model.generate(**inputs, **generation_params)
57
 
58
- # Convert to list for JSON serialization
59
- generated_audio = outputs.cpu().numpy().tolist()
60
-
61
- return [{"generated_audio": generated_audio}]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  from typing import Dict, Any
3
  from transformers import AutoProcessor, MusicgenForConditionalGeneration
4
  import torch
 
5
 
6
  class EndpointHandler:
7
  def __init__(self, path=""):
8
+ """Initialize the model and processor."""
9
  self.processor = AutoProcessor.from_pretrained(path)
10
  self.model = MusicgenForConditionalGeneration.from_pretrained(
11
  path,
12
+ torch_dtype=torch.float16,
13
+ device_map="auto" # Added for better GPU management
14
  ).to("cuda")
15
 
16
  def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
17
+ """Process the input data and generate audio."""
18
+ try:
19
+ # Extract inputs and parameters
20
+ inputs = data.pop("inputs", data)
21
+ parameters = data.pop("parameters", {})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
+ # Get prompt and duration
24
+ prompt = inputs.get("prompt", "")
25
+ duration = inputs.get("duration", 30)
26
+
27
+ # Calculate max_new_tokens based on duration
28
+ samples_per_token = 1024
29
+ sampling_rate = 32000
30
+ max_new_tokens = int((duration * sampling_rate) / samples_per_token)
31
+
32
+ # Process input text
33
+ model_inputs = self.processor(
34
+ text=[prompt],
35
+ padding=True,
36
+ return_tensors="pt"
37
+ ).to("cuda")
38
+
39
+ # Set default generation parameters
40
+ generation_params = {
41
+ "do_sample": True,
42
+ "guidance_scale": 3,
43
+ "max_new_tokens": max_new_tokens
44
+ }
45
+
46
+ # Update with any user-provided parameters
47
+ generation_params.update(parameters)
48
+
49
+ # Generate audio with autocast for memory efficiency
50
+ with torch.cuda.amp.autocast():
51
+ audio_values = self.model.generate(**model_inputs, **generation_params)
52
+
53
+ # Convert to list for JSON serialization
54
+ audio_data = audio_values.cpu().numpy().tolist()
55
+
56
+ return [{"generated_audio": audio_data}]
57
+
58
+ except Exception as e:
59
+ return {"error": str(e)}