nit
Browse files
audiocraft/models/musicgen.py
CHANGED
@@ -79,7 +79,7 @@ class MusicGen:
|
|
79 |
# used only for unit tests
|
80 |
compression_model = get_debug_compression_model(device)
|
81 |
lm = get_debug_lm_model(device)
|
82 |
-
return MusicGen(name, compression_model, lm)
|
83 |
|
84 |
if name not in HF_MODEL_CHECKPOINTS_MAP:
|
85 |
raise ValueError(
|
@@ -270,13 +270,14 @@ class MusicGen:
|
|
270 |
torch.Tensor: Generated audio, of shape [B, C, T], T is defined by the generation params.
|
271 |
"""
|
272 |
total_gen_len = int(self.duration * self.frame_rate)
|
|
|
273 |
current_gen_offset: int = 0
|
274 |
|
275 |
def _progress_callback(generated_tokens: int, tokens_to_generate: int):
|
276 |
print(f'{current_gen_offset + generated_tokens: 6d} / {total_gen_len: 6d}', end='\r')
|
277 |
|
278 |
if prompt_tokens is not None:
|
279 |
-
assert
|
280 |
"Prompt is longer than audio to generate"
|
281 |
|
282 |
callback = None
|
|
|
79 |
# used only for unit tests
|
80 |
compression_model = get_debug_compression_model(device)
|
81 |
lm = get_debug_lm_model(device)
|
82 |
+
return MusicGen(name, compression_model, lm, max_duration=3.)
|
83 |
|
84 |
if name not in HF_MODEL_CHECKPOINTS_MAP:
|
85 |
raise ValueError(
|
|
|
270 |
torch.Tensor: Generated audio, of shape [B, C, T], T is defined by the generation params.
|
271 |
"""
|
272 |
total_gen_len = int(self.duration * self.frame_rate)
|
273 |
+
max_prompt_len = int(min(self.duration, self.max_duration) * self.frame_rate)
|
274 |
current_gen_offset: int = 0
|
275 |
|
276 |
def _progress_callback(generated_tokens: int, tokens_to_generate: int):
|
277 |
print(f'{current_gen_offset + generated_tokens: 6d} / {total_gen_len: 6d}', end='\r')
|
278 |
|
279 |
if prompt_tokens is not None:
|
280 |
+
assert max_prompt_len >= prompt_tokens.shape[-1], \
|
281 |
"Prompt is longer than audio to generate"
|
282 |
|
283 |
callback = None
|
tests/models/test_musicgen.py
CHANGED
@@ -48,3 +48,10 @@ class TestSEANetModel:
|
|
48 |
wav = mg.generate(
|
49 |
['youpi', 'lapin dort'])
|
50 |
assert list(wav.shape) == [2, 1, 64000]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
48 |
wav = mg.generate(
|
49 |
['youpi', 'lapin dort'])
|
50 |
assert list(wav.shape) == [2, 1, 64000]
|
51 |
+
|
52 |
+
def test_generate_long(self):
|
53 |
+
mg = self.get_musicgen()
|
54 |
+
mg.set_generation_params(duration=4.)
|
55 |
+
wav = mg.generate(
|
56 |
+
['youpi', 'lapin dort'])
|
57 |
+
assert list(wav.shape) == [2, 1, 32000 * 4]
|