Spaces:
Running
on
T4
Running
on
T4
Update Overlap Action in Melody
Browse files- app.py +5 -3
- audiocraft/models/musicgen.py +55 -2
- audiocraft/utils/extend.py +36 -23
app.py
CHANGED
@@ -100,6 +100,8 @@ def predict(model, text, melody, duration, dimension, topk, topp, temperature, c
|
|
100 |
temperature=temperature,
|
101 |
cfg_coef=cfg_coef,
|
102 |
duration=segment_duration,
|
|
|
|
|
103 |
)
|
104 |
|
105 |
if melody:
|
@@ -201,7 +203,7 @@ def ui(**kwargs):
|
|
201 |
include_settings = gr.Checkbox(label="Add Settings to background", value=True, interactive=True)
|
202 |
with gr.Row():
|
203 |
title = gr.Textbox(label="Title", value="UnlimitedMusicGen", interactive=True)
|
204 |
-
settings_font = gr.Text(label="Settings Font", value="arial.ttf", interactive=True)
|
205 |
settings_font_color = gr.ColorPicker(label="Settings Font Color", value="#ffffff", interactive=True)
|
206 |
with gr.Row():
|
207 |
model = gr.Radio(["melody", "medium", "small", "large"], label="Model", value="melody", interactive=True)
|
@@ -212,8 +214,8 @@ def ui(**kwargs):
|
|
212 |
with gr.Row():
|
213 |
topk = gr.Number(label="Top-k", value=250, interactive=True)
|
214 |
topp = gr.Number(label="Top-p", value=0, interactive=True)
|
215 |
-
temperature = gr.Number(label="Randomness Temperature", value=
|
216 |
-
cfg_coef = gr.Number(label="Classifier Free Guidance", value=5.
|
217 |
with gr.Row():
|
218 |
seed = gr.Number(label="Seed", value=-1, precision=0, interactive=True)
|
219 |
gr.Button('\U0001f3b2\ufe0f').style(full_width=False).click(fn=lambda: -1, outputs=[seed], queue=False)
|
|
|
100 |
temperature=temperature,
|
101 |
cfg_coef=cfg_coef,
|
102 |
duration=segment_duration,
|
103 |
+
two_step_cfg=False,
|
104 |
+
rep_penalty=0.5
|
105 |
)
|
106 |
|
107 |
if melody:
|
|
|
203 |
include_settings = gr.Checkbox(label="Add Settings to background", value=True, interactive=True)
|
204 |
with gr.Row():
|
205 |
title = gr.Textbox(label="Title", value="UnlimitedMusicGen", interactive=True)
|
206 |
+
settings_font = gr.Text(label="Settings Font", value="./assets/arial.ttf", interactive=True)
|
207 |
settings_font_color = gr.ColorPicker(label="Settings Font Color", value="#ffffff", interactive=True)
|
208 |
with gr.Row():
|
209 |
model = gr.Radio(["melody", "medium", "small", "large"], label="Model", value="melody", interactive=True)
|
|
|
214 |
with gr.Row():
|
215 |
topk = gr.Number(label="Top-k", value=250, interactive=True)
|
216 |
topp = gr.Number(label="Top-p", value=0, interactive=True)
|
217 |
+
temperature = gr.Number(label="Randomness Temperature", value=0.75, precision=None, interactive=True)
|
218 |
+
cfg_coef = gr.Number(label="Classifier Free Guidance", value=5.5, precision=None, interactive=True)
|
219 |
with gr.Row():
|
220 |
seed = gr.Number(label="Seed", value=-1, precision=0, interactive=True)
|
221 |
gr.Button('\U0001f3b2\ufe0f').style(full_width=False).click(fn=lambda: -1, outputs=[seed], queue=False)
|
audiocraft/models/musicgen.py
CHANGED
@@ -97,7 +97,7 @@ class MusicGen:
|
|
97 |
def set_generation_params(self, use_sampling: bool = True, top_k: int = 250,
|
98 |
top_p: float = 0.0, temperature: float = 1.0,
|
99 |
duration: float = 30.0, cfg_coef: float = 3.0,
|
100 |
-
two_step_cfg: bool = False):
|
101 |
"""Set the generation parameters for MusicGen.
|
102 |
|
103 |
Args:
|
@@ -110,6 +110,7 @@ class MusicGen:
|
|
110 |
two_step_cfg (bool, optional): If True, performs 2 forward for Classifier Free Guidance,
|
111 |
instead of batching together the two. This has some impact on how things
|
112 |
are padded but seems to have little impact in practice.
|
|
|
113 |
"""
|
114 |
assert duration <= 30, "The MusicGen cannot generate more than 30 seconds"
|
115 |
self.generation_params = {
|
@@ -119,7 +120,7 @@ class MusicGen:
|
|
119 |
'top_k': top_k,
|
120 |
'top_p': top_p,
|
121 |
'cfg_coef': cfg_coef,
|
122 |
-
'two_step_cfg': two_step_cfg,
|
123 |
}
|
124 |
|
125 |
def generate_unconditional(self, num_samples: int, progress: bool = False) -> torch.Tensor:
|
@@ -177,6 +178,58 @@ class MusicGen:
|
|
177 |
assert prompt_tokens is None
|
178 |
return self._generate_tokens(attributes, prompt_tokens, progress)
|
179 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
180 |
def generate_continuation(self, prompt: torch.Tensor, prompt_sample_rate: int,
|
181 |
descriptions: tp.Optional[tp.List[tp.Optional[str]]] = None,
|
182 |
progress: bool = False) -> torch.Tensor:
|
|
|
97 |
def set_generation_params(self, use_sampling: bool = True, top_k: int = 250,
|
98 |
top_p: float = 0.0, temperature: float = 1.0,
|
99 |
duration: float = 30.0, cfg_coef: float = 3.0,
|
100 |
+
two_step_cfg: bool = False, rep_penalty: float = None):
|
101 |
"""Set the generation parameters for MusicGen.
|
102 |
|
103 |
Args:
|
|
|
110 |
two_step_cfg (bool, optional): If True, performs 2 forward for Classifier Free Guidance,
|
111 |
instead of batching together the two. This has some impact on how things
|
112 |
are padded but seems to have little impact in practice.
|
113 |
+
rep_penalty (float, optional): If set, use repetition penalty during generation. Not Implemented.
|
114 |
"""
|
115 |
assert duration <= 30, "The MusicGen cannot generate more than 30 seconds"
|
116 |
self.generation_params = {
|
|
|
120 |
'top_k': top_k,
|
121 |
'top_p': top_p,
|
122 |
'cfg_coef': cfg_coef,
|
123 |
+
'two_step_cfg': two_step_cfg,
|
124 |
}
|
125 |
|
126 |
def generate_unconditional(self, num_samples: int, progress: bool = False) -> torch.Tensor:
|
|
|
178 |
assert prompt_tokens is None
|
179 |
return self._generate_tokens(attributes, prompt_tokens, progress)
|
180 |
|
181 |
+
def generate_with_all(self, descriptions: tp.List[str], melody_wavs: MelodyType,
|
182 |
+
sample_rate: int, progress: bool = False, prompt: tp.Optional[torch.Tensor] = None) -> torch.Tensor:
|
183 |
+
"""Generate samples conditioned on text and melody and audio prompts.
|
184 |
+
Args:
|
185 |
+
descriptions (tp.List[str]): A list of strings used as text conditioning.
|
186 |
+
melody_wavs: (torch.Tensor or list of Tensor): A batch of waveforms used as
|
187 |
+
melody conditioning. Should have shape [B, C, T] with B matching the description length,
|
188 |
+
C=1 or 2. It can be [C, T] if there is a single description. It can also be
|
189 |
+
a list of [C, T] tensors.
|
190 |
+
sample_rate: (int): Sample rate of the melody waveforms.
|
191 |
+
progress (bool, optional): Flag to display progress of the generation process. Defaults to False.
|
192 |
+
prompt (torch.Tensor): A batch of waveforms used for continuation.
|
193 |
+
Prompt should be [B, C, T], or [C, T] if only one sample is generated.
|
194 |
+
"""
|
195 |
+
if isinstance(melody_wavs, torch.Tensor):
|
196 |
+
if melody_wavs.dim() == 2:
|
197 |
+
melody_wavs = melody_wavs[None]
|
198 |
+
if melody_wavs.dim() != 3:
|
199 |
+
raise ValueError("Melody wavs should have a shape [B, C, T].")
|
200 |
+
melody_wavs = list(melody_wavs)
|
201 |
+
else:
|
202 |
+
for melody in melody_wavs:
|
203 |
+
if melody is not None:
|
204 |
+
assert melody.dim() == 2, "One melody in the list has the wrong number of dims."
|
205 |
+
|
206 |
+
melody_wavs = [
|
207 |
+
convert_audio(wav, sample_rate, self.sample_rate, self.audio_channels)
|
208 |
+
if wav is not None else None
|
209 |
+
for wav in melody_wavs]
|
210 |
+
attributes, prompt_tokens = self._prepare_tokens_and_attributes(descriptions=descriptions, prompt=None,
|
211 |
+
melody_wavs=melody_wavs)
|
212 |
+
|
213 |
+
if prompt is not None:
|
214 |
+
if prompt.dim() == 2:
|
215 |
+
prompt = prompt[None]
|
216 |
+
if prompt.dim() != 3:
|
217 |
+
raise ValueError("prompt should have 3 dimensions: [B, C, T] (C = 1).")
|
218 |
+
prompt = convert_audio(prompt, sample_rate, self.sample_rate, self.audio_channels)
|
219 |
+
if descriptions is None:
|
220 |
+
descriptions = [None] * len(prompt)
|
221 |
+
|
222 |
+
if prompt is not None:
|
223 |
+
attributes_gen, prompt_tokens = self._prepare_tokens_and_attributes(descriptions, prompt)
|
224 |
+
|
225 |
+
#attributes, prompt_tokens = self._prepare_tokens_and_attributes(descriptions=descriptions, prompt=prompt,
|
226 |
+
# melody_wavs=melody_wavs)
|
227 |
+
if prompt is not None:
|
228 |
+
assert prompt_tokens is not None
|
229 |
+
else:
|
230 |
+
assert prompt_tokens is None
|
231 |
+
return self._generate_tokens(attributes, prompt_tokens, progress)
|
232 |
+
|
233 |
def generate_continuation(self, prompt: torch.Tensor, prompt_sample_rate: int,
|
234 |
descriptions: tp.Optional[tp.List[tp.Optional[str]]] = None,
|
235 |
progress: bool = False) -> torch.Tensor:
|
audiocraft/utils/extend.py
CHANGED
@@ -22,12 +22,15 @@ def separate_audio_segments(audio, segment_duration=30, overlap=1):
|
|
22 |
start_sample = 0
|
23 |
|
24 |
while total_samples >= segment_samples:
|
|
|
|
|
|
|
25 |
end_sample = start_sample + segment_samples
|
26 |
segment = audio_data[start_sample:end_sample]
|
27 |
segments.append((sr, segment))
|
28 |
|
29 |
start_sample += segment_samples - overlap_samples
|
30 |
-
total_samples -= segment_samples
|
31 |
|
32 |
# Collect the final segment
|
33 |
if total_samples > 0:
|
@@ -38,17 +41,16 @@ def separate_audio_segments(audio, segment_duration=30, overlap=1):
|
|
38 |
|
39 |
def generate_music_segments(text, melody, MODEL, seed, duration:int=10, overlap:int=1, segment_duration:int=30):
|
40 |
# generate audio segments
|
41 |
-
melody_segments = separate_audio_segments(melody, segment_duration,
|
42 |
|
43 |
# Create a list to store the melody tensors for each segment
|
44 |
melodys = []
|
45 |
output_segments = []
|
|
|
|
|
46 |
|
47 |
# Calculate the total number of segments
|
48 |
total_segments = max(math.ceil(duration / segment_duration),1)
|
49 |
-
# account for overlap
|
50 |
-
duration = duration + (max((total_segments - 1),0) * overlap)
|
51 |
-
total_segments = max(math.ceil(duration / segment_duration),1)
|
52 |
#calc excess duration
|
53 |
excess_duration = segment_duration - (total_segments * segment_duration - duration)
|
54 |
print(f"total Segments to Generate: {total_segments} for {duration} seconds. Each segment is {segment_duration} seconds. Excess {excess_duration}")
|
@@ -76,11 +78,15 @@ def generate_music_segments(text, melody, MODEL, seed, duration:int=10, overlap:
|
|
76 |
torch.manual_seed(seed)
|
77 |
for idx, verse in enumerate(melodys):
|
78 |
print(f"Generating New Melody Segment {idx + 1}: {text}\r")
|
79 |
-
|
|
|
|
|
|
|
80 |
descriptions=[text],
|
81 |
melody_wavs=verse,
|
82 |
-
|
83 |
-
progress=True
|
|
|
84 |
)
|
85 |
|
86 |
# Append the generated output to the list of segments
|
@@ -151,24 +157,31 @@ def load_font(font_name, font_size=16):
|
|
151 |
Example:
|
152 |
font = load_font("Arial.ttf", font_size=20)
|
153 |
"""
|
154 |
-
|
155 |
-
|
156 |
-
font = ImageFont.truetype(font_name, font_size)
|
157 |
-
except (FileNotFoundError, OSError):
|
158 |
try:
|
159 |
font = ImageFont.truetype(font_name, font_size)
|
160 |
-
|
161 |
-
|
|
|
162 |
try:
|
163 |
-
|
164 |
-
|
165 |
-
print("Font not found.
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
172 |
return font
|
173 |
|
174 |
|
|
|
22 |
start_sample = 0
|
23 |
|
24 |
while total_samples >= segment_samples:
|
25 |
+
# Collect the segment
|
26 |
+
# the end sample is the start sample plus the segment samples,
|
27 |
+
# the start sample, after 0, is minus the overlap samples to account for the overlap
|
28 |
end_sample = start_sample + segment_samples
|
29 |
segment = audio_data[start_sample:end_sample]
|
30 |
segments.append((sr, segment))
|
31 |
|
32 |
start_sample += segment_samples - overlap_samples
|
33 |
+
total_samples -= segment_samples
|
34 |
|
35 |
# Collect the final segment
|
36 |
if total_samples > 0:
|
|
|
41 |
|
42 |
def generate_music_segments(text, melody, MODEL, seed, duration:int=10, overlap:int=1, segment_duration:int=30):
|
43 |
# generate audio segments
|
44 |
+
melody_segments = separate_audio_segments(melody, segment_duration, 0)
|
45 |
|
46 |
# Create a list to store the melody tensors for each segment
|
47 |
melodys = []
|
48 |
output_segments = []
|
49 |
+
last_chunk = []
|
50 |
+
text += ", seed=" + str(seed)
|
51 |
|
52 |
# Calculate the total number of segments
|
53 |
total_segments = max(math.ceil(duration / segment_duration),1)
|
|
|
|
|
|
|
54 |
#calc excess duration
|
55 |
excess_duration = segment_duration - (total_segments * segment_duration - duration)
|
56 |
print(f"total Segments to Generate: {total_segments} for {duration} seconds. Each segment is {segment_duration} seconds. Excess {excess_duration}")
|
|
|
78 |
torch.manual_seed(seed)
|
79 |
for idx, verse in enumerate(melodys):
|
80 |
print(f"Generating New Melody Segment {idx + 1}: {text}\r")
|
81 |
+
if output_segments:
|
82 |
+
# If this isn't the first segment, use the last chunk of the previous segment as the input
|
83 |
+
last_chunk = output_segments[-1][:, :, -overlap*MODEL.sample_rate:]
|
84 |
+
output = MODEL.generate_with_all(
|
85 |
descriptions=[text],
|
86 |
melody_wavs=verse,
|
87 |
+
sample_rate=sr,
|
88 |
+
progress=True,
|
89 |
+
prompt=last_chunk if len(last_chunk) > 0 else None,
|
90 |
)
|
91 |
|
92 |
# Append the generated output to the list of segments
|
|
|
157 |
Example:
|
158 |
font = load_font("Arial.ttf", font_size=20)
|
159 |
"""
|
160 |
+
font = None
|
161 |
+
if not "http" in font_name:
|
|
|
|
|
162 |
try:
|
163 |
font = ImageFont.truetype(font_name, font_size)
|
164 |
+
except (FileNotFoundError, OSError):
|
165 |
+
print("Font not found. Trying to download from local assets folder...\n")
|
166 |
+
if font is None:
|
167 |
try:
|
168 |
+
font = ImageFont.truetype("assets/" + font_name, font_size)
|
169 |
+
except (FileNotFoundError, OSError):
|
170 |
+
print("Font not found. Trying to download from URL...\n")
|
171 |
+
|
172 |
+
if font is None:
|
173 |
+
try:
|
174 |
+
req = requests.get(font_name)
|
175 |
+
font = ImageFont.truetype(BytesIO(req.content), font_size)
|
176 |
+
except (FileNotFoundError, OSError):
|
177 |
+
print(f"Font found: {font_name} Using Hugging Face download font\n")
|
178 |
+
|
179 |
+
if font is None:
|
180 |
+
try:
|
181 |
+
font = ImageFont.truetype(hf_hub_download("assets", font_name), encoding="UTF-8")
|
182 |
+
except (FileNotFoundError, OSError):
|
183 |
+
font = ImageFont.load_default()
|
184 |
+
print(f"Font not found: {font_name} Using default font\n")
|
185 |
return font
|
186 |
|
187 |
|