nit
Browse files
audiocraft/models/musicgen.py
CHANGED
@@ -270,8 +270,7 @@ class MusicGen:
|
|
270 |
torch.Tensor: Generated audio, of shape [B, C, T], T is defined by the generation params.
|
271 |
"""
|
272 |
total_gen_len = int(self.duration * self.frame_rate)
|
273 |
-
|
274 |
-
current_gen_offset = 0
|
275 |
|
276 |
def _progress_callback(generated_tokens: int, tokens_to_generate: int):
|
277 |
print(f'{current_gen_offset + generated_tokens: 6d} / {total_gen_len: 6d}', end='\r')
|
@@ -299,7 +298,7 @@ class MusicGen:
|
|
299 |
if prompt_tokens is not None:
|
300 |
all_tokens.append(prompt_tokens)
|
301 |
|
302 |
-
time_offset = 0
|
303 |
while time_offset < self.duration:
|
304 |
chunk_duration = min(self.duration - time_offset, self.max_duration)
|
305 |
max_gen_len = int(chunk_duration * self.frame_rate)
|
@@ -308,12 +307,15 @@ class MusicGen:
|
|
308 |
if wav_length == 0:
|
309 |
continue
|
310 |
# We will extend the wav periodically if it not long enough.
|
311 |
-
# we have to do it here
|
|
|
312 |
initial_position = int(time_offset * self.sample_rate)
|
313 |
-
wav_target_length = int(
|
314 |
positions = torch.arange(initial_position,
|
315 |
initial_position + wav_target_length, device=self.device)
|
316 |
-
attr.wav['self_wav'] =
|
|
|
|
|
317 |
with self.autocast:
|
318 |
gen_tokens = self.lm.generate(
|
319 |
prompt_tokens, attributes,
|
|
|
270 |
torch.Tensor: Generated audio, of shape [B, C, T], T is defined by the generation params.
|
271 |
"""
|
272 |
total_gen_len = int(self.duration * self.frame_rate)
|
273 |
+
current_gen_offset: int = 0
|
|
|
274 |
|
275 |
def _progress_callback(generated_tokens: int, tokens_to_generate: int):
|
276 |
print(f'{current_gen_offset + generated_tokens: 6d} / {total_gen_len: 6d}', end='\r')
|
|
|
298 |
if prompt_tokens is not None:
|
299 |
all_tokens.append(prompt_tokens)
|
300 |
|
301 |
+
time_offset = 0.
|
302 |
while time_offset < self.duration:
|
303 |
chunk_duration = min(self.duration - time_offset, self.max_duration)
|
304 |
max_gen_len = int(chunk_duration * self.frame_rate)
|
|
|
307 |
if wav_length == 0:
|
308 |
continue
|
309 |
# We will extend the wav periodically if it not long enough.
|
310 |
+
# we have to do it here rather than in conditioners.py as otherwise
|
311 |
+
# we wouldn't have the full wav.
|
312 |
initial_position = int(time_offset * self.sample_rate)
|
313 |
+
wav_target_length = int(self.max_duration * self.sample_rate)
|
314 |
positions = torch.arange(initial_position,
|
315 |
initial_position + wav_target_length, device=self.device)
|
316 |
+
attr.wav['self_wav'] = WavCondition(
|
317 |
+
ref_wav[0][:, positions % wav_length],
|
318 |
+
torch.full_like(ref_wav[1], wav_target_length))
|
319 |
with self.autocast:
|
320 |
gen_tokens = self.lm.generate(
|
321 |
prompt_tokens, attributes,
|