revert pattern preserves 4
Browse files- audiocraft/builders.py +6 -18
- audiocraft/conditioners.py +1 -31
- audiocraft/lm.py +86 -142
- demo.py +5 -3
audiocraft/builders.py
CHANGED
@@ -11,7 +11,6 @@ from .lm import LMModel
|
|
11 |
from .seanet import SEANetDecoder
|
12 |
from .codebooks_patterns import DelayedPatternProvider
|
13 |
from .conditioners import (
|
14 |
-
ConditionFuser,
|
15 |
ConditioningProvider,
|
16 |
T5Conditioner,
|
17 |
ConditioningAttributes
|
@@ -78,11 +77,9 @@ class AudioGen(nn.Module):
|
|
78 |
ConditioningAttributes(text={'description': d}) for d in descriptions]
|
79 |
gen_tokens = self.lm.generate(
|
80 |
conditions=attributes,
|
81 |
-
max_gen_len=int(self.duration * self.frame_rate)) #[
|
82 |
-
x = self.compression_model.decode(gen_tokens, None) #[
|
83 |
-
|
84 |
-
x = x.reshape(1, n_draw * n_time_samples) # linearise n_draw
|
85 |
-
print('______________\nGENTOk 5', gen_tokens)
|
86 |
print('GENAUD 5', x.sum())
|
87 |
return x
|
88 |
|
@@ -147,13 +144,13 @@ class AudioGen(nn.Module):
|
|
147 |
attribute_dropout = dict_from_config(getattr(cfg, 'attribute_dropout'))
|
148 |
cls_free_guidance = dict_from_config(getattr(cfg, 'classifier_free_guidance'))
|
149 |
cfg_prob, cfg_coef = cls_free_guidance['training_dropout'], cls_free_guidance['inference_coef']
|
150 |
-
|
151 |
condition_provider = self.get_conditioner_provider(kwargs["dim"], cfg
|
152 |
).to(self.device)
|
153 |
|
154 |
|
155 |
-
if len(fuser.fuse2cond['cross']) > 0: # enforce cross-att programmatically
|
156 |
-
|
157 |
if codebooks_pattern_cfg.modeling is None:
|
158 |
print('Q MODELING\n=\n=><')
|
159 |
assert q_modeling is not None, \
|
@@ -166,7 +163,6 @@ class AudioGen(nn.Module):
|
|
166 |
return LMModel(
|
167 |
pattern_provider=pattern_provider,
|
168 |
condition_provider=condition_provider,
|
169 |
-
fuser=fuser,
|
170 |
cfg_dropout=cfg_prob,
|
171 |
cfg_coef=cfg_coef,
|
172 |
attribute_dropout=attribute_dropout,
|
@@ -202,14 +198,6 @@ class AudioGen(nn.Module):
|
|
202 |
return ConditioningProvider(conditioners)
|
203 |
|
204 |
|
205 |
-
def get_condition_fuser(self, cfg):
|
206 |
-
"""Instantiate a condition fuser object."""
|
207 |
-
fuser_cfg = getattr(cfg, 'fuser')
|
208 |
-
fuser_methods = ['sum', 'cross', 'prepend', 'input_interpolate']
|
209 |
-
fuse2cond = {k: fuser_cfg[k] for k in fuser_methods}
|
210 |
-
kwargs = {k: v for k, v in fuser_cfg.items() if k not in fuser_methods}
|
211 |
-
fuser = ConditionFuser(fuse2cond=fuse2cond, **kwargs)
|
212 |
-
return fuser
|
213 |
|
214 |
|
215 |
def get_codebooks_pattern_provider(self, n_q, cfg):
|
|
|
11 |
from .seanet import SEANetDecoder
|
12 |
from .codebooks_patterns import DelayedPatternProvider
|
13 |
from .conditioners import (
|
|
|
14 |
ConditioningProvider,
|
15 |
T5Conditioner,
|
16 |
ConditioningAttributes
|
|
|
77 |
ConditioningAttributes(text={'description': d}) for d in descriptions]
|
78 |
gen_tokens = self.lm.generate(
|
79 |
conditions=attributes,
|
80 |
+
max_gen_len=int(self.duration * self.frame_rate)) # [bs, 4, 37 * self.lm.n_draw]
|
81 |
+
x = self.compression_model.decode(gen_tokens, None) #[bs, 1, 11840]
|
82 |
+
print('______________\nGENTOk 5', gen_tokens.shape)
|
|
|
|
|
83 |
print('GENAUD 5', x.sum())
|
84 |
return x
|
85 |
|
|
|
144 |
attribute_dropout = dict_from_config(getattr(cfg, 'attribute_dropout'))
|
145 |
cls_free_guidance = dict_from_config(getattr(cfg, 'classifier_free_guidance'))
|
146 |
cfg_prob, cfg_coef = cls_free_guidance['training_dropout'], cls_free_guidance['inference_coef']
|
147 |
+
|
148 |
condition_provider = self.get_conditioner_provider(kwargs["dim"], cfg
|
149 |
).to(self.device)
|
150 |
|
151 |
|
152 |
+
# if len(fuser.fuse2cond['cross']) > 0: # enforce cross-att programmatically
|
153 |
+
kwargs['cross_attention'] = True
|
154 |
if codebooks_pattern_cfg.modeling is None:
|
155 |
print('Q MODELING\n=\n=><')
|
156 |
assert q_modeling is not None, \
|
|
|
163 |
return LMModel(
|
164 |
pattern_provider=pattern_provider,
|
165 |
condition_provider=condition_provider,
|
|
|
166 |
cfg_dropout=cfg_prob,
|
167 |
cfg_coef=cfg_coef,
|
168 |
attribute_dropout=attribute_dropout,
|
|
|
198 |
return ConditioningProvider(conditioners)
|
199 |
|
200 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
201 |
|
202 |
|
203 |
def get_codebooks_pattern_provider(self, n_q, cfg):
|
audiocraft/conditioners.py
CHANGED
@@ -4,7 +4,6 @@ import logging
|
|
4 |
import random
|
5 |
import typing as tp
|
6 |
import warnings
|
7 |
-
import soundfile
|
8 |
from transformers import T5EncoderModel, T5Tokenizer # type: ignore
|
9 |
import torch
|
10 |
from torch import nn
|
@@ -243,33 +242,4 @@ class ConditioningProvider(nn.Module):
|
|
243 |
for text in texts:
|
244 |
for condition in self.text_conditions:
|
245 |
out[condition].append(text[condition])
|
246 |
-
return out
|
247 |
-
|
248 |
-
|
249 |
-
|
250 |
-
|
251 |
-
|
252 |
-
|
253 |
-
class ConditionFuser(nn.Module):
|
254 |
-
|
255 |
-
FUSING_METHODS = ["sum", "prepend", "cross", "input_interpolate"]
|
256 |
-
|
257 |
-
def __init__(self, fuse2cond: tp.Dict[str, tp.List[str]], cross_attention_pos_emb: bool = False,
|
258 |
-
cross_attention_pos_emb_scale: float = 1.0):
|
259 |
-
super().__init__()
|
260 |
-
assert all(
|
261 |
-
[k in self.FUSING_METHODS for k in fuse2cond.keys()]
|
262 |
-
), f"Got invalid fuse method, allowed methods: {self.FUSING_METHODS}"
|
263 |
-
self.cross_attention_pos_emb = cross_attention_pos_emb
|
264 |
-
self.cross_attention_pos_emb_scale = cross_attention_pos_emb_scale
|
265 |
-
self.fuse2cond: tp.Dict[str, tp.List[str]] = fuse2cond
|
266 |
-
self.cond2fuse: tp.Dict[str, str] = {}
|
267 |
-
for fuse_method, conditions in fuse2cond.items():
|
268 |
-
for condition in conditions:
|
269 |
-
self.cond2fuse[condition] = fuse_method
|
270 |
-
|
271 |
-
def forward(
|
272 |
-
self,
|
273 |
-
input,
|
274 |
-
conditions):
|
275 |
-
return input, conditions['description'][0] #cross_attention_output
|
|
|
4 |
import random
|
5 |
import typing as tp
|
6 |
import warnings
|
|
|
7 |
from transformers import T5EncoderModel, T5Tokenizer # type: ignore
|
8 |
import torch
|
9 |
from torch import nn
|
|
|
242 |
for text in texts:
|
243 |
for condition in self.text_conditions:
|
244 |
out[condition].append(text[condition])
|
245 |
+
return out
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
audiocraft/lm.py
CHANGED
@@ -10,31 +10,7 @@ from functools import partial
|
|
10 |
from torch import nn
|
11 |
from audiocraft.activations import get_activation_fn
|
12 |
|
13 |
-
def sample_top_k(p, k=1, n_draw=None):
|
14 |
-
"""
|
15 |
-
p probabs 2048 ?
|
16 |
-
num_draw : how many tokens to sample (for duplicate elongation)
|
17 |
-
"""
|
18 |
|
19 |
-
p = torch.softmax(p, dim=-1) # p/temp
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
top_k_value, i250 = torch.topk(p, k, dim=-1) # probs: [1, 4, 2048]
|
24 |
-
# print('\n_____TOPK________\n', top_k_value.shape, top_k_value[0, 0, :10], '\n___________END_TOPK____________\n')
|
25 |
-
min_value_top_k = top_k_value[..., [-1]] #
|
26 |
-
p *= (p >= min_value_top_k).float()
|
27 |
-
p.div_(p.sum(dim=-1, keepdim=True))
|
28 |
-
# -- next_token = multinomial(probs, num_samples=num_draw)
|
29 |
-
|
30 |
-
# RESHAPED into bs, 4, 250
|
31 |
-
p_ = p.reshape(-1, p.shape[-1])
|
32 |
-
|
33 |
-
|
34 |
-
out = torch.multinomial(p_,
|
35 |
-
num_samples=n_draw,
|
36 |
-
replacement=False) # [4, num_draw]
|
37 |
-
return out.transpose(0, 1)[:, :, None] # [num_draw, 4, 1]
|
38 |
|
39 |
|
40 |
|
@@ -160,21 +136,26 @@ class LMModel(nn.Module):
|
|
160 |
def __init__(self,
|
161 |
pattern_provider,
|
162 |
condition_provider,
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
170 |
**kwargs):
|
171 |
super().__init__()
|
172 |
self.cfg_coef = cfg_coef
|
173 |
-
|
174 |
-
self.n_draw = 1
|
175 |
self.condition_provider = condition_provider
|
176 |
-
self.fuser = fuser
|
177 |
self.card = card # 2048 ?
|
|
|
178 |
embed_dim = self.card + 1
|
179 |
self.n_q = n_q
|
180 |
self.dim = dim
|
@@ -251,37 +232,39 @@ class LMModel(nn.Module):
|
|
251 |
@property
|
252 |
def special_token_id(self) -> int:
|
253 |
return self.card
|
|
|
|
|
|
|
254 |
|
255 |
-
|
256 |
-
|
257 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
258 |
|
259 |
def forward(self,
|
260 |
sequence,
|
261 |
condition_tensors=None,
|
262 |
token_count=None):
|
263 |
-
|
264 |
-
input_ = sum([self.emb[k](sequence[:, k]) for k in range(
|
265 |
-
# input_, cross_attention_input = self.fuser(input_, condition_tensors)
|
266 |
-
cross_attention_input = condition_tensors['description'][0]
|
267 |
-
|
268 |
-
# print(f'{input_.shape=}')
|
269 |
out = self.transformer(input_,
|
270 |
-
cross_attention_src=
|
271 |
token_count=token_count)
|
272 |
if self.out_norm:
|
273 |
out = self.out_norm(out)
|
274 |
-
# K = 2 because of llm producing 2 tokens?
|
275 |
-
# so only 2 x sel.flinear() of 4 are used ?
|
276 |
-
# WHy torch.stack is in dim=1
|
277 |
-
logits = torch.stack([self.linears[k](out) for k in range(K)], dim=1) # [B, K, S, card]
|
278 |
-
# print(f'{input_.shape=} {out.shape=} {cross_attention_input.shape=} {logits.shape=} FUSER LLM')
|
279 |
-
# remove the prefix from the model outputs
|
280 |
-
# if len(self.fuser.fuse2cond['prepend']) > 0:
|
281 |
-
# logits = logits[:, :, -S:]
|
282 |
-
# print('==========================================PRESFIX')
|
283 |
|
284 |
-
|
|
|
|
|
285 |
|
286 |
|
287 |
# GENERATE class revert_codebook_patterns()
|
@@ -289,7 +272,6 @@ class LMModel(nn.Module):
|
|
289 |
def generate(self,
|
290 |
prompt = None,
|
291 |
conditions = [],
|
292 |
-
num_samples = 1, # N next token
|
293 |
max_gen_len=256):
|
294 |
|
295 |
print(f'{prompt=} {conditions=}')
|
@@ -299,7 +281,8 @@ class LMModel(nn.Module):
|
|
299 |
|
300 |
|
301 |
tokenized = self.condition_provider.tokenize(conditions)
|
302 |
-
|
|
|
303 |
# TOKENIZ {'description': {'input_ids': tensor([[3887, 16, 2815, 1],
|
304 |
# [3887, 16, 2815, 1]], device='cuda:0'), 'attention_mask': tensor([[1, 1, 1, 1],
|
305 |
# [1, 1, 1, 1]], device='cuda:0')}}
|
@@ -307,105 +290,66 @@ class LMModel(nn.Module):
|
|
307 |
cfg_conditions = self.condition_provider(tokenized)
|
308 |
|
309 |
|
310 |
-
|
311 |
-
|
312 |
-
|
313 |
-
|
314 |
-
|
315 |
-
|
316 |
-
|
317 |
-
|
318 |
-
|
319 |
-
pattern = self.pattern_provider.get_pattern(max_gen_len) # duplicate sequence
|
320 |
-
# this token is used as default value for codes that are not generated yet ?
|
321 |
-
unknown_token = -1
|
322 |
-
|
323 |
-
|
324 |
-
gen_codes = torch.full((B, K, max_gen_len), unknown_token, dtype=torch.long, device=device)
|
325 |
-
|
326 |
-
gen_codes[..., :start_offset] = prompt # place 0
|
327 |
-
|
328 |
-
_gen_sequence, _, mask = pattern.build_pattern_sequence(gen_codes, self.special_token_id)
|
329 |
|
|
|
|
|
330 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
331 |
|
|
|
|
|
332 |
|
333 |
|
334 |
|
335 |
-
|
336 |
-
# print(mask.shape, mask.sum(), 'MSK LM')
|
337 |
-
# torch.Size([4, 39]) tensor(140, device='cuda:0') MSK LM ? Fully 1 normal no special token
|
338 |
-
# --\
|
339 |
-
|
340 |
-
# list - Elongation for take-5 next tokens - n_draw 5 tokens at each time-step
|
341 |
-
# append them at end of sequence
|
342 |
-
duplicate_draw = [
|
343 |
-
_gen_sequence[:, :, 0:1].repeat(self.n_draw, 1, 1)
|
344 |
-
]
|
345 |
-
|
346 |
-
|
347 |
-
for offset in range(1, _gen_sequence.shape[2]):
|
348 |
-
|
349 |
-
|
350 |
|
351 |
-
|
352 |
-
|
353 |
-
|
354 |
-
logits = self.forward(_gen_sequence[:, :, offset-1:offset], # bs/n_draw, 4, 1
|
355 |
condition_tensors=cfg_conditions,
|
356 |
-
token_count=offset)
|
357 |
-
|
358 |
-
# print(f'BEF {logits.shape=} BEF utils.SampleTop5') # AGREES 4 BEF logits.shape=torch.Size([1, 4, 1, 2048]) BEF utils.SampleTop5
|
359 |
-
next_token = sample_top_k(logits, n_draw=self.n_draw) # [1,4,2048] logits
|
360 |
|
361 |
|
362 |
-
|
363 |
-
_gen_sequence[:, :, offset] = next_token[0, :, 0] # next_token=[1,4,6] gen_seq=[1, 4, 39]
|
364 |
-
|
365 |
-
duplicate_draw.append(next_token)
|
366 |
|
367 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
368 |
|
369 |
|
370 |
-
|
371 |
-
|
372 |
-
#
|
373 |
-
|
374 |
-
|
375 |
-
|
376 |
-
|
377 |
-
|
378 |
-
|
379 |
-
|
380 |
-
print(
|
381 |
-
|
382 |
-
#
|
383 |
-
out_codes, _, _ = pattern.revert_pattern_sequence(gen_sequence,
|
384 |
-
special_token=unknown_token)
|
385 |
-
|
386 |
-
|
387 |
-
# set(out_codes.unique().tolist()) - set(gen_sequence.unique().tolist()) # set()
|
388 |
-
|
389 |
-
# UNIQUE are the SAME ---------------?> is it rearrange
|
390 |
-
|
391 |
-
|
392 |
-
|
393 |
-
# ARE SOME PARTS IGNORED OR RE-ARRANGED
|
394 |
-
|
395 |
-
# print(f'{unknown_token=} {gen_sequence.shape=} {out_codes.shape=}')
|
396 |
-
# -> unknown tokn = -1 or 2048
|
397 |
-
# unknown_token=-1
|
398 |
-
|
399 |
-
print(f' <=> CODES {out_codes.shape=} {out_codes.min()} {out_codes.max()}\n') # ARRIVES here also if special
|
400 |
-
|
401 |
-
# unknown_token=-1 gen_sequence.shape=torch.Size([1, 4, 39]) out_codes.shape=torch.Size([1, 4, 35])
|
402 |
-
# <=> CODES out_codes.shape=torch.Size([1, 4, 35]) 30 2024
|
403 |
-
|
404 |
-
|
405 |
-
|
406 |
-
# Clean Transformer MHA k_history v_history
|
407 |
for lay in self.transformer.layers:
|
408 |
lay.self_attn.k_history = None
|
409 |
lay.self_attn.v_history = None
|
410 |
-
|
411 |
-
return out_codes
|
|
|
10 |
from torch import nn
|
11 |
from audiocraft.activations import get_activation_fn
|
12 |
|
|
|
|
|
|
|
|
|
|
|
13 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
|
15 |
|
16 |
|
|
|
136 |
def __init__(self,
|
137 |
pattern_provider,
|
138 |
condition_provider,
|
139 |
+
n_q: int = 8,
|
140 |
+
card: int = 1024,
|
141 |
+
dim: int = 128,
|
142 |
+
num_heads: int = 8,
|
143 |
+
hidden_scale: int = 4,
|
144 |
+
norm: str = 'layer_norm',
|
145 |
+
norm_first: bool = False,
|
146 |
+
emb_lr: tp.Optional[float] = None,
|
147 |
+
bias_proj: bool = True,
|
148 |
+
weight_init: tp.Optional[str] = None,
|
149 |
+
depthwise_init: tp.Optional[str] = None,
|
150 |
+
zero_bias_init: bool = False, cfg_dropout: float = 0,
|
151 |
+
cfg_coef: float = 1.0,
|
152 |
+
two_step_cfg: bool = False,
|
153 |
**kwargs):
|
154 |
super().__init__()
|
155 |
self.cfg_coef = cfg_coef
|
|
|
|
|
156 |
self.condition_provider = condition_provider
|
|
|
157 |
self.card = card # 2048 ?
|
158 |
+
self.n_draw = 8 # replicate so many times the generation of each text in batch
|
159 |
embed_dim = self.card + 1
|
160 |
self.n_q = n_q
|
161 |
self.dim = dim
|
|
|
232 |
@property
|
233 |
def special_token_id(self) -> int:
|
234 |
return self.card
|
235 |
+
|
236 |
+
def sample_top_k(self, p, k=249):
|
237 |
+
bs, _, _, hidden = p.shape # logits [3, 4, 1, 2048]
|
238 |
|
239 |
+
p = torch.softmax(p, dim=3)
|
240 |
+
top_k_value, i250 = torch.topk(p, k, dim=3) # [3, 4, 1, k]
|
241 |
+
min_value_top_k = top_k_value[:, :, :, -1:]
|
242 |
+
p *= (p >= min_value_top_k).float() # zero low probs
|
243 |
+
p.div_(p.sum(dim=-1, keepdim=True)) # renormalise on non-zero probs
|
244 |
+
|
245 |
+
|
246 |
+
# BRING THE nq = 4 IN BATCH
|
247 |
+
p = p.reshape(bs * self.n_q, hidden)
|
248 |
+
out = torch.multinomial(p, # p=[bs,2048], out=[bs, num_samples]
|
249 |
+
num_samples=self.n_draw,
|
250 |
+
replacement=False) # [bs*4, self.n_draw]
|
251 |
+
return out.reshape(bs, self.n_q, self.n_draw).transpose(1,2) # [bs, self.n_draw, 4]
|
252 |
|
253 |
def forward(self,
|
254 |
sequence,
|
255 |
condition_tensors=None,
|
256 |
token_count=None):
|
257 |
+
|
258 |
+
input_ = sum([self.emb[k](sequence[:, k]) for k in range(self.n_q)])
|
|
|
|
|
|
|
|
|
259 |
out = self.transformer(input_,
|
260 |
+
cross_attention_src=condition_tensors['description'][0],
|
261 |
token_count=token_count)
|
262 |
if self.out_norm:
|
263 |
out = self.out_norm(out)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
264 |
|
265 |
+
logits = torch.stack([self.linears[k](out) for k in range(self.n_q)], dim=1)
|
266 |
+
|
267 |
+
return logits # [bs, 4, 1, 2048]
|
268 |
|
269 |
|
270 |
# GENERATE class revert_codebook_patterns()
|
|
|
272 |
def generate(self,
|
273 |
prompt = None,
|
274 |
conditions = [],
|
|
|
275 |
max_gen_len=256):
|
276 |
|
277 |
print(f'{prompt=} {conditions=}')
|
|
|
281 |
|
282 |
|
283 |
tokenized = self.condition_provider.tokenize(conditions)
|
284 |
+
|
285 |
+
# print(f'TOKENIZ, {tokenized.keys()=}, {tokenized=}') # 'description'
|
286 |
# TOKENIZ {'description': {'input_ids': tensor([[3887, 16, 2815, 1],
|
287 |
# [3887, 16, 2815, 1]], device='cuda:0'), 'attention_mask': tensor([[1, 1, 1, 1],
|
288 |
# [1, 1, 1, 1]], device='cuda:0')}}
|
|
|
290 |
cfg_conditions = self.condition_provider(tokenized)
|
291 |
|
292 |
|
293 |
+
# print(f'CFGcon, {cfg_conditions.keys()=}, {cfg_conditions["description"][0].shape=}')
|
294 |
+
# USE THIS ATTENTION MASK IF NOT SAME LEN;
|
295 |
+
bs, _7, _1536 = cfg_conditions['description'][0].shape # [bs, textlen, 1536]
|
296 |
+
pattern = self.pattern_provider.get_pattern(max_gen_len)
|
297 |
+
gen_codes = torch.full((bs,
|
298 |
+
self.n_q,
|
299 |
+
max_gen_len), -1, dtype=torch.long, device=device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
300 |
|
301 |
+
gen_sequence, _, mask = pattern.build_pattern_sequence(gen_codes, self.special_token_id)
|
302 |
+
_, _, audiodur = gen_sequence.shape # bs, 4, 7=audiodur
|
303 |
|
304 |
+
# print(gen_sequence.shape, mask.shape, 'F') # mask has no batch = [4,audio_duration]
|
305 |
+
# print(f'{mask=}')
|
306 |
+
#
|
307 |
+
# torch.Size([3, 4, 7]) torch.Size([4, 7]) F
|
308 |
+
# mask=tensor([[False, True, True, True, False, False, False],
|
309 |
+
# [False, False, True, True, True, False, False],
|
310 |
+
# [False, False, False, True, True, True, False],
|
311 |
+
# [False, False, False, False, True, True, True]], device='cuda:0')
|
312 |
|
313 |
+
mask = mask[None, None, :, :].repeat(bs, self.n_draw, 1, 1) # [bs, n_draw, 4, audio duration]
|
314 |
+
gen_sequence = gen_sequence[:, None, :, :].repeat(1, self.n_draw, 1, 1) # bs,n_draw,4,dur
|
315 |
|
316 |
|
317 |
|
318 |
+
for offset in range(1, audiodur):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
319 |
|
320 |
+
# pass only 0-th draw in forward
|
321 |
+
logits = self.forward(gen_sequence[:, 0, :, offset-1:offset],
|
|
|
|
|
322 |
condition_tensors=cfg_conditions,
|
323 |
+
token_count=offset) # [bs, 4, 1, 2048]
|
|
|
|
|
|
|
324 |
|
325 |
|
326 |
+
next_token = self.sample_top_k(logits) # [bs, n_draw, 4]
|
|
|
|
|
|
|
327 |
|
328 |
+
# MASK is not full 1---- HAS 4 x audioduration PATTERN
|
329 |
+
m = mask[:, :, :, offset]
|
330 |
+
next_token[~m] = self.special_token_id
|
331 |
+
gen_sequence[:, :, :, offset] = torch.where(
|
332 |
+
gen_sequence[:, :, :, offset] == -1, #unknown_token,
|
333 |
+
next_token,
|
334 |
+
gen_sequence[:, :, :, offset]
|
335 |
+
)
|
336 |
|
337 |
|
338 |
+
# 1. reshape n_draw as bs * n_draw
|
339 |
+
# 2. invert all short-sequences
|
340 |
+
# 3. reshape bs * n_draw -> bs, n_draw * audiodur ELONGATION
|
341 |
+
out_codes, _, _ = pattern.revert_pattern_sequence(
|
342 |
+
gen_sequence.reshape(bs * self.n_draw, 4, audiodur), # [3,8,4,7]
|
343 |
+
special_token=-1)
|
344 |
+
# print(f'{gen_sequence.shape=} {out_codes.shape=} Ha') # REVERT PATTERN REDUCES DURATION?
|
345 |
+
_, _, new_len = out_codes.shape # 4 IS PRESERVED AFTER REVERT!
|
346 |
+
out_codes = out_codes.reshape(bs, self.n_draw, 4, new_len)
|
347 |
+
out_codes = out_codes.transpose(1, 2).reshape(bs, 4, self.n_draw * new_len)
|
348 |
+
print(out_codes.shape, 'o')
|
349 |
+
|
350 |
+
# Clear Transformer k/v history (Different history is kept by 48x selfattn)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
351 |
for lay in self.transformer.layers:
|
352 |
lay.self_attn.k_history = None
|
353 |
lay.self_attn.v_history = None
|
354 |
+
|
355 |
+
return out_codes
|
demo.py
CHANGED
@@ -1,11 +1,13 @@
|
|
1 |
import audiofile
|
2 |
import numpy as np
|
3 |
from audiocraft import AudioGen
|
4 |
-
|
5 |
|
6 |
-
sound_generator = AudioGen(duration=.
|
7 |
device='cuda:0').to('cuda:0').eval()
|
8 |
-
x = sound_generator.generate(
|
|
|
|
|
9 |
x /= np.abs(x).max() + 1e-7
|
10 |
|
11 |
audiofile.write('del_seane.wav', x, 16000)
|
|
|
1 |
import audiofile
|
2 |
import numpy as np
|
3 |
from audiocraft import AudioGen
|
4 |
+
text_list = ['dogs barging in the street', 'people po']
|
5 |
|
6 |
+
sound_generator = AudioGen(duration=.74,
|
7 |
device='cuda:0').to('cuda:0').eval()
|
8 |
+
x = sound_generator.generate(text_list) # [bs, 1, 7680]
|
9 |
+
# print('demo', x.shape)
|
10 |
+
x = x[0, :, :].detach().cpu().numpy()
|
11 |
x /= np.abs(x).max() + 1e-7
|
12 |
|
13 |
audiofile.write('del_seane.wav', x, 16000)
|