animikhaich commited on
Commit
4b810e5
·
1 Parent(s): 0391cf5

Added base Server and Client for MusicGen

Browse files
Files changed (4) hide show
  1. .gitignore +167 -0
  2. client.py +39 -0
  3. run_test.sh +28 -0
  4. server.py +53 -0
.gitignore ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # Byte-compiled / optimized / DLL files
3
+ __pycache__/
4
+ *.py[cod]
5
+ *$py.class
6
+
7
+ # C extensions
8
+ *.so
9
+
10
+ # Distribution / packaging
11
+ .Python
12
+ build/
13
+ develop-eggs/
14
+ dist/
15
+ downloads/
16
+ eggs/
17
+ .eggs/
18
+ lib/
19
+ lib64/
20
+ parts/
21
+ sdist/
22
+ var/
23
+ wheels/
24
+ share/python-wheels/
25
+ *.egg-info/
26
+ .installed.cfg
27
+ *.egg
28
+ MANIFEST
29
+
30
+ # PyInstaller
31
+ # Usually these files are written by a python script from a template
32
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
33
+ *.manifest
34
+ *.spec
35
+
36
+ # Installer logs
37
+ pip-log.txt
38
+ pip-delete-this-directory.txt
39
+
40
+ # Unit test / coverage reports
41
+ htmlcov/
42
+ .tox/
43
+ .nox/
44
+ .coverage
45
+ .coverage.*
46
+ .cache
47
+ nosetests.xml
48
+ coverage.xml
49
+ *.cover
50
+ *.py,cover
51
+ .hypothesis/
52
+ .pytest_cache/
53
+ cover/
54
+
55
+ # Translations
56
+ *.mo
57
+ *.pot
58
+
59
+ # Django stuff:
60
+ *.log
61
+ local_settings.py
62
+ db.sqlite3
63
+ db.sqlite3-journal
64
+
65
+ # Flask stuff:
66
+ instance/
67
+ .webassets-cache
68
+
69
+ # Scrapy stuff:
70
+ .scrapy
71
+
72
+ # Sphinx documentation
73
+ docs/_build/
74
+
75
+ # PyBuilder
76
+ .pybuilder/
77
+ target/
78
+
79
+ # Jupyter Notebook
80
+ .ipynb_checkpoints
81
+
82
+ # IPython
83
+ profile_default/
84
+ ipython_config.py
85
+
86
+ # pyenv
87
+ # For a library or package, you might want to ignore these files since the code is
88
+ # intended to run in multiple environments; otherwise, check them in:
89
+ # .python-version
90
+
91
+ # pipenv
92
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
93
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
94
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
95
+ # install all needed dependencies.
96
+ #Pipfile.lock
97
+
98
+ # poetry
99
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
100
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
101
+ # commonly ignored for libraries.
102
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
103
+ #poetry.lock
104
+
105
+ # pdm
106
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
107
+ #pdm.lock
108
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
109
+ # in version control.
110
+ # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
111
+ .pdm.toml
112
+ .pdm-python
113
+ .pdm-build/
114
+
115
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
116
+ __pypackages__/
117
+
118
+ # Celery stuff
119
+ celerybeat-schedule
120
+ celerybeat.pid
121
+
122
+ # SageMath parsed files
123
+ *.sage.py
124
+
125
+ # Environments
126
+ .env
127
+ .venv
128
+ env/
129
+ venv/
130
+ ENV/
131
+ env.bak/
132
+ venv.bak/
133
+
134
+ # Spyder project settings
135
+ .spyderproject
136
+ .spyproject
137
+
138
+ # Rope project settings
139
+ .ropeproject
140
+
141
+ # mkdocs documentation
142
+ /site
143
+
144
+ # mypy
145
+ .mypy_cache/
146
+ .dmypy.json
147
+ dmypy.json
148
+
149
+ # Pyre type checker
150
+ .pyre/
151
+
152
+ # pytype static type analyzer
153
+ .pytype/
154
+
155
+ # Cython debug symbols
156
+ cython_debug/
157
+
158
+ # PyCharm
159
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
160
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
161
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
162
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
163
+ #.idea/
164
+
165
+
166
+ *.wav
167
+ *.mp3
client.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+ import argparse
3
+
4
+ # Parse command line arguments
5
+ parser = argparse.ArgumentParser(description="Music Generation Client")
6
+ parser.add_argument(
7
+ "--server_url", type=str, default="http://localhost:8000", help="URL of the server"
8
+ )
9
+ parser.add_argument(
10
+ "--prompts",
11
+ nargs="+",
12
+ type=str,
13
+ default=["Lofi Music for Coding"],
14
+ help="Prompts for music generation",
15
+ )
16
+ parser.add_argument(
17
+ "--output_file", type=str, default="output.wav", help="Output file name"
18
+ )
19
+
20
+ args = parser.parse_args()
21
+
22
+
23
+ def generate_music(server_url, prompts, output_file):
24
+ url = f"{server_url}/generate_music"
25
+ headers = {"Content-Type": "application/json"}
26
+ data = {"prompts": prompts}
27
+
28
+ response = requests.post(url, json=data, headers=headers)
29
+
30
+ if response.status_code == 200:
31
+ with open(output_file, "wb") as f:
32
+ f.write(response.content)
33
+ print(f"Music saved to {output_file}")
34
+ else:
35
+ print(f"Failed to generate music: {response.status_code}, {response.text}")
36
+
37
+
38
+ if __name__ == "__main__":
39
+ generate_music(args.server_url, args.prompts, args.output_file)
run_test.sh ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ echo "Script started."
4
+
5
+ # Run server
6
+ echo "Starting server..."
7
+ python server.py --duration 10 &
8
+ echo "Server started."
9
+
10
+ # Sleep
11
+ echo "Waiting for the server to startup..."
12
+ sleep 10
13
+
14
+ # Run client
15
+ echo "Starting client..."
16
+ python client.py --server_url http://localhost:8000 --prompts "Lofi Music for Coding" --output_file output.wav
17
+ echo "Client finished."
18
+
19
+
20
+ # Kill server
21
+ echo "Killing server..."
22
+ kill $(ps aux | grep 'server.py' | awk '{print $2}')
23
+
24
+
25
+ # Done
26
+ sleep 5
27
+ echo "Script finished."
28
+
server.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+ import argparse
3
+ from fastapi import FastAPI, HTTPException
4
+ from pydantic import BaseModel
5
+ from typing import List
6
+ import torch
7
+ from audiocraft.models import musicgen
8
+ import numpy as np
9
+ import io
10
+ from fastapi.responses import StreamingResponse
11
+ from scipy.io.wavfile import write as wav_write
12
+ import uvicorn
13
+
14
+ warnings.simplefilter('ignore')
15
+
16
+ # Parse command line arguments
17
+ parser = argparse.ArgumentParser(description="Music Generation Server")
18
+ parser.add_argument("--model_name", type=str, default="small", help="Pretrained model name")
19
+ parser.add_argument("--device", type=str, default="cuda", help="Device to load the model on")
20
+ parser.add_argument("--duration", type=int, default=10, help="Duration of generated music in seconds")
21
+ parser.add_argument("--host", type=str, default="0.0.0.0", help="Host to run the server on")
22
+ parser.add_argument("--port", type=int, default=8000, help="Port to run the server on")
23
+
24
+ args = parser.parse_args()
25
+
26
+ # Initialize the FastAPI app
27
+ app = FastAPI()
28
+
29
+ # Load the model with the provided arguments
30
+ musicgen_model = musicgen.MusicGen.get_pretrained(args.model_name, device=args.device)
31
+ musicgen_model.set_generation_params(duration=args.duration)
32
+
33
+ class MusicRequest(BaseModel):
34
+ prompts: List[str]
35
+
36
+ @app.post("/generate_music")
37
+ def generate_music(request: MusicRequest):
38
+ try:
39
+ result = musicgen_model.generate(request.prompts, progress=False)
40
+ result = result.squeeze().cpu().numpy()
41
+
42
+ sample_rate = musicgen_model.sample_rate
43
+
44
+ buffer = io.BytesIO()
45
+ wav_write(buffer, sample_rate, result)
46
+ buffer.seek(0)
47
+
48
+ return StreamingResponse(buffer, media_type="audio/wav")
49
+ except Exception as e:
50
+ raise HTTPException(status_code=500, detail=str(e))
51
+
52
+ if __name__ == "__main__":
53
+ uvicorn.run(app, host=args.host, port=args.port)