animikhaich commited on
Commit
3c30d79
·
1 Parent(s): fcbf149

Added: Music Duration Control from the client end

Browse files
Files changed (2) hide show
  1. client.py +6 -5
  2. server.py +3 -3
client.py CHANGED
@@ -16,14 +16,16 @@ parser.add_argument(
16
  parser.add_argument(
17
  "--output_file", type=str, default="output.wav", help="Output file name"
18
  )
 
 
 
19
 
20
  args = parser.parse_args()
21
 
22
-
23
- def generate_music(server_url, prompts, output_file):
24
  url = f"{server_url}/generate_music"
25
  headers = {"Content-Type": "application/json"}
26
- data = {"prompts": prompts}
27
 
28
  response = requests.post(url, json=data, headers=headers)
29
 
@@ -34,6 +36,5 @@ def generate_music(server_url, prompts, output_file):
34
  else:
35
  print(f"Failed to generate music: {response.status_code}, {response.text}")
36
 
37
-
38
  if __name__ == "__main__":
39
- generate_music(args.server_url, args.prompts, args.output_file)
 
16
  parser.add_argument(
17
  "--output_file", type=str, default="output.wav", help="Output file name"
18
  )
19
+ parser.add_argument(
20
+ "--duration", type=int, default=10, help="Duration of generated music in seconds"
21
+ )
22
 
23
  args = parser.parse_args()
24
 
25
+ def generate_music(server_url, prompts, duration, output_file):
 
26
  url = f"{server_url}/generate_music"
27
  headers = {"Content-Type": "application/json"}
28
+ data = {"prompts": prompts, "duration": duration}
29
 
30
  response = requests.post(url, json=data, headers=headers)
31
 
 
36
  else:
37
  print(f"Failed to generate music: {response.status_code}, {response.text}")
38
 
 
39
  if __name__ == "__main__":
40
+ generate_music(args.server_url, args.prompts, args.duration, args.output_file)
server.py CHANGED
@@ -2,7 +2,7 @@ import warnings
2
  import argparse
3
  from fastapi import FastAPI, HTTPException
4
  from pydantic import BaseModel
5
- from typing import List
6
  import torch
7
  from audiocraft.models import musicgen
8
  import numpy as np
@@ -17,7 +17,6 @@ warnings.simplefilter('ignore')
17
  parser = argparse.ArgumentParser(description="Music Generation Server")
18
  parser.add_argument("--model_name", type=str, default="small", help="Pretrained model name")
19
  parser.add_argument("--device", type=str, default="cuda", help="Device to load the model on")
20
- parser.add_argument("--duration", type=int, default=10, help="Duration of generated music in seconds")
21
  parser.add_argument("--host", type=str, default="0.0.0.0", help="Host to run the server on")
22
  parser.add_argument("--port", type=int, default=8000, help="Port to run the server on")
23
 
@@ -28,14 +27,15 @@ app = FastAPI()
28
 
29
  # Load the model with the provided arguments
30
  musicgen_model = musicgen.MusicGen.get_pretrained(args.model_name, device=args.device)
31
- musicgen_model.set_generation_params(duration=args.duration)
32
 
33
  class MusicRequest(BaseModel):
34
  prompts: List[str]
 
35
 
36
  @app.post("/generate_music")
37
  def generate_music(request: MusicRequest):
38
  try:
 
39
  result = musicgen_model.generate(request.prompts, progress=False)
40
  result = result.squeeze().cpu().numpy()
41
 
 
2
  import argparse
3
  from fastapi import FastAPI, HTTPException
4
  from pydantic import BaseModel
5
+ from typing import List, Optional
6
  import torch
7
  from audiocraft.models import musicgen
8
  import numpy as np
 
17
  parser = argparse.ArgumentParser(description="Music Generation Server")
18
  parser.add_argument("--model_name", type=str, default="small", help="Pretrained model name")
19
  parser.add_argument("--device", type=str, default="cuda", help="Device to load the model on")
 
20
  parser.add_argument("--host", type=str, default="0.0.0.0", help="Host to run the server on")
21
  parser.add_argument("--port", type=int, default=8000, help="Port to run the server on")
22
 
 
27
 
28
  # Load the model with the provided arguments
29
  musicgen_model = musicgen.MusicGen.get_pretrained(args.model_name, device=args.device)
 
30
 
31
  class MusicRequest(BaseModel):
32
  prompts: List[str]
33
+ duration: Optional[int] = 10 # Default duration is 10 seconds if not provided
34
 
35
  @app.post("/generate_music")
36
  def generate_music(request: MusicRequest):
37
  try:
38
+ musicgen_model.set_generation_params(duration=request.duration)
39
  result = musicgen_model.generate(request.prompts, progress=False)
40
  result = result.squeeze().cpu().numpy()
41