animikhaich commited on
Commit
2deb721
·
1 Parent(s): 3c30d79

Added: Stereo Music Generation Support

Browse files
Files changed (1) hide show
  1. server.py +10 -2
server.py CHANGED
@@ -15,8 +15,9 @@ warnings.simplefilter('ignore')
15
 
16
  # Parse command line arguments
17
  parser = argparse.ArgumentParser(description="Music Generation Server")
18
- parser.add_argument("--model_name", type=str, default="small", help="Pretrained model name")
19
  parser.add_argument("--device", type=str, default="cuda", help="Device to load the model on")
 
20
  parser.add_argument("--host", type=str, default="0.0.0.0", help="Host to run the server on")
21
  parser.add_argument("--port", type=int, default=8000, help="Port to run the server on")
22
 
@@ -25,6 +26,12 @@ args = parser.parse_args()
25
  # Initialize the FastAPI app
26
  app = FastAPI()
27
 
 
 
 
 
 
 
28
  # Load the model with the provided arguments
29
  musicgen_model = musicgen.MusicGen.get_pretrained(args.model_name, device=args.device)
30
 
@@ -37,12 +44,13 @@ def generate_music(request: MusicRequest):
37
  try:
38
  musicgen_model.set_generation_params(duration=request.duration)
39
  result = musicgen_model.generate(request.prompts, progress=False)
40
- result = result.squeeze().cpu().numpy()
41
 
42
  sample_rate = musicgen_model.sample_rate
43
 
44
  buffer = io.BytesIO()
45
  wav_write(buffer, sample_rate, result)
 
46
  buffer.seek(0)
47
 
48
  return StreamingResponse(buffer, media_type="audio/wav")
 
15
 
16
  # Parse command line arguments
17
  parser = argparse.ArgumentParser(description="Music Generation Server")
18
+ parser.add_argument("--model", type=str, default="musicgen-stereo-small", help="Pretrained model name")
19
  parser.add_argument("--device", type=str, default="cuda", help="Device to load the model on")
20
+ parser.add_argument("--duration", type=int, default=10, help="Duration of generated music in seconds")
21
  parser.add_argument("--host", type=str, default="0.0.0.0", help="Host to run the server on")
22
  parser.add_argument("--port", type=int, default=8000, help="Port to run the server on")
23
 
 
26
  # Initialize the FastAPI app
27
  app = FastAPI()
28
 
29
+ # Build the model name based on the provided arguments
30
+ if args.model.startswith('facebook/'):
31
+ args.model_name = args.model
32
+ else:
33
+ args.model_name = f"facebook/{args.model}"
34
+
35
  # Load the model with the provided arguments
36
  musicgen_model = musicgen.MusicGen.get_pretrained(args.model_name, device=args.device)
37
 
 
44
  try:
45
  musicgen_model.set_generation_params(duration=request.duration)
46
  result = musicgen_model.generate(request.prompts, progress=False)
47
+ result = result.squeeze().cpu().numpy().T
48
 
49
  sample_rate = musicgen_model.sample_rate
50
 
51
  buffer = io.BytesIO()
52
  wav_write(buffer, sample_rate, result)
53
+
54
  buffer.seek(0)
55
 
56
  return StreamingResponse(buffer, media_type="audio/wav")