Phoenixak99 commited on
Commit
48e4064
·
verified ·
1 Parent(s): 15ccdc9

Create handler.py

Browse files
Files changed (1) hide show
  1. handler.py +61 -0
handler.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # handler.py
2
+ from typing import Dict, Any
3
+ from transformers import AutoProcessor, MusicgenForConditionalGeneration
4
+ import torch
5
+ import numpy as np
6
+
7
+ class EndpointHandler:
8
+ def __init__(self, path=""):
9
+ # Load model and processor from path
10
+ self.processor = AutoProcessor.from_pretrained(path)
11
+ self.model = MusicgenForConditionalGeneration.from_pretrained(
12
+ path,
13
+ torch_dtype=torch.float16
14
+ ).to("cuda")
15
+
16
+ def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
17
+ """
18
+ Args:
19
+ data (Dict): The request data, containing:
20
+ - inputs (Dict): Contains 'prompt' and optional 'duration'
21
+ - parameters (Dict, optional): Generation parameters
22
+ """
23
+ # Extract inputs and parameters
24
+ inputs = data.pop("inputs", data)
25
+ parameters = data.pop("parameters", {})
26
+
27
+ # Get prompt and duration
28
+ prompt = inputs.get("prompt", "")
29
+ duration = inputs.get("duration", 30) # Default 30 seconds
30
+
31
+ # Calculate max_new_tokens based on duration
32
+ # MusicGen generates audio at 32000 Hz, with each token representing 1024 samples
33
+ samples_per_token = 1024
34
+ sampling_rate = 32000
35
+ max_new_tokens = int((duration * sampling_rate) / samples_per_token)
36
+
37
+ # Process input text
38
+ inputs = self.processor(
39
+ text=[prompt],
40
+ padding=True,
41
+ return_tensors="pt"
42
+ ).to("cuda")
43
+
44
+ # Set default generation parameters
45
+ generation_params = {
46
+ "do_sample": True,
47
+ "guidance_scale": 3,
48
+ "max_new_tokens": max_new_tokens
49
+ }
50
+
51
+ # Update with any user-provided parameters
52
+ generation_params.update(parameters)
53
+
54
+ # Generate audio
55
+ with torch.cuda.amp.autocast():
56
+ outputs = self.model.generate(**inputs, **generation_params)
57
+
58
+ # Convert to list for JSON serialization
59
+ generated_audio = outputs.cpu().numpy().tolist()
60
+
61
+ return [{"generated_audio": generated_audio}]