plop
Browse files
audiocraft/models/musicgen.py
CHANGED
@@ -96,7 +96,7 @@ class MusicGen:
|
|
96 |
def set_generation_params(self, use_sampling: bool = True, top_k: int = 250,
|
97 |
top_p: float = 0.0, temperature: float = 1.0,
|
98 |
duration: float = 30.0, cfg_coef: float = 3.0,
|
99 |
-
two_step_cfg: bool = False):
|
100 |
"""Set the generation parameters for MusicGen.
|
101 |
|
102 |
Args:
|
@@ -109,8 +109,13 @@ class MusicGen:
|
|
109 |
two_step_cfg (bool, optional): If True, performs 2 forward for Classifier Free Guidance,
|
110 |
instead of batching together the two. This has some impact on how things
|
111 |
are padded but seems to have little impact in practice.
|
|
|
|
|
|
|
112 |
"""
|
113 |
-
assert duration <= 30, "The MusicGen cannot generate more than 30 seconds"
|
|
|
|
|
114 |
self.generation_params = {
|
115 |
'max_gen_len': int(duration * self.frame_rate),
|
116 |
'use_sampling': use_sampling,
|
|
|
96 |
def set_generation_params(self, use_sampling: bool = True, top_k: int = 250,
|
97 |
top_p: float = 0.0, temperature: float = 1.0,
|
98 |
duration: float = 30.0, cfg_coef: float = 3.0,
|
99 |
+
two_step_cfg: bool = False, extend_stride: float = 15):
|
100 |
"""Set the generation parameters for MusicGen.
|
101 |
|
102 |
Args:
|
|
|
109 |
two_step_cfg (bool, optional): If True, performs 2 forward for Classifier Free Guidance,
|
110 |
instead of batching together the two. This has some impact on how things
|
111 |
are padded but seems to have little impact in practice.
|
112 |
+
extend_stride: when doing extended generation (i.e. more than 30 seconds), by how much
|
113 |
+
should we extend the audio each time. Larger values will mean less context is
|
114 |
+
preserved, and shorter value will require extra computations.
|
115 |
"""
|
116 |
+
# assert duration <= 30, "The MusicGen cannot generate more than 30 seconds"
|
117 |
+
assert extend_stride <= 25, "Keep at least 5 seconds of overlap!"
|
118 |
+
self.extend_stride = extend_stride
|
119 |
self.generation_params = {
|
120 |
'max_gen_len': int(duration * self.frame_rate),
|
121 |
'use_sampling': use_sampling,
|