initial implem
Browse files
audiocraft/models/musicgen.py
CHANGED
@@ -36,10 +36,12 @@ class MusicGen:
|
|
36 |
used to map audio to invertible discrete representations.
|
37 |
lm (LMModel): Language model over discrete representations.
|
38 |
"""
|
39 |
-
def __init__(self, name: str, compression_model: CompressionModel, lm: LMModel
|
|
|
40 |
self.name = name
|
41 |
self.compression_model = compression_model
|
42 |
self.lm = lm
|
|
|
43 |
self.device = next(iter(lm.parameters())).device
|
44 |
self.generation_params: dict = {}
|
45 |
self.set_generation_params(duration=15) # 15 seconds by default
|
@@ -113,11 +115,10 @@ class MusicGen:
|
|
113 |
should we extend the audio each time. Larger values will mean less context is
|
114 |
preserved, and shorter value will require extra computations.
|
115 |
"""
|
116 |
-
|
117 |
-
assert extend_stride <= 25, "Keep at least 5 seconds of overlap!"
|
118 |
self.extend_stride = extend_stride
|
|
|
119 |
self.generation_params = {
|
120 |
-
'max_gen_len': int(duration * self.frame_rate),
|
121 |
'use_sampling': use_sampling,
|
122 |
'temp': temperature,
|
123 |
'top_k': top_k,
|
@@ -268,8 +269,12 @@ class MusicGen:
|
|
268 |
Returns:
|
269 |
torch.Tensor: Generated audio, of shape [B, C, T], T is defined by the generation params.
|
270 |
"""
|
|
|
|
|
|
|
|
|
271 |
def _progress_callback(generated_tokens: int, tokens_to_generate: int):
|
272 |
-
print(f'{generated_tokens: 6d} / {
|
273 |
|
274 |
if prompt_tokens is not None:
|
275 |
assert self.generation_params['max_gen_len'] > prompt_tokens.shape[-1], \
|
@@ -279,9 +284,46 @@ class MusicGen:
|
|
279 |
if progress:
|
280 |
callback = _progress_callback
|
281 |
|
282 |
-
|
283 |
-
|
284 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
285 |
|
286 |
# generate audio
|
287 |
assert gen_tokens.dim() == 3
|
|
|
36 |
used to map audio to invertible discrete representations.
|
37 |
lm (LMModel): Language model over discrete representations.
|
38 |
"""
|
39 |
+
def __init__(self, name: str, compression_model: CompressionModel, lm: LMModel,
|
40 |
+
max_duration: float = 30):
|
41 |
self.name = name
|
42 |
self.compression_model = compression_model
|
43 |
self.lm = lm
|
44 |
+
self.max_duration = max_duration
|
45 |
self.device = next(iter(lm.parameters())).device
|
46 |
self.generation_params: dict = {}
|
47 |
self.set_generation_params(duration=15) # 15 seconds by default
|
|
|
115 |
should we extend the audio each time. Larger values will mean less context is
|
116 |
preserved, and shorter value will require extra computations.
|
117 |
"""
|
118 |
+
assert extend_stride <= self.max_duration - 5, "Keep at least 5 seconds of overlap!"
|
|
|
119 |
self.extend_stride = extend_stride
|
120 |
+
self.duration = duration
|
121 |
self.generation_params = {
|
|
|
122 |
'use_sampling': use_sampling,
|
123 |
'temp': temperature,
|
124 |
'top_k': top_k,
|
|
|
269 |
Returns:
|
270 |
torch.Tensor: Generated audio, of shape [B, C, T], T is defined by the generation params.
|
271 |
"""
|
272 |
+
total_gen_len = int(self.duration * self.frame_rate)
|
273 |
+
|
274 |
+
current_gen_offset = 0
|
275 |
+
|
276 |
def _progress_callback(generated_tokens: int, tokens_to_generate: int):
|
277 |
+
print(f'{current_gen_offset + generated_tokens: 6d} / {total_gen_len: 6d}', end='\r')
|
278 |
|
279 |
if prompt_tokens is not None:
|
280 |
assert self.generation_params['max_gen_len'] > prompt_tokens.shape[-1], \
|
|
|
284 |
if progress:
|
285 |
callback = _progress_callback
|
286 |
|
287 |
+
if self.duration <= self.max_duration:
|
288 |
+
# generate by sampling from LM, simple case.
|
289 |
+
with self.autocast:
|
290 |
+
gen_tokens = self.lm.generate(
|
291 |
+
prompt_tokens, attributes,
|
292 |
+
callback=callback, max_gen_len=total_gen_len, **self.generation_params)
|
293 |
+
|
294 |
+
else:
|
295 |
+
# now this gets a bit messier, we need to handle prompts,
|
296 |
+
# melody conditioning etc.
|
297 |
+
ref_wavs = [attr.wav['self_wav'] for attr in attributes]
|
298 |
+
all_tokens = []
|
299 |
+
if prompt_tokens is not None:
|
300 |
+
all_tokens.append(prompt_tokens)
|
301 |
+
|
302 |
+
for time_offset in range(0, self.duration, self.extend_stride):
|
303 |
+
chunk_duration = min(self.duration - time_offset, self.max_duration)
|
304 |
+
max_gen_len = int(chunk_duration * self.frame_rate)
|
305 |
+
for attr, ref_wav in zip(attributes, ref_wavs):
|
306 |
+
wav_length = ref_wav.length.item()
|
307 |
+
if wav_length == 0:
|
308 |
+
continue
|
309 |
+
# We will extend the wav periodically if it not long enough.
|
310 |
+
# we have to do it here before it is too late.
|
311 |
+
initial_position = int(time_offset * self.sample_rate)
|
312 |
+
wav_target_length = int(chunk_duration * self.sample_rate)
|
313 |
+
positions = torch.arange(initial_position,
|
314 |
+
initial_position + wav_target_length, device=self.device)
|
315 |
+
attr.wav['self_wav'] = ref_wav[:, positions % wav_length]
|
316 |
+
with self.autocast:
|
317 |
+
gen_tokens = self.lm.generate(
|
318 |
+
prompt_tokens, attributes,
|
319 |
+
callback=callback, max_gen_len=max_gen_len, **self.generation_params)
|
320 |
+
stride_tokens = int(self.frame_rate * self.extend_stride)
|
321 |
+
if prompt_tokens is None:
|
322 |
+
all_tokens.append(gen_tokens)
|
323 |
+
else:
|
324 |
+
all_tokens.append(gen_tokens[:, :, prompt_tokens.shape[-1]:])
|
325 |
+
prompt_tokens = gen_tokens[:, :, stride_tokens]
|
326 |
+
gen_tokens = torch.cat(all_tokens, dim=-1)
|
327 |
|
328 |
# generate audio
|
329 |
assert gen_tokens.dim() == 3
|