final changes
Browse files- app.py +1 -1
- tests/models/test_musicgen.py +2 -2
app.py
CHANGED
@@ -25,7 +25,7 @@ from audiocraft.models import MusicGen
|
|
25 |
|
26 |
MODEL = None # Last used model
|
27 |
IS_BATCHED = "facebook/MusicGen" in os.environ.get('SPACE_ID', '')
|
28 |
-
MAX_BATCH_SIZE =
|
29 |
BATCHED_DURATION = 15
|
30 |
INTERRUPTING = False
|
31 |
# We have to wrap subprocess call to clean a bit the log when using gr.make_waveform
|
|
|
25 |
|
26 |
MODEL = None # Last used model
|
27 |
IS_BATCHED = "facebook/MusicGen" in os.environ.get('SPACE_ID', '')
|
28 |
+
MAX_BATCH_SIZE = 12
|
29 |
BATCHED_DURATION = 15
|
30 |
INTERRUPTING = False
|
31 |
# We have to wrap subprocess call to clean a bit the log when using gr.make_waveform
|
tests/models/test_musicgen.py
CHANGED
@@ -13,7 +13,7 @@ from audiocraft.models import MusicGen
|
|
13 |
class TestSEANetModel:
|
14 |
def get_musicgen(self):
|
15 |
mg = MusicGen.get_pretrained(name='debug', device='cpu')
|
16 |
-
mg.set_generation_params(duration=2.0,
|
17 |
return mg
|
18 |
|
19 |
def test_base(self):
|
@@ -52,7 +52,7 @@ class TestSEANetModel:
|
|
52 |
def test_generate_long(self):
|
53 |
mg = self.get_musicgen()
|
54 |
mg.max_duration = 3.
|
55 |
-
mg.set_generation_params(duration=4.,
|
56 |
wav = mg.generate(
|
57 |
['youpi', 'lapin dort'])
|
58 |
assert list(wav.shape) == [2, 1, 32000 * 4]
|
|
|
13 |
class TestSEANetModel:
|
14 |
def get_musicgen(self):
|
15 |
mg = MusicGen.get_pretrained(name='debug', device='cpu')
|
16 |
+
mg.set_generation_params(duration=2.0, extend_stride=2.)
|
17 |
return mg
|
18 |
|
19 |
def test_base(self):
|
|
|
52 |
def test_generate_long(self):
|
53 |
mg = self.get_musicgen()
|
54 |
mg.max_duration = 3.
|
55 |
+
mg.set_generation_params(duration=4., extend_stride=2.)
|
56 |
wav = mg.generate(
|
57 |
['youpi', 'lapin dort'])
|
58 |
assert list(wav.shape) == [2, 1, 32000 * 4]
|