nit
Browse files
audiocraft/models/musicgen.py
CHANGED
@@ -115,7 +115,7 @@ class MusicGen:
|
|
115 |
should we extend the audio each time. Larger values will mean less context is
|
116 |
preserved, and shorter value will require extra computations.
|
117 |
"""
|
118 |
-
assert extend_stride
|
119 |
self.extend_stride = extend_stride
|
120 |
self.duration = duration
|
121 |
self.generation_params = {
|
|
|
115 |
should we extend the audio each time. Larger values will mean less context is
|
116 |
preserved, and shorter value will require extra computations.
|
117 |
"""
|
118 |
+
assert extend_stride < self.max_duration, "Cannot stride by more than max generation duration."
|
119 |
self.extend_stride = extend_stride
|
120 |
self.duration = duration
|
121 |
self.generation_params = {
|
tests/models/test_musicgen.py
CHANGED
@@ -13,7 +13,7 @@ from audiocraft.models import MusicGen
|
|
13 |
class TestSEANetModel:
|
14 |
def get_musicgen(self):
|
15 |
mg = MusicGen.get_pretrained(name='debug', device='cpu')
|
16 |
-
mg.set_generation_params(duration=2.0)
|
17 |
return mg
|
18 |
|
19 |
def test_base(self):
|
@@ -51,7 +51,7 @@ class TestSEANetModel:
|
|
51 |
|
52 |
def test_generate_long(self):
|
53 |
mg = self.get_musicgen()
|
54 |
-
mg.set_generation_params(duration=4.)
|
55 |
wav = mg.generate(
|
56 |
['youpi', 'lapin dort'])
|
57 |
assert list(wav.shape) == [2, 1, 32000 * 4]
|
|
|
13 |
class TestSEANetModel:
|
14 |
def get_musicgen(self):
|
15 |
mg = MusicGen.get_pretrained(name='debug', device='cpu')
|
16 |
+
mg.set_generation_params(duration=2.0, stride_extend=2.)
|
17 |
return mg
|
18 |
|
19 |
def test_base(self):
|
|
|
51 |
|
52 |
def test_generate_long(self):
|
53 |
mg = self.get_musicgen()
|
54 |
+
mg.set_generation_params(duration=4., stride_extend=2.)
|
55 |
wav = mg.generate(
|
56 |
['youpi', 'lapin dort'])
|
57 |
assert list(wav.shape) == [2, 1, 32000 * 4]
|