Spaces:
Running
on
Zero
Running
on
Zero
Commit
·
3c30d79
1
Parent(s):
fcbf149
Added: Music Duration Control from the client end
Browse files
client.py
CHANGED
@@ -16,14 +16,16 @@ parser.add_argument(
|
|
16 |
parser.add_argument(
|
17 |
"--output_file", type=str, default="output.wav", help="Output file name"
|
18 |
)
|
|
|
|
|
|
|
19 |
|
20 |
args = parser.parse_args()
|
21 |
|
22 |
-
|
23 |
-
def generate_music(server_url, prompts, output_file):
|
24 |
url = f"{server_url}/generate_music"
|
25 |
headers = {"Content-Type": "application/json"}
|
26 |
-
data = {"prompts": prompts}
|
27 |
|
28 |
response = requests.post(url, json=data, headers=headers)
|
29 |
|
@@ -34,6 +36,5 @@ def generate_music(server_url, prompts, output_file):
|
|
34 |
else:
|
35 |
print(f"Failed to generate music: {response.status_code}, {response.text}")
|
36 |
|
37 |
-
|
38 |
if __name__ == "__main__":
|
39 |
-
generate_music(args.server_url, args.prompts, args.output_file)
|
|
|
16 |
parser.add_argument(
|
17 |
"--output_file", type=str, default="output.wav", help="Output file name"
|
18 |
)
|
19 |
+
parser.add_argument(
|
20 |
+
"--duration", type=int, default=10, help="Duration of generated music in seconds"
|
21 |
+
)
|
22 |
|
23 |
args = parser.parse_args()
|
24 |
|
25 |
+
def generate_music(server_url, prompts, duration, output_file):
|
|
|
26 |
url = f"{server_url}/generate_music"
|
27 |
headers = {"Content-Type": "application/json"}
|
28 |
+
data = {"prompts": prompts, "duration": duration}
|
29 |
|
30 |
response = requests.post(url, json=data, headers=headers)
|
31 |
|
|
|
36 |
else:
|
37 |
print(f"Failed to generate music: {response.status_code}, {response.text}")
|
38 |
|
|
|
39 |
if __name__ == "__main__":
|
40 |
+
generate_music(args.server_url, args.prompts, args.duration, args.output_file)
|
server.py
CHANGED
@@ -2,7 +2,7 @@ import warnings
|
|
2 |
import argparse
|
3 |
from fastapi import FastAPI, HTTPException
|
4 |
from pydantic import BaseModel
|
5 |
-
from typing import List
|
6 |
import torch
|
7 |
from audiocraft.models import musicgen
|
8 |
import numpy as np
|
@@ -17,7 +17,6 @@ warnings.simplefilter('ignore')
|
|
17 |
parser = argparse.ArgumentParser(description="Music Generation Server")
|
18 |
parser.add_argument("--model_name", type=str, default="small", help="Pretrained model name")
|
19 |
parser.add_argument("--device", type=str, default="cuda", help="Device to load the model on")
|
20 |
-
parser.add_argument("--duration", type=int, default=10, help="Duration of generated music in seconds")
|
21 |
parser.add_argument("--host", type=str, default="0.0.0.0", help="Host to run the server on")
|
22 |
parser.add_argument("--port", type=int, default=8000, help="Port to run the server on")
|
23 |
|
@@ -28,14 +27,15 @@ app = FastAPI()
|
|
28 |
|
29 |
# Load the model with the provided arguments
|
30 |
musicgen_model = musicgen.MusicGen.get_pretrained(args.model_name, device=args.device)
|
31 |
-
musicgen_model.set_generation_params(duration=args.duration)
|
32 |
|
33 |
class MusicRequest(BaseModel):
|
34 |
prompts: List[str]
|
|
|
35 |
|
36 |
@app.post("/generate_music")
|
37 |
def generate_music(request: MusicRequest):
|
38 |
try:
|
|
|
39 |
result = musicgen_model.generate(request.prompts, progress=False)
|
40 |
result = result.squeeze().cpu().numpy()
|
41 |
|
|
|
2 |
import argparse
|
3 |
from fastapi import FastAPI, HTTPException
|
4 |
from pydantic import BaseModel
|
5 |
+
from typing import List, Optional
|
6 |
import torch
|
7 |
from audiocraft.models import musicgen
|
8 |
import numpy as np
|
|
|
17 |
parser = argparse.ArgumentParser(description="Music Generation Server")
|
18 |
parser.add_argument("--model_name", type=str, default="small", help="Pretrained model name")
|
19 |
parser.add_argument("--device", type=str, default="cuda", help="Device to load the model on")
|
|
|
20 |
parser.add_argument("--host", type=str, default="0.0.0.0", help="Host to run the server on")
|
21 |
parser.add_argument("--port", type=int, default=8000, help="Port to run the server on")
|
22 |
|
|
|
27 |
|
28 |
# Load the model with the provided arguments
|
29 |
musicgen_model = musicgen.MusicGen.get_pretrained(args.model_name, device=args.device)
|
|
|
30 |
|
31 |
class MusicRequest(BaseModel):
|
32 |
prompts: List[str]
|
33 |
+
duration: Optional[int] = 10 # Default duration is 10 seconds if not provided
|
34 |
|
35 |
@app.post("/generate_music")
|
36 |
def generate_music(request: MusicRequest):
|
37 |
try:
|
38 |
+
musicgen_model.set_generation_params(duration=request.duration)
|
39 |
result = musicgen_model.generate(request.prompts, progress=False)
|
40 |
result = result.squeeze().cpu().numpy()
|
41 |
|