nit
Browse files
audiocraft/models/musicgen.py
CHANGED
@@ -299,7 +299,8 @@ class MusicGen:
|
|
299 |
if prompt_tokens is not None:
|
300 |
all_tokens.append(prompt_tokens)
|
301 |
|
302 |
-
|
|
|
303 |
chunk_duration = min(self.duration - time_offset, self.max_duration)
|
304 |
max_gen_len = int(chunk_duration * self.frame_rate)
|
305 |
for attr, ref_wav in zip(attributes, ref_wavs):
|
@@ -323,6 +324,9 @@ class MusicGen:
|
|
323 |
else:
|
324 |
all_tokens.append(gen_tokens[:, :, prompt_tokens.shape[-1]:])
|
325 |
prompt_tokens = gen_tokens[:, :, stride_tokens]
|
|
|
|
|
|
|
326 |
gen_tokens = torch.cat(all_tokens, dim=-1)
|
327 |
|
328 |
# generate audio
|
|
|
299 |
if prompt_tokens is not None:
|
300 |
all_tokens.append(prompt_tokens)
|
301 |
|
302 |
+
time_offset = 0
|
303 |
+
while time_offset < self.duration:
|
304 |
chunk_duration = min(self.duration - time_offset, self.max_duration)
|
305 |
max_gen_len = int(chunk_duration * self.frame_rate)
|
306 |
for attr, ref_wav in zip(attributes, ref_wavs):
|
|
|
324 |
else:
|
325 |
all_tokens.append(gen_tokens[:, :, prompt_tokens.shape[-1]:])
|
326 |
prompt_tokens = gen_tokens[:, :, stride_tokens]
|
327 |
+
current_gen_offset += stride_tokens
|
328 |
+
time_offset += self.extend_stride
|
329 |
+
|
330 |
gen_tokens = torch.cat(all_tokens, dim=-1)
|
331 |
|
332 |
# generate audio
|