Phoenixak99
commited on
Commit
•
3570981
1
Parent(s):
9d2438d
Update handler.py
Browse files- handler.py +41 -68
handler.py
CHANGED
@@ -1,78 +1,51 @@
|
|
1 |
-
# app.py
|
2 |
-
from fastapi import FastAPI, Request
|
3 |
-
from handler import EndpointHandler
|
4 |
-
import json
|
5 |
-
|
6 |
-
app = FastAPI()
|
7 |
-
handler = None
|
8 |
-
|
9 |
-
@app.on_event("startup")
|
10 |
-
async def startup_event():
|
11 |
-
global handler
|
12 |
-
handler = EndpointHandler()
|
13 |
-
|
14 |
-
@app.post("/")
|
15 |
-
async def process_request(request: Request):
|
16 |
-
body = await request.json()
|
17 |
-
response = handler(body)
|
18 |
-
return response
|
19 |
-
|
20 |
-
# handler.py
|
21 |
from typing import Dict, Any
|
22 |
from transformers import AutoProcessor, MusicgenForConditionalGeneration
|
23 |
import torch
|
24 |
|
25 |
class EndpointHandler:
|
26 |
-
def __init__(self, path="
|
27 |
-
|
28 |
self.processor = AutoProcessor.from_pretrained(path)
|
29 |
self.model = MusicgenForConditionalGeneration.from_pretrained(
|
30 |
-
path,
|
31 |
-
torch_dtype=torch.float16,
|
32 |
-
device_map="auto"
|
33 |
).to("cuda")
|
34 |
-
|
|
|
35 |
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
|
36 |
-
"""
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
63 |
}
|
64 |
-
|
65 |
-
# Update with any user-provided parameters
|
66 |
-
generation_params.update(parameters)
|
67 |
-
|
68 |
-
# Generate audio with autocast for memory efficiency
|
69 |
-
with torch.cuda.amp.autocast():
|
70 |
-
audio_values = self.model.generate(**model_inputs, **generation_params)
|
71 |
-
|
72 |
-
# Convert to list for JSON serialization
|
73 |
-
audio_data = audio_values.cpu().numpy().tolist()
|
74 |
-
|
75 |
-
return [{"generated_audio": audio_data}]
|
76 |
-
|
77 |
-
except Exception as e:
|
78 |
-
return {"error": str(e)}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
from typing import Dict, Any
|
2 |
from transformers import AutoProcessor, MusicgenForConditionalGeneration
|
3 |
import torch
|
4 |
|
5 |
class EndpointHandler:
|
6 |
+
def __init__(self, path=""):
|
7 |
+
# Load the processor and model from the specified path
|
8 |
self.processor = AutoProcessor.from_pretrained(path)
|
9 |
self.model = MusicgenForConditionalGeneration.from_pretrained(
|
10 |
+
path, torch_dtype=torch.float16
|
|
|
|
|
11 |
).to("cuda")
|
12 |
+
self.sampling_rate = self.model.config.audio_encoder.sampling_rate
|
13 |
+
|
14 |
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
|
15 |
+
"""
|
16 |
+
Args:
|
17 |
+
data (dict): The payload with the text prompt and generation parameters.
|
18 |
+
"""
|
19 |
+
# Extract inputs and parameters from the payload
|
20 |
+
inputs = data.get("inputs", {})
|
21 |
+
prompt = inputs.get("prompt", "")
|
22 |
+
duration = inputs.get("duration", 10)
|
23 |
+
parameters = data.get("parameters", {})
|
24 |
+
|
25 |
+
# Preprocess the prompt
|
26 |
+
input_ids = self.processor(
|
27 |
+
text=[prompt],
|
28 |
+
padding=True,
|
29 |
+
return_tensors="pt",
|
30 |
+
).to("cuda")
|
31 |
+
|
32 |
+
# Set generation parameters
|
33 |
+
gen_kwargs = {
|
34 |
+
"max_new_tokens": int(duration * 50), # MusicGen uses 50 tokens per second
|
35 |
+
**parameters,
|
36 |
+
}
|
37 |
+
|
38 |
+
# Generate audio
|
39 |
+
with torch.autocast("cuda"):
|
40 |
+
outputs = self.model.generate(**input_ids, **gen_kwargs)
|
41 |
+
|
42 |
+
# Convert the output audio tensor to a list of lists (channel-wise)
|
43 |
+
audio_tensor = outputs[0].cpu() # Shape: [num_channels, seq_len]
|
44 |
+
audio_list = audio_tensor.numpy().tolist() # [[channel1_data], [channel2_data]]
|
45 |
+
|
46 |
+
return [
|
47 |
+
{
|
48 |
+
"generated_audio": audio_list,
|
49 |
+
"sample_rate": self.sampling_rate,
|
50 |
}
|
51 |
+
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|