DEBUG: cross_attention_src = query or key?
Browse files- audiocraft/audiogen.py +13 -47
- audiocraft/conditioners.py +7 -48
- audiocraft/lm.py +96 -112
- audiocraft/streaming.py +0 -131
- audiocraft/transformer.py +137 -394
- demo.py +1 -1
audiocraft/audiogen.py
CHANGED
@@ -4,11 +4,6 @@
|
|
4 |
# This source code is licensed under the license found in the
|
5 |
# LICENSE file in the root directory of this source tree.
|
6 |
|
7 |
-
"""
|
8 |
-
Main model for using AudioGen. This will combine all the required components
|
9 |
-
and provide easy access to the generation API.
|
10 |
-
"""
|
11 |
-
|
12 |
import typing as tp
|
13 |
import torch
|
14 |
from audiocraft.loaders import load_compression_model, load_lm_model
|
@@ -87,51 +82,25 @@ class BaseGenModel(ABC):
|
|
87 |
"""Sample rate of the generated audio."""
|
88 |
return self.compression_model.sample_rate
|
89 |
|
90 |
-
@property
|
91 |
-
def audio_channels(self) -> int:
|
92 |
-
"""Audio channels of the generated audio."""
|
93 |
-
return self.compression_model.channels
|
94 |
-
|
95 |
-
@torch.no_grad()
|
96 |
-
def _prepare_tokens_and_attributes(
|
97 |
-
self,
|
98 |
-
descriptions,
|
99 |
-
prompt,
|
100 |
-
):
|
101 |
-
attributes = [
|
102 |
-
ConditioningAttributes(text={'description': description}) for description in descriptions]
|
103 |
-
prompt_tokens = None
|
104 |
-
return attributes, prompt_tokens
|
105 |
-
|
106 |
-
def generate_unconditional(self,
|
107 |
-
num_samples,
|
108 |
-
progress=False,
|
109 |
-
return_tokens=False):
|
110 |
-
descriptions: tp.List[tp.Optional[str]] = [None] * num_samples
|
111 |
-
attributes, _ = self._prepare_tokens_and_attributes(descriptions, None)
|
112 |
-
tokens = self._generate_tokens(attributes)
|
113 |
-
if return_tokens:
|
114 |
-
return self.generate_audio(tokens), tokens
|
115 |
-
return self.generate_audio(tokens)
|
116 |
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
attributes
|
|
|
122 |
tokens = self._generate_tokens(attributes)
|
123 |
-
if return_tokens:
|
124 |
-
return self.generate_audio(tokens), tokens
|
125 |
return self.generate_audio(tokens)
|
126 |
|
127 |
-
def _generate_tokens(self, attributes
|
128 |
-
prompt_tokens=None,
|
129 |
-
progress=False):
|
130 |
|
131 |
total_gen_len = int(self.duration * self.frame_rate)
|
132 |
-
max_prompt_len = int(min(self.duration, self.max_duration) * self.frame_rate)
|
133 |
-
current_gen_offset: int = 0
|
134 |
|
|
|
|
|
|
|
|
|
|
|
135 |
|
136 |
|
137 |
|
@@ -140,10 +109,7 @@ class BaseGenModel(ABC):
|
|
140 |
# generate by sampling from LM, simple case.
|
141 |
|
142 |
with self.autocast:
|
143 |
-
gen_tokens = self.lm.generate(conditions=attributes,
|
144 |
-
callback=None,
|
145 |
-
max_gen_len=total_gen_len,
|
146 |
-
**self.generation_params)
|
147 |
else:
|
148 |
print('<>Long gen ?<>')
|
149 |
# print(f'{gen_tokens.shape=}') # [5,4,35]
|
|
|
4 |
# This source code is licensed under the license found in the
|
5 |
# LICENSE file in the root directory of this source tree.
|
6 |
|
|
|
|
|
|
|
|
|
|
|
7 |
import typing as tp
|
8 |
import torch
|
9 |
from audiocraft.loaders import load_compression_model, load_lm_model
|
|
|
82 |
"""Sample rate of the generated audio."""
|
83 |
return self.compression_model.sample_rate
|
84 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
85 |
|
86 |
+
|
87 |
+
|
88 |
+
|
89 |
+
def generate(self, descriptions):
|
90 |
+
attributes = [
|
91 |
+
ConditioningAttributes(text={'description': d}) for d in descriptions]
|
92 |
tokens = self._generate_tokens(attributes)
|
|
|
|
|
93 |
return self.generate_audio(tokens)
|
94 |
|
95 |
+
def _generate_tokens(self, attributes):
|
|
|
|
|
96 |
|
97 |
total_gen_len = int(self.duration * self.frame_rate)
|
|
|
|
|
98 |
|
99 |
+
|
100 |
+
# # print(f'{self.generation_params=}')
|
101 |
+
# self.generation_params={'use_sampling': True,
|
102 |
+
# 'temp': 1.0, 'top_k': 250,
|
103 |
+
# 'top_p': 0.0, 'cfg_coef': 2.4, 'two_step_cfg': False}
|
104 |
|
105 |
|
106 |
|
|
|
109 |
# generate by sampling from LM, simple case.
|
110 |
|
111 |
with self.autocast:
|
112 |
+
gen_tokens = self.lm.generate(conditions=attributes, max_gen_len=total_gen_len)
|
|
|
|
|
|
|
113 |
else:
|
114 |
print('<>Long gen ?<>')
|
115 |
# print(f'{gen_tokens.shape=}') # [5,4,35]
|
audiocraft/conditioners.py
CHANGED
@@ -8,7 +8,7 @@ import soundfile
|
|
8 |
from transformers import T5EncoderModel, T5Tokenizer # type: ignore
|
9 |
import torch
|
10 |
from torch import nn
|
11 |
-
|
12 |
|
13 |
from .utils.autocast import TorchAutocast
|
14 |
|
@@ -126,17 +126,7 @@ class BaseConditioner(nn.Module):
|
|
126 |
"""
|
127 |
raise NotImplementedError()
|
128 |
|
129 |
-
|
130 |
-
"""Gets input that should be used as conditioning (e.g, genre, description or a waveform).
|
131 |
-
Outputs a ConditionType, after the input data was embedded as a dense vector.
|
132 |
-
|
133 |
-
Returns:
|
134 |
-
ConditionType:
|
135 |
-
- A tensor of size [B, T, D] where B is the batch size, T is the length of the
|
136 |
-
output embedding and D is the dimension of the embedding.
|
137 |
-
- And a mask indicating where the padding tokens.
|
138 |
-
"""
|
139 |
-
raise NotImplementedError()
|
140 |
|
141 |
|
142 |
class TextConditioner(BaseConditioner):
|
@@ -239,6 +229,9 @@ class T5Conditioner(TextConditioner):
|
|
239 |
embeds = self.t5(**inputs).last_hidden_state
|
240 |
embeds = self.output_proj(embeds.to(self.output_proj.weight))
|
241 |
embeds = (embeds * mask.unsqueeze(-1))
|
|
|
|
|
|
|
242 |
return embeds, mask
|
243 |
|
244 |
|
@@ -352,21 +345,8 @@ class ConditioningProvider(nn.Module):
|
|
352 |
|
353 |
|
354 |
|
355 |
-
class ConditionFuser(
|
356 |
-
"""Condition fuser handles the logic to combine the different conditions
|
357 |
-
to the actual model input.
|
358 |
|
359 |
-
Args:
|
360 |
-
fuse2cond (tp.Dict[str, str]): A dictionary that says how to fuse
|
361 |
-
each condition. For example:
|
362 |
-
{
|
363 |
-
"prepend": ["description"],
|
364 |
-
"sum": ["genre", "bpm"],
|
365 |
-
"cross": ["description"],
|
366 |
-
}
|
367 |
-
cross_attention_pos_emb (bool, optional): Use positional embeddings in cross attention.
|
368 |
-
cross_attention_pos_emb_scale (int): Scale for positional embeddings in cross attention if used.
|
369 |
-
"""
|
370 |
FUSING_METHODS = ["sum", "prepend", "cross", "input_interpolate"]
|
371 |
|
372 |
def __init__(self, fuse2cond: tp.Dict[str, tp.List[str]], cross_attention_pos_emb: bool = False,
|
@@ -387,25 +367,4 @@ class ConditionFuser(StreamingModule):
|
|
387 |
self,
|
388 |
input,
|
389 |
conditions):
|
390 |
-
|
391 |
-
B, T, _ = input.shape
|
392 |
-
|
393 |
-
|
394 |
-
first_step = True
|
395 |
-
offsets = torch.zeros(input.shape[0], dtype=torch.long, device=input.device)
|
396 |
-
|
397 |
-
|
398 |
-
cross_attention_output = None
|
399 |
-
for cond_type, (cond, cond_mask) in conditions.items():
|
400 |
-
# print(f'{self.cond2fuse=}') - self.cond2fuse={'description': 'cross'}
|
401 |
-
|
402 |
-
cross_attention_output = cond
|
403 |
-
# print(f'{cross_attention_output.shape=} for {input.sum()=}')
|
404 |
-
# cross_attention_output.shape=torch.Size([2, 5, 1536]) for input.sum()=tensor(-0.0650, device='cuda:0')
|
405 |
-
# cross_attention_output.shape=torch.Size([2, 5, 1536]) for input.sum()=tensor(3.7672, device='cuda:0')
|
406 |
-
|
407 |
-
|
408 |
-
if self._is_streaming:
|
409 |
-
self._streaming_state['offsets'] = offsets + T
|
410 |
-
|
411 |
-
return input, cross_attention_output
|
|
|
8 |
from transformers import T5EncoderModel, T5Tokenizer # type: ignore
|
9 |
import torch
|
10 |
from torch import nn
|
11 |
+
|
12 |
|
13 |
from .utils.autocast import TorchAutocast
|
14 |
|
|
|
126 |
"""
|
127 |
raise NotImplementedError()
|
128 |
|
129 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
130 |
|
131 |
|
132 |
class TextConditioner(BaseConditioner):
|
|
|
229 |
embeds = self.t5(**inputs).last_hidden_state
|
230 |
embeds = self.output_proj(embeds.to(self.output_proj.weight))
|
231 |
embeds = (embeds * mask.unsqueeze(-1))
|
232 |
+
|
233 |
+
# T5 torch.Size([2, 4, 1536]) dict_keys(['input_ids', 'attention_mask'])
|
234 |
+
# print(f'{inputs["input_ids"].shape=}') # inputs["input_ids"].shape=torch.Size([2, 4])
|
235 |
return embeds, mask
|
236 |
|
237 |
|
|
|
345 |
|
346 |
|
347 |
|
348 |
+
class ConditionFuser(nn.Module):
|
|
|
|
|
349 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
350 |
FUSING_METHODS = ["sum", "prepend", "cross", "input_interpolate"]
|
351 |
|
352 |
def __init__(self, fuse2cond: tp.Dict[str, tp.List[str]], cross_attention_pos_emb: bool = False,
|
|
|
367 |
self,
|
368 |
input,
|
369 |
conditions):
|
370 |
+
return input, conditions['description'][0] #cross_attention_output
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
audiocraft/lm.py
CHANGED
@@ -6,7 +6,6 @@ import re
|
|
6 |
import typing as tp
|
7 |
import torch
|
8 |
import torch.nn.functional as F
|
9 |
-
from audiocraft.streaming import StreamingModule
|
10 |
from audiocraft.transformer import StreamingTransformer, create_norm_fn
|
11 |
from dataclasses import dataclass
|
12 |
from functools import partial
|
@@ -109,7 +108,7 @@ class LMOutput:
|
|
109 |
mask: torch.Tensor # [B, K, T]
|
110 |
|
111 |
|
112 |
-
class LMModel(
|
113 |
"""Transformer-based language model on multiple streams of codes.
|
114 |
|
115 |
Args:
|
@@ -148,7 +147,7 @@ class LMModel(StreamingModule):
|
|
148 |
super().__init__()
|
149 |
self.cfg_coef = cfg_coef
|
150 |
|
151 |
-
self.n_draw =
|
152 |
self.condition_provider = condition_provider
|
153 |
self.fuser = fuser
|
154 |
self.card = card # 2048 ?
|
@@ -160,9 +159,26 @@ class LMModel(StreamingModule):
|
|
160 |
self.emb = nn.ModuleList([ScaledEmbedding(embed_dim, dim, lr=emb_lr) for _ in range(n_q)])
|
161 |
if 'activation' in kwargs:
|
162 |
kwargs['activation'] = get_activation_fn(kwargs['activation'])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
163 |
self.transformer = StreamingTransformer(
|
164 |
-
d_model=dim,
|
165 |
-
|
|
|
|
|
|
|
166 |
self.out_norm: tp.Optional[nn.Module] = None
|
167 |
if norm_first:
|
168 |
self.out_norm = create_norm_fn(norm, dim)
|
@@ -199,7 +215,10 @@ class LMModel(StreamingModule):
|
|
199 |
depth = layer_idx + 1
|
200 |
elif depthwise_init == 'global':
|
201 |
depth = len(self.transformer.layers)
|
202 |
-
init_fn = partial(init_layer,
|
|
|
|
|
|
|
203 |
tr_layer.apply(init_fn)
|
204 |
|
205 |
for linear in self.linears:
|
@@ -215,91 +234,55 @@ class LMModel(StreamingModule):
|
|
215 |
|
216 |
def forward(self,
|
217 |
sequence,
|
218 |
-
conditions,
|
219 |
condition_tensors=None,
|
220 |
stage = -1):
|
221 |
-
B, K, S = sequence.shape
|
222 |
-
|
223 |
|
224 |
input_ = sum([self.emb[k](sequence[:, k]) for k in range(K)])
|
225 |
|
226 |
|
227 |
-
input_, cross_attention_input = self.fuser(input_, condition_tensors)
|
228 |
-
|
229 |
-
# print(f'{input_.shape=} {cross_attention_input.shape=} FUSER LLM
|
230 |
-
|
231 |
|
232 |
out = self.transformer(input_, cross_attention_src=cross_attention_input,
|
233 |
src_mask=(self.attn_mask_per_stage[stage] if stage >= 0 else None))
|
234 |
if self.out_norm:
|
235 |
out = self.out_norm(out)
|
|
|
|
|
|
|
236 |
logits = torch.stack([self.linears[k](out) for k in range(K)], dim=1) # [B, K, S, card]
|
237 |
-
|
238 |
# remove the prefix from the model outputs
|
239 |
-
if len(self.fuser.fuse2cond['prepend']) > 0:
|
240 |
-
|
241 |
-
|
242 |
|
243 |
return logits # [B, K, S, card]
|
244 |
|
245 |
|
246 |
-
def _sample_next_token(self,
|
247 |
-
sequence,
|
248 |
-
cfg_conditions,
|
249 |
-
unconditional_state):
|
250 |
-
"""self.n_draw"""
|
251 |
-
B = sequence.shape[0]
|
252 |
-
|
253 |
-
model = self if self._fsdp is None else self._fsdp
|
254 |
-
|
255 |
-
condition_tensors = cfg_conditions
|
256 |
-
# logits = [2, 4, 1, 2048]
|
257 |
-
logits = model(
|
258 |
-
sequence, # cond_logits = wav condition
|
259 |
-
conditions=[], condition_tensors=condition_tensors) # uncond_logits already see the text
|
260 |
-
|
261 |
-
|
262 |
-
# use cfg
|
263 |
-
# logits = (3 * logits[1, :, :, :] - 2.4 * logits[0, :, :, :]).transpose(1,0)
|
264 |
-
|
265 |
-
# or use 1 of logits
|
266 |
-
logits = logits[0, :, :, :].transpose(1,0) # [2,4,1, 2048] -> [1,4,2048]
|
267 |
-
|
268 |
-
|
269 |
-
|
270 |
-
# print(f'{B=}, {logits.shape=} SAMPLER {top_k=}')
|
271 |
-
next_token = utils.sample_top_k(logits, n_draw=self.n_draw) # [1,4,2048] logits
|
272 |
-
return next_token
|
273 |
-
|
274 |
# GENERATE class revert_codebook_patterns()
|
275 |
@torch.no_grad()
|
276 |
def generate(self,
|
277 |
prompt = None,
|
278 |
conditions = [],
|
279 |
-
num_samples = 1,
|
280 |
-
max_gen_len=256
|
281 |
-
use_sampling: bool = True,
|
282 |
-
**kwargs):
|
283 |
|
284 |
-
print(f'{
|
285 |
first_param = next(iter(self.parameters()))
|
286 |
device = first_param.device
|
287 |
|
288 |
-
|
289 |
-
|
290 |
-
# we then do 1 forward pass instead of 2.
|
291 |
-
# the reason for that is two-fold:
|
292 |
-
# 1. it is about x2 faster than doing 2 forward passes
|
293 |
-
# 2. avoid the streaming API treating the 2 passes as part of different time steps
|
294 |
-
# We also support doing two different passes, in particular to ensure that
|
295 |
-
# the padding structure is exactly the same between train and test.
|
296 |
-
# With a batch size of 1, this can be slower though.
|
297 |
-
cfg_conditions: CFGConditions
|
298 |
-
# two_step_cfg = self.two_step_cfg if two_step_cfg is None else two_step_cfg
|
299 |
-
|
300 |
-
null_conditions = conditions
|
301 |
-
conditions = conditions + null_conditions
|
302 |
tokenized = self.condition_provider.tokenize(conditions)
|
|
|
|
|
|
|
|
|
|
|
303 |
cfg_conditions = self.condition_provider(tokenized)
|
304 |
|
305 |
|
@@ -326,58 +309,59 @@ class LMModel(StreamingModule):
|
|
326 |
|
327 |
|
328 |
|
329 |
-
|
|
|
|
|
|
|
|
|
|
|
330 |
|
331 |
-
|
332 |
-
|
333 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
334 |
|
|
|
|
|
|
|
|
|
|
|
335 |
# --
|
336 |
-
#
|
337 |
-
#
|
338 |
-
#
|
339 |
-
|
340 |
-
|
341 |
-
|
342 |
-
|
343 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
344 |
|
345 |
-
for offset in range(1, gen_sequence_len): # start_offset_sequence=1
|
346 |
-
# print(f'{_gen_sequence.shape=}') # [1,4,16]
|
347 |
-
# starts from 1 not 0 thus uses the 0:1 as curr sequence
|
348 |
-
# although this is empty contains -1 ?
|
349 |
|
350 |
-
|
351 |
-
|
352 |
-
|
353 |
-
|
354 |
-
next_token = self._sample_next_token(
|
355 |
-
curr_sequence,
|
356 |
-
cfg_conditions,
|
357 |
-
unconditional_state) # [5, 4, 1]
|
358 |
-
|
359 |
-
|
360 |
-
|
361 |
-
|
362 |
-
|
363 |
-
|
364 |
-
|
365 |
-
# RUNS with = 2047 just different of self.special_token_id = 2047 = alwayssingletoken = drill noise
|
366 |
-
# special_token_id is filler for CODEBOOK_PATTERN ?
|
367 |
-
|
368 |
-
# next_token[:] = self.special_token_id # seanet.embed torch.embedding does not have this - out of bounds in detokenize
|
369 |
-
|
370 |
-
_gen_sequence[..., offset:offset+1] = next_token[0, :, :] #gen_sequence.shape=torch.Size([1, 4, 39])
|
371 |
-
|
372 |
-
duplicate_draw.append(next_token)
|
373 |
-
|
374 |
-
prev_offset = offset
|
375 |
-
|
376 |
-
|
377 |
-
|
378 |
-
unconditional_state.clear()
|
379 |
-
|
380 |
-
gen_sequence = torch.cat(duplicate_draw, 2) # [self.n_draw, 4, len_seq]
|
381 |
|
382 |
# revert codes as "batch"
|
383 |
|
@@ -415,4 +399,4 @@ class LMModel(StreamingModule):
|
|
415 |
|
416 |
|
417 |
|
418 |
-
return out_codes #
|
|
|
6 |
import typing as tp
|
7 |
import torch
|
8 |
import torch.nn.functional as F
|
|
|
9 |
from audiocraft.transformer import StreamingTransformer, create_norm_fn
|
10 |
from dataclasses import dataclass
|
11 |
from functools import partial
|
|
|
108 |
mask: torch.Tensor # [B, K, T]
|
109 |
|
110 |
|
111 |
+
class LMModel(nn.Module):
|
112 |
"""Transformer-based language model on multiple streams of codes.
|
113 |
|
114 |
Args:
|
|
|
147 |
super().__init__()
|
148 |
self.cfg_coef = cfg_coef
|
149 |
|
150 |
+
self.n_draw = 5
|
151 |
self.condition_provider = condition_provider
|
152 |
self.fuser = fuser
|
153 |
self.card = card # 2048 ?
|
|
|
159 |
self.emb = nn.ModuleList([ScaledEmbedding(embed_dim, dim, lr=emb_lr) for _ in range(n_q)])
|
160 |
if 'activation' in kwargs:
|
161 |
kwargs['activation'] = get_activation_fn(kwargs['activation'])
|
162 |
+
# ========================================================================
|
163 |
+
# {
|
164 |
+
# 'dtype': torch.float16, 'device': 'cuda',
|
165 |
+
# 'num_layers': 48, 'dropout': 0.0, 'activation': 'gelu',
|
166 |
+
# 'bias_ff': False, 'bias_attn': False,
|
167 |
+
# 'past_context': None, 'causal': True,
|
168 |
+
# 'custom': False, 'memory_efficient': True,
|
169 |
+
# 'attention_as_float32': False, 'positional_embedding': 'sin', 'xpos': False,
|
170 |
+
# 'checkpointing': 'none', 'cross_attention': True, 'qk_layer_norm': False,
|
171 |
+
# 'qk_layer_norm_cross': False, 'attention_dropout': None, 'kv_repeat': 1
|
172 |
+
# }
|
173 |
+
# ==========================================================================
|
174 |
+
kwargs.pop('layer_scale') # nn.Indentity()
|
175 |
+
|
176 |
self.transformer = StreamingTransformer(
|
177 |
+
d_model=dim,
|
178 |
+
num_heads=num_heads,
|
179 |
+
dim_feedforward=int(hidden_scale * dim),
|
180 |
+
norm=norm,
|
181 |
+
norm_first=norm_first, **kwargs)
|
182 |
self.out_norm: tp.Optional[nn.Module] = None
|
183 |
if norm_first:
|
184 |
self.out_norm = create_norm_fn(norm, dim)
|
|
|
215 |
depth = layer_idx + 1
|
216 |
elif depthwise_init == 'global':
|
217 |
depth = len(self.transformer.layers)
|
218 |
+
init_fn = partial(init_layer,
|
219 |
+
method=weight_init,
|
220 |
+
init_depth=depth,
|
221 |
+
zero_bias_init=zero_bias_init)
|
222 |
tr_layer.apply(init_fn)
|
223 |
|
224 |
for linear in self.linears:
|
|
|
234 |
|
235 |
def forward(self,
|
236 |
sequence,
|
|
|
237 |
condition_tensors=None,
|
238 |
stage = -1):
|
239 |
+
B, K, S = sequence.shape # linears are n_q
|
|
|
240 |
|
241 |
input_ = sum([self.emb[k](sequence[:, k]) for k in range(K)])
|
242 |
|
243 |
|
244 |
+
# input_, cross_attention_input = self.fuser(input_, condition_tensors)
|
245 |
+
cross_attention_input = condition_tensors['description'][0]
|
246 |
+
# print(f'{input_.shape=} {cross_attention_input.shape=} FUSER LLM')
|
247 |
+
|
248 |
|
249 |
out = self.transformer(input_, cross_attention_src=cross_attention_input,
|
250 |
src_mask=(self.attn_mask_per_stage[stage] if stage >= 0 else None))
|
251 |
if self.out_norm:
|
252 |
out = self.out_norm(out)
|
253 |
+
# K = 2 because of llm producing 2 tokens?
|
254 |
+
# so only 2 x sel.flinear() of 4 are used ?
|
255 |
+
# WHy torch.stack is in dim=1
|
256 |
logits = torch.stack([self.linears[k](out) for k in range(K)], dim=1) # [B, K, S, card]
|
257 |
+
print(f'{input_.shape=} {out.shape=} {cross_attention_input.shape=} {logits.shape=} FUSER LLM')
|
258 |
# remove the prefix from the model outputs
|
259 |
+
# if len(self.fuser.fuse2cond['prepend']) > 0:
|
260 |
+
# logits = logits[:, :, -S:]
|
261 |
+
# print('==========================================PRESFIX')
|
262 |
|
263 |
return logits # [B, K, S, card]
|
264 |
|
265 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
266 |
# GENERATE class revert_codebook_patterns()
|
267 |
@torch.no_grad()
|
268 |
def generate(self,
|
269 |
prompt = None,
|
270 |
conditions = [],
|
271 |
+
num_samples = 1, # N next token
|
272 |
+
max_gen_len=256):
|
|
|
|
|
273 |
|
274 |
+
print(f'{prompt=} {conditions=}')
|
275 |
first_param = next(iter(self.parameters()))
|
276 |
device = first_param.device
|
277 |
|
278 |
+
|
279 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
280 |
tokenized = self.condition_provider.tokenize(conditions)
|
281 |
+
# print('TOKENIZ', tokenized) # 'description'
|
282 |
+
# TOKENIZ {'description': {'input_ids': tensor([[3887, 16, 2815, 1],
|
283 |
+
# [3887, 16, 2815, 1]], device='cuda:0'), 'attention_mask': tensor([[1, 1, 1, 1],
|
284 |
+
# [1, 1, 1, 1]], device='cuda:0')}}
|
285 |
+
|
286 |
cfg_conditions = self.condition_provider(tokenized)
|
287 |
|
288 |
|
|
|
309 |
|
310 |
|
311 |
|
312 |
+
|
313 |
+
|
314 |
+
# --
|
315 |
+
# print(mask.shape, mask.sum(), 'MSK LM')
|
316 |
+
# torch.Size([4, 39]) tensor(140, device='cuda:0') MSK LM ? Fully 1 normal no special token
|
317 |
+
# --\
|
318 |
|
319 |
+
# list - Elongation for take-5 next tokens - n_draw 5 tokens at each time-step
|
320 |
+
# append them at end of sequence
|
321 |
+
duplicate_draw = [
|
322 |
+
_gen_sequence[:, :, 0:1].repeat(self.n_draw, 1, 1)
|
323 |
+
]
|
324 |
+
|
325 |
+
|
326 |
+
for offset in range(1, _gen_sequence.shape[2]): # gen_sequence shape is [B, K, S]):
|
327 |
+
# print(f'{_gen_sequence.shape=}') # [1,4,16]
|
328 |
+
# starts from 1 not 0 thus uses the 0:1 as curr sequence
|
329 |
+
# although this is empty contains -1 ?
|
330 |
+
|
331 |
+
|
332 |
|
333 |
+
|
334 |
+
# ====================== SAMPLE NEXT TOK
|
335 |
+
# next_token = self._sample_next_token(
|
336 |
+
# _gen_sequence[..., :offset],
|
337 |
+
# cfg_conditions) # [5, 4, 1]
|
338 |
# --
|
339 |
+
# def _sample_next_token(self,
|
340 |
+
# sequence,
|
341 |
+
# cfg_conditions):
|
342 |
+
model = self if self._fsdp is None else self._fsdp
|
343 |
+
|
344 |
+
logits = model(_gen_sequence[..., :offset],
|
345 |
+
condition_tensors=cfg_conditions)
|
346 |
+
# print(logits.shape, 'Next Logits') # [1, 4, 2, 2048] why 2 tokens on query
|
347 |
+
|
348 |
+
# use cfg
|
349 |
+
# logits = (3 * logits[1, :, :, :] - 2.4 * logits[0, :, :, :]).transpose(1,0)
|
350 |
+
|
351 |
+
# or use 1 of logits
|
352 |
+
logits = logits[0, :, 0:1, :] # [1,4,2048]
|
353 |
+
next_token = utils.sample_top_k(logits, n_draw=self.n_draw) # [1,4,2048] logits
|
354 |
+
# =================================
|
355 |
+
|
356 |
+
|
357 |
+
_gen_sequence[:, :, offset] = next_token[0, :, 0] #gen_sequence.shape=torch.Size([1, 4, 39])
|
358 |
+
|
359 |
+
duplicate_draw.append(next_token)
|
360 |
|
|
|
|
|
|
|
|
|
361 |
|
362 |
+
|
363 |
+
|
364 |
+
gen_sequence = torch.cat(duplicate_draw, 2) # RESHAPE -> N_DRAW -> TIME
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
365 |
|
366 |
# revert codes as "batch"
|
367 |
|
|
|
399 |
|
400 |
|
401 |
|
402 |
+
return out_codes #
|
audiocraft/streaming.py
DELETED
@@ -1,131 +0,0 @@
|
|
1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
-
# All rights reserved.
|
3 |
-
#
|
4 |
-
# This source code is licensed under the license found in the
|
5 |
-
# LICENSE file in the root directory of this source tree.
|
6 |
-
|
7 |
-
"""
|
8 |
-
Streaming module API that should be implemented by all Streaming components,
|
9 |
-
"""
|
10 |
-
|
11 |
-
from contextlib import contextmanager
|
12 |
-
import typing as tp
|
13 |
-
from torch import nn
|
14 |
-
import torch
|
15 |
-
|
16 |
-
|
17 |
-
State = tp.Dict[str, torch.Tensor]
|
18 |
-
|
19 |
-
|
20 |
-
class StreamingModule(nn.Module):
|
21 |
-
"""Common API for streaming components.
|
22 |
-
|
23 |
-
Each streaming component has a streaming state, which is just a dict[str, Tensor].
|
24 |
-
By convention, the first dim of each tensor must be the batch size.
|
25 |
-
Don't use dots in the key names, as this would clash with submodules
|
26 |
-
(like in state_dict).
|
27 |
-
|
28 |
-
If `self._is_streaming` is True, the component should use and remember
|
29 |
-
the proper state inside `self._streaming_state`.
|
30 |
-
|
31 |
-
To set a streaming component in streaming state, use
|
32 |
-
|
33 |
-
with module.streaming():
|
34 |
-
...
|
35 |
-
|
36 |
-
This will automatically reset the streaming state when exiting the context manager.
|
37 |
-
This also automatically propagates to all streaming children module.
|
38 |
-
|
39 |
-
Some module might also implement the `StreamingModule.flush` method, although
|
40 |
-
this one is trickier, as all parents module must be StreamingModule and implement
|
41 |
-
it as well for it to work properly. See `StreamingSequential` after.
|
42 |
-
"""
|
43 |
-
def __init__(self) -> None:
|
44 |
-
super().__init__()
|
45 |
-
self._streaming_state: State = {}
|
46 |
-
self._is_streaming = False
|
47 |
-
|
48 |
-
def _apply_named_streaming(self, fn: tp.Any):
|
49 |
-
for name, module in self.named_modules():
|
50 |
-
if isinstance(module, StreamingModule):
|
51 |
-
fn(name, module)
|
52 |
-
|
53 |
-
def _set_streaming(self, streaming: bool):
|
54 |
-
def _set_streaming(name, module):
|
55 |
-
module._is_streaming = streaming
|
56 |
-
self._apply_named_streaming(_set_streaming)
|
57 |
-
|
58 |
-
@contextmanager
|
59 |
-
def streaming(self):
|
60 |
-
"""Context manager to enter streaming mode. Reset streaming state on exit."""
|
61 |
-
self._set_streaming(True)
|
62 |
-
try:
|
63 |
-
yield
|
64 |
-
finally:
|
65 |
-
self._set_streaming(False)
|
66 |
-
self.reset_streaming()
|
67 |
-
|
68 |
-
def reset_streaming(self):
|
69 |
-
"""Reset the streaming state."""
|
70 |
-
def _reset(name: str, module: StreamingModule):
|
71 |
-
module._streaming_state.clear()
|
72 |
-
|
73 |
-
self._apply_named_streaming(_reset)
|
74 |
-
|
75 |
-
def get_streaming_state(self) -> State:
|
76 |
-
"""Return the streaming state, including that of sub-modules."""
|
77 |
-
state: State = {}
|
78 |
-
|
79 |
-
def _add(name: str, module: StreamingModule):
|
80 |
-
if name:
|
81 |
-
name += "."
|
82 |
-
for key, value in module._streaming_state.items():
|
83 |
-
state[name + key] = value
|
84 |
-
|
85 |
-
self._apply_named_streaming(_add)
|
86 |
-
return state
|
87 |
-
|
88 |
-
def set_streaming_state(self, state: State):
|
89 |
-
"""Set the streaming state, including that of sub-modules."""
|
90 |
-
state = dict(state)
|
91 |
-
|
92 |
-
def _set(name: str, module: StreamingModule):
|
93 |
-
if name:
|
94 |
-
name += "."
|
95 |
-
module._streaming_state.clear()
|
96 |
-
for key, value in list(state.items()):
|
97 |
-
# complexity is not ideal here, but probably fine.
|
98 |
-
if key.startswith(name):
|
99 |
-
local_key = key[len(name):]
|
100 |
-
if '.' not in local_key:
|
101 |
-
module._streaming_state[local_key] = value
|
102 |
-
del state[key]
|
103 |
-
|
104 |
-
self._apply_named_streaming(_set)
|
105 |
-
assert len(state) == 0, list(state.keys())
|
106 |
-
|
107 |
-
def flush(self, x: tp.Optional[torch.Tensor] = None):
|
108 |
-
"""Flush any remaining outputs that were waiting for completion.
|
109 |
-
Typically, for convolutions, this will add the final padding
|
110 |
-
and process the last buffer.
|
111 |
-
|
112 |
-
This should take an optional argument `x`, which will be provided
|
113 |
-
if a module before this one in the streaming pipeline has already
|
114 |
-
spitted out a flushed out buffer.
|
115 |
-
"""
|
116 |
-
if x is None:
|
117 |
-
return None
|
118 |
-
else:
|
119 |
-
return self(x)
|
120 |
-
|
121 |
-
|
122 |
-
class StreamingSequential(StreamingModule, nn.Sequential):
|
123 |
-
"""A streaming compatible alternative of `nn.Sequential`.
|
124 |
-
"""
|
125 |
-
def flush(self, x: tp.Optional[torch.Tensor] = None):
|
126 |
-
for module in self:
|
127 |
-
if isinstance(module, StreamingModule):
|
128 |
-
x = module.flush(x)
|
129 |
-
elif x is not None:
|
130 |
-
x = module(x)
|
131 |
-
return x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
audiocraft/transformer.py
CHANGED
@@ -1,30 +1,10 @@
|
|
1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
-
# All rights reserved.
|
3 |
-
#
|
4 |
-
# This source code is licensed under the license found in the
|
5 |
-
# LICENSE file in the root directory of this source tree.
|
6 |
-
|
7 |
-
"""
|
8 |
-
Transformer model, with streaming support, xformer attention support
|
9 |
-
and easy causal attention with a potentially finite receptive field.
|
10 |
-
|
11 |
-
See `StreamingTransformer` for more information.
|
12 |
-
|
13 |
-
Unlike regular PyTorch Transformer, we make the hard choice that batches are first.
|
14 |
-
"""
|
15 |
-
|
16 |
import typing as tp
|
17 |
-
|
18 |
from einops import rearrange
|
19 |
import torch
|
20 |
import torch.nn as nn
|
21 |
from torch.nn import functional as F
|
22 |
-
from torch.utils.checkpoint import checkpoint as torch_checkpoint
|
23 |
from xformers import ops
|
24 |
|
25 |
-
from .rope import RotaryEmbedding
|
26 |
-
from .streaming import StreamingModule
|
27 |
-
|
28 |
_efficient_attention_backend: str = 'torch'
|
29 |
|
30 |
|
@@ -35,14 +15,10 @@ def set_efficient_attention_backend(backend: str = 'torch'):
|
|
35 |
_efficient_attention_backend = backend
|
36 |
|
37 |
|
38 |
-
def _get_attention_time_dimension(memory_efficient: bool) -> int:
|
39 |
-
if _efficient_attention_backend == 'torch' and memory_efficient:
|
40 |
-
return 2
|
41 |
-
else:
|
42 |
-
return 1
|
43 |
|
44 |
|
45 |
-
|
|
|
46 |
# Return true if we are currently running with a xformers profiler activated.
|
47 |
try:
|
48 |
from xformers.profiler import profiler
|
@@ -51,16 +27,8 @@ def _is_profiled() -> bool:
|
|
51 |
return profiler._Profiler._CURRENT_PROFILER is not None
|
52 |
|
53 |
|
54 |
-
def create_norm_fn(norm_type
|
55 |
-
"""Create normalization module for transformer encoder layer.
|
56 |
|
57 |
-
Args:
|
58 |
-
norm_type (str): Normalization method.
|
59 |
-
dim (int): Dimension of the normalized layer.
|
60 |
-
**kwargs (dict): Additional parameters for normalization layer.
|
61 |
-
Returns:
|
62 |
-
nn.Module: Normalization module.
|
63 |
-
"""
|
64 |
if norm_type == 'layer_norm':
|
65 |
return nn.LayerNorm(dim, eps=1e-5, **kwargs)
|
66 |
else:
|
@@ -86,87 +54,26 @@ def create_sin_embedding(positions: torch.Tensor, dim: int, max_period: float =
|
|
86 |
adim = torch.arange(half_dim, device=positions.device, dtype=dtype).view(1, 1, -1)
|
87 |
max_period_tensor = torch.full([], max_period, device=positions.device, dtype=dtype) # avoid sync point
|
88 |
phase = positions / (max_period_tensor ** (adim / (half_dim - 1)))
|
89 |
-
print('==============CONCAT 3 ============'
|
90 |
-
return torch.cat([torch.cos(phase), torch.sin(phase)], dim=-1)
|
91 |
|
92 |
|
93 |
-
def expand_repeated_kv(x: torch.Tensor, n_rep: int, memory_efficient: bool) -> torch.Tensor:
|
94 |
-
"""torch.repeat_interleave(x, dim=2, repeats=n_rep) from xlformers."""
|
95 |
-
if n_rep == 1:
|
96 |
-
return x
|
97 |
-
if _efficient_attention_backend == 'torch' and memory_efficient:
|
98 |
-
bs, n_kv_heads, slen, head_dim = x.shape
|
99 |
-
return (
|
100 |
-
x[:, :, None, :, :]
|
101 |
-
.expand(bs, n_kv_heads, n_rep, slen, head_dim)
|
102 |
-
.reshape(bs, n_kv_heads * n_rep, slen, head_dim)
|
103 |
-
)
|
104 |
-
else:
|
105 |
-
bs, slen, n_kv_heads, head_dim = x.shape
|
106 |
-
return (
|
107 |
-
x[:, :, :, None, :]
|
108 |
-
.expand(bs, slen, n_kv_heads, n_rep, head_dim)
|
109 |
-
.reshape(bs, slen, n_kv_heads * n_rep, head_dim)
|
110 |
-
)
|
111 |
|
112 |
|
113 |
-
class LayerScale(nn.Module):
|
114 |
-
"""Layer scale from [Touvron et al 2021] (https://arxiv.org/pdf/2103.17239.pdf).
|
115 |
-
This rescales diagonally the residual outputs close to 0, with a learnt scale.
|
116 |
|
117 |
-
Args:
|
118 |
-
channels (int): Number of channels.
|
119 |
-
init (float): Initial scale.
|
120 |
-
channel_last (bool): If True, expect `[*, C]` shaped tensors, otherwise, `[*, C, T]`.
|
121 |
-
device (torch.device or str, optional): Device on which to initialize the module.
|
122 |
-
dtype (torch.dtype, optional): dtype to use to initialize the module.
|
123 |
-
"""
|
124 |
-
def __init__(self, channels: int, init: float = 1e-4, channel_last: bool = True,
|
125 |
-
device=None, dtype=None):
|
126 |
-
super().__init__()
|
127 |
-
self.channel_last = channel_last
|
128 |
-
self.scale = nn.Parameter(
|
129 |
-
torch.full((channels,), init,
|
130 |
-
requires_grad=True, device=device, dtype=dtype))
|
131 |
-
|
132 |
-
def forward(self, x: torch.Tensor):
|
133 |
-
if self.channel_last:
|
134 |
-
return self.scale * x
|
135 |
-
else:
|
136 |
-
return self.scale[:, None] * x
|
137 |
|
138 |
|
139 |
-
class StreamingMultiheadAttention(StreamingModule):
|
140 |
-
"""Similar to `nn.MultiheadAttention` but with support for streaming, causal evaluation.
|
141 |
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
past_context (int, optional): Receptive field for the causal mask, infinite if None.
|
149 |
-
custom (bool): Use custom MHA implementation, for testing / benchmarking.
|
150 |
-
memory_efficient (bool): Use xformers based memory efficient attention.
|
151 |
-
attention_as_float32 (bool): Perform the attention as float32
|
152 |
-
(especially important with memory_efficient as autocast won't do this automatically).
|
153 |
-
rope (`RotaryEmbedding`, optional): Rope embedding to use.
|
154 |
-
cross_attention: Should be true when used as a cross attention.
|
155 |
-
All keys and values must be available at once, streaming is only for the queries.
|
156 |
-
Cannot be used with `causal` or `rope` (as it wouldn't make sens to
|
157 |
-
interpret the time steps in the keys relative to those in the queries).
|
158 |
-
safe_streaming (bool): Bug fix, will go away with xformers update.
|
159 |
-
qk_layer_norm (bool): Layer normalization applied to queries and keys before dot product.
|
160 |
-
kv_repeat (int): If > 1, will repeat keys and queries multiple times (need to divide num_heads).
|
161 |
-
This will lead to faster decoding time on A100 or other GPUs with tensorcore.
|
162 |
-
device (torch.device, optional): Device on which to initialize.
|
163 |
-
dtype (torch.dtype, optional): dtype to use.
|
164 |
-
"""
|
165 |
-
def __init__(self, embed_dim: int, num_heads: int, dropout: float = 0.0, bias: bool = True,
|
166 |
causal: bool = False, past_context: tp.Optional[int] = None, custom: bool = False,
|
167 |
memory_efficient: bool = False, attention_as_float32: bool = False,
|
168 |
-
|
169 |
-
|
170 |
device=None, dtype=None):
|
171 |
super().__init__()
|
172 |
factory_kwargs = {'device': device, 'dtype': dtype}
|
@@ -178,15 +85,15 @@ class StreamingMultiheadAttention(StreamingModule):
|
|
178 |
self.past_context = past_context
|
179 |
self.memory_efficient = memory_efficient
|
180 |
self.attention_as_float32 = attention_as_float32
|
181 |
-
|
182 |
self.cross_attention = cross_attention
|
183 |
-
|
184 |
self.num_heads = num_heads
|
185 |
self.dropout = dropout
|
186 |
self.kv_repeat = kv_repeat
|
187 |
if cross_attention:
|
188 |
assert not causal, "Causal cannot work with cross attention."
|
189 |
-
|
190 |
|
191 |
if memory_efficient:
|
192 |
_verify_xformers_memory_efficient_compat()
|
@@ -231,123 +138,42 @@ class StreamingMultiheadAttention(StreamingModule):
|
|
231 |
state_dict[prefix + "mha." + key] = state_dict.pop(prefix + key)
|
232 |
super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)
|
233 |
|
234 |
-
def _get_mask(self, current_steps: int, device: torch.device, dtype: torch.dtype):
|
235 |
-
# Return a causal mask, accounting for potentially stored past keys/values
|
236 |
-
# We actually return a bias for the attention score, as this has the same
|
237 |
-
# convention both in the builtin MHA in Pytorch, and Xformers functions.
|
238 |
-
time_dim = _get_attention_time_dimension(self.memory_efficient)
|
239 |
-
if self.memory_efficient:
|
240 |
-
from xformers.ops import LowerTriangularMask
|
241 |
-
if current_steps == 1:
|
242 |
-
# If we only have one step, then we do not need a mask.
|
243 |
-
return None
|
244 |
-
elif 'past_keys' in self._streaming_state:
|
245 |
-
raise RuntimeError("Not supported at the moment")
|
246 |
-
else:
|
247 |
-
# Then we can safely use a lower triangular mask
|
248 |
-
return LowerTriangularMask()
|
249 |
-
if self._streaming_state:
|
250 |
-
past_keys = self._streaming_state['past_keys']
|
251 |
-
past_steps = past_keys.shape[time_dim]
|
252 |
-
else:
|
253 |
-
past_steps = 0
|
254 |
-
|
255 |
-
queries_pos = torch.arange(
|
256 |
-
past_steps, current_steps + past_steps, device=device).view(-1, 1)
|
257 |
-
keys_pos = torch.arange(past_steps + current_steps, device=device).view(1, -1)
|
258 |
-
delta = queries_pos - keys_pos
|
259 |
-
valid = delta >= 0
|
260 |
-
if self.past_context is not None:
|
261 |
-
valid &= (delta <= self.past_context)
|
262 |
-
return torch.where(
|
263 |
-
valid,
|
264 |
-
torch.zeros([], device=device, dtype=dtype),
|
265 |
-
torch.full([], float('-inf'), device=device, dtype=dtype))
|
266 |
-
|
267 |
-
def _complete_kv(self, k, v):
|
268 |
-
|
269 |
-
time_dim = _get_attention_time_dimension(self.memory_efficient)
|
270 |
-
if self.cross_attention:
|
271 |
-
# With cross attention we assume all keys and values
|
272 |
-
# are already available, and streaming is with respect
|
273 |
-
# to the queries only.
|
274 |
-
return k, v
|
275 |
-
# Complete the key/value pair using the streaming state.
|
276 |
-
if self._streaming_state:
|
277 |
-
pk = self._streaming_state['past_keys']
|
278 |
-
nk = torch.cat([pk, k], dim=time_dim)
|
279 |
-
print('==============CONCAT 1===============')
|
280 |
-
if v is k:
|
281 |
-
nv = nk
|
282 |
-
else:
|
283 |
-
pv = self._streaming_state['past_values']
|
284 |
-
nv = torch.cat([pv, v], dim=time_dim)
|
285 |
-
print('==============CONCAT 2================')
|
286 |
-
else:
|
287 |
-
nk = k
|
288 |
-
nv = v
|
289 |
-
|
290 |
-
assert nk.shape[time_dim] == nv.shape[time_dim]
|
291 |
-
offset = 0
|
292 |
-
if self.past_context is not None:
|
293 |
-
offset = max(0, nk.shape[time_dim] - self.past_context)
|
294 |
-
if self._is_streaming:
|
295 |
-
self._streaming_state['past_keys'] = nk[:, offset:]
|
296 |
-
if v is not k:
|
297 |
-
self._streaming_state['past_values'] = nv[:, offset:]
|
298 |
-
if 'offset' in self._streaming_state:
|
299 |
-
self._streaming_state['offset'] += offset
|
300 |
-
else:
|
301 |
-
self._streaming_state['offset'] = torch.tensor(0)
|
302 |
-
return nk, nv
|
303 |
-
|
304 |
-
def _apply_rope(self, query: torch.Tensor, key: torch.Tensor):
|
305 |
-
time_dim = _get_attention_time_dimension(self.memory_efficient)
|
306 |
-
# Apply rope embeddings to query and key tensors.
|
307 |
-
assert self.rope is not None
|
308 |
-
if 'past_keys' in self._streaming_state:
|
309 |
-
past_keys_offset = self._streaming_state['past_keys'].shape[1]
|
310 |
-
else:
|
311 |
-
past_keys_offset = 0
|
312 |
-
if 'offset' in self._streaming_state:
|
313 |
-
past_context_offset = int(self._streaming_state['offset'].item())
|
314 |
-
else:
|
315 |
-
past_context_offset = 0
|
316 |
-
streaming_offset = past_context_offset + past_keys_offset
|
317 |
-
return self.rope.rotate_qk(query, key, start=streaming_offset, time_dim=time_dim)
|
318 |
|
319 |
-
|
320 |
-
|
321 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
322 |
assert not is_causal, ("New param added in torch 2.0.1 not supported, "
|
323 |
"use the causal args in the constructor.")
|
324 |
-
|
325 |
-
time_dim =
|
326 |
if time_dim == 2:
|
327 |
layout = "b h t d"
|
328 |
else:
|
329 |
layout = "b t h d"
|
330 |
dtype = query.dtype
|
331 |
-
|
332 |
-
assert self.causal or self.cross_attention, \
|
333 |
-
"Streaming only available for causal or cross attention"
|
334 |
|
335 |
custom_attn_mask = attn_mask is not None
|
336 |
|
337 |
-
if self.causal:
|
338 |
-
assert attn_mask is None
|
339 |
-
# At the moment we specialize only for the self-attention case.
|
340 |
-
assert query.shape[1] == key.shape[1], "Causal only for same length query / key / value"
|
341 |
-
assert value.shape[1] == key.shape[1], "Causal only for same length query / key / value"
|
342 |
-
attn_mask = self._get_mask(query.shape[1], query.device, query.dtype)
|
343 |
-
|
344 |
if self.custom:
|
345 |
# custom implementation
|
346 |
assert need_weights is False
|
347 |
assert key_padding_mask is None
|
348 |
if self.cross_attention:
|
349 |
-
#
|
350 |
-
|
|
|
351 |
dim = self.in_proj_weight.shape[0] // 3
|
352 |
if self.in_proj_bias is None:
|
353 |
bias_q, bias_k, bias_v = None, None, None
|
@@ -356,14 +182,23 @@ class StreamingMultiheadAttention(StreamingModule):
|
|
356 |
bias_k = self.in_proj_bias[dim: 2 * dim]
|
357 |
bias_v = self.in_proj_bias[2 * dim:]
|
358 |
q = nn.functional.linear(query, self.in_proj_weight[:dim], bias_q)
|
|
|
359 |
# todo: when streaming, we could actually save k, v and check the shape actually match.
|
360 |
k = nn.functional.linear(key, self.in_proj_weight[dim: 2 * dim], bias_k)
|
361 |
v = nn.functional.linear(value, self.in_proj_weight[2 * dim:], bias_v)
|
362 |
if self.qk_layer_norm is True:
|
363 |
q = self.q_layer_norm(q)
|
364 |
k = self.k_layer_norm(k)
|
|
|
365 |
q, k, v = [rearrange(x, f"b t (h d) -> {layout}", h=self.num_heads) for x in [q, k, v]]
|
|
|
366 |
else:
|
|
|
|
|
|
|
|
|
|
|
|
|
367 |
if not _is_profiled():
|
368 |
# profiling breaks that propertysomehow.
|
369 |
assert query is key, "specialized implementation"
|
@@ -374,8 +209,13 @@ class StreamingMultiheadAttention(StreamingModule):
|
|
374 |
bound_layout = "b h p t d"
|
375 |
else:
|
376 |
bound_layout = "b t p h d"
|
|
|
377 |
packed = rearrange(projected, f"b t (p h d) -> {bound_layout}", p=3, h=self.num_heads)
|
|
|
|
|
|
|
378 |
q, k, v = ops.unbind(packed, dim=2)
|
|
|
379 |
else:
|
380 |
embed_dim = self.embed_dim
|
381 |
per_head_dim = (embed_dim // self.num_heads)
|
@@ -395,12 +235,12 @@ class StreamingMultiheadAttention(StreamingModule):
|
|
395 |
q = self.q_layer_norm(q)
|
396 |
k = self.k_layer_norm(k)
|
397 |
q, k = [rearrange(x, f"b t (h d) -> {layout}", h=self.num_heads) for x in [q, k]]
|
398 |
-
|
399 |
-
|
400 |
-
k, v = self._complete_kv(k, v)
|
401 |
if self.kv_repeat > 1:
|
402 |
-
|
403 |
-
|
|
|
404 |
if self.attention_as_float32:
|
405 |
q, k, v = [x.float() for x in [q, k, v]]
|
406 |
if self.memory_efficient:
|
@@ -429,11 +269,8 @@ class StreamingMultiheadAttention(StreamingModule):
|
|
429 |
q = q / q.shape[-1] ** 0.5
|
430 |
key_layout = layout.replace('t', 'k')
|
431 |
query_layout = layout
|
432 |
-
|
433 |
-
|
434 |
-
pre_w = torch.einsum(f"{query_layout},{key_layout}-> b h t k", q, k)
|
435 |
-
else:
|
436 |
-
pre_w = torch.einsum(f"{query_layout},{key_layout}-> b h t k", q, k)
|
437 |
if attn_mask is not None:
|
438 |
pre_w = pre_w + attn_mask
|
439 |
w = torch.softmax(pre_w, dim=-1)
|
@@ -444,58 +281,24 @@ class StreamingMultiheadAttention(StreamingModule):
|
|
444 |
x = rearrange(x, f"{layout} -> b t (h d)", h=self.num_heads)
|
445 |
x = self.out_proj(x)
|
446 |
else:
|
447 |
-
|
448 |
-
if self.attention_as_float32:
|
449 |
-
query, key, value = [x.float() for x in [query, key, value]]
|
450 |
-
x, _ = self.mha(
|
451 |
-
query, key, value, key_padding_mask,
|
452 |
-
need_weights, attn_mask, average_attn_weights)
|
453 |
-
x = x.to(dtype)
|
454 |
|
455 |
return x, None
|
456 |
|
457 |
|
458 |
class StreamingTransformerLayer(nn.TransformerEncoderLayer):
|
459 |
-
|
460 |
-
|
461 |
-
|
462 |
-
|
463 |
-
|
464 |
-
d_model (int): Dimension of the data.
|
465 |
-
num_heads (int): Number of heads.
|
466 |
-
dim_feedforward (int): Intermediate dimension of FF module.
|
467 |
-
dropout (float): Dropout both for MHA and FF.
|
468 |
-
bias_ff (bool): Use bias for FF.
|
469 |
-
bias_attn (bool): Use bias for MHA.
|
470 |
-
causal (bool): Causal mask applied automatically.
|
471 |
-
past_context (int, optional): Receptive field for the causal mask, infinite if None.
|
472 |
-
custom (bool): Use custom MHA implementation, for testing / benchmarking.
|
473 |
-
memory_efficient (bool): Use xformers based memory efficient attention.
|
474 |
-
attention_as_float32 (bool): Perform the attention as float32
|
475 |
-
(especially important with memory_efficient as autocast won't do this automatically).
|
476 |
-
qk_layer_norm (bool): Layer normalization applied to queries and keys before dot product in attention.
|
477 |
-
qk_layer_norm_cross (bool): Same for the cross attention.
|
478 |
-
cross_attention (bool): If True, expect to get secondary input for cross-attention.
|
479 |
-
Cross attention will use the default MHA, as it typically won't require
|
480 |
-
special treatment.
|
481 |
-
layer_scale (float, optional): If not None, LayerScale will be used with
|
482 |
-
the given value as initial scale.
|
483 |
-
rope (`RotaryEmbedding`, optional): Rope embedding to use.
|
484 |
-
attention_dropout (float, optional): If not None, separate the value of the dimension dropout
|
485 |
-
in FFN and of the attention dropout.
|
486 |
-
kv_repeat (int): If > 1, will repeat keys and queries multiple times (need to divide num_heads).
|
487 |
-
This will lead to faster decoding time on A100 or other GPUs with tensorcore.
|
488 |
-
device (torch.device, optional): Device on which to initialize.
|
489 |
-
dtype (torch.dtype, optional): dtype to use.
|
490 |
-
**kwargs: See `nn.TransformerEncoderLayer`.
|
491 |
-
"""
|
492 |
-
def __init__(self, d_model: int, num_heads: int, dim_feedforward: int = 2048, dropout: float = 0.1,
|
493 |
bias_ff: bool = True, bias_attn: bool = True, causal: bool = False,
|
494 |
past_context: tp.Optional[int] = None, custom: bool = False,
|
495 |
memory_efficient: bool = False, attention_as_float32: bool = False,
|
496 |
qk_layer_norm: bool = False, qk_layer_norm_cross: bool = False,
|
497 |
-
cross_attention: bool = False,
|
498 |
-
|
|
|
499 |
kv_repeat: int = 1, norm: str = 'layer_norm', device=None, dtype=None, **kwargs):
|
500 |
super().__init__(d_model, num_heads, dim_feedforward, dropout,
|
501 |
device=device, dtype=dtype, batch_first=True, **kwargs)
|
@@ -511,22 +314,17 @@ class StreamingTransformerLayer(nn.TransformerEncoderLayer):
|
|
511 |
'attention_as_float32': attention_as_float32,
|
512 |
}
|
513 |
self.self_attn: StreamingMultiheadAttention = StreamingMultiheadAttention(
|
514 |
-
causal=causal, past_context=past_context,
|
|
|
|
|
515 |
kv_repeat=kv_repeat, **attn_kwargs, **factory_kwargs) # type: ignore
|
516 |
# Redefine feedforward layers to expose bias parameter
|
517 |
self.linear1 = nn.Linear(d_model, dim_feedforward, bias=bias_ff, **factory_kwargs)
|
518 |
self.linear2 = nn.Linear(dim_feedforward, d_model, bias=bias_ff, **factory_kwargs)
|
519 |
|
520 |
-
|
521 |
-
self.layer_scale_2: nn.Module
|
522 |
-
if layer_scale is None:
|
523 |
-
self.layer_scale_1 = nn.Identity()
|
524 |
-
self.layer_scale_2 = nn.Identity()
|
525 |
-
else:
|
526 |
-
self.layer_scale_1 = LayerScale(d_model, layer_scale, **factory_kwargs)
|
527 |
-
self.layer_scale_2 = LayerScale(d_model, layer_scale, **factory_kwargs)
|
528 |
|
529 |
-
self.cross_attention
|
530 |
if cross_attention:
|
531 |
self.cross_attention = StreamingMultiheadAttention(
|
532 |
cross_attention=True, qk_layer_norm=qk_layer_norm_cross,
|
@@ -535,98 +333,69 @@ class StreamingTransformerLayer(nn.TransformerEncoderLayer):
|
|
535 |
self.dropout_cross = nn.Dropout(dropout)
|
536 |
# eps value matching that used in PyTorch reference implementation.
|
537 |
self.norm_cross = nn.LayerNorm(d_model, eps=1e-5, **factory_kwargs)
|
538 |
-
|
539 |
-
if layer_scale is None:
|
540 |
-
self.layer_scale_cross = nn.Identity()
|
541 |
-
else:
|
542 |
-
self.layer_scale_cross = LayerScale(d_model, layer_scale, **factory_kwargs)
|
543 |
self.norm1 = create_norm_fn(norm, d_model, **factory_kwargs) # type: ignore
|
544 |
self.norm2 = create_norm_fn(norm, d_model, **factory_kwargs) # type: ignore
|
545 |
|
546 |
-
def _cross_attention_block(self,
|
547 |
-
|
548 |
-
|
|
|
549 |
# queries are from src, keys and values from cross_attention_src.
|
550 |
x = self.cross_attention(
|
551 |
src, cross_attention_src, cross_attention_src, need_weights=False)[0]
|
552 |
return self.dropout_cross(x) # type: ignore
|
553 |
|
554 |
-
def forward(self,
|
555 |
-
|
556 |
-
|
557 |
-
|
558 |
-
|
559 |
-
|
560 |
-
|
561 |
x = src
|
562 |
if self.norm_first:
|
563 |
-
|
564 |
-
|
|
|
|
|
|
|
|
|
565 |
if cross_attention_src is not None:
|
566 |
-
x = x + self.
|
567 |
-
|
568 |
-
|
569 |
-
|
|
|
|
|
|
|
|
|
570 |
else:
|
571 |
-
|
572 |
-
|
573 |
-
if cross_attention_src is not None:
|
574 |
-
x = self.norm_cross(
|
575 |
-
x + self.layer_scale_cross(
|
576 |
-
self._cross_attention_block(src, cross_attention_src)))
|
577 |
-
x = self.norm2(x + self.layer_scale_2(self._ff_block(x)))
|
578 |
return x
|
579 |
|
580 |
|
581 |
-
class StreamingTransformer(
|
582 |
-
|
583 |
-
|
584 |
-
Args:
|
585 |
-
d_model (int): Dimension of the data.
|
586 |
-
num_heads (int): Number of heads.
|
587 |
-
dim_feedforward (int): Intermediate dimension of FF module.
|
588 |
-
dropout (float): Dropout both for MHA and FF.
|
589 |
-
bias_ff (bool): Use bias for FF.
|
590 |
-
bias_attn (bool): Use bias for MHA.
|
591 |
-
causal (bool): Causal mask applied automatically.
|
592 |
-
past_context (int, optional): Receptive field for the causal mask, infinite if None.
|
593 |
-
custom (bool): Use custom MHA implementation, for testing / benchmarking.
|
594 |
-
memory_efficient (bool): Use xformers based memory efficient attention.
|
595 |
-
attention_as_float32 (bool): Perform the attention as float32
|
596 |
-
(especially important with memory_efficient as autocast won't do this automatically).
|
597 |
-
cross_attention (bool): If True, expect to get secondary input for cross-attention.
|
598 |
-
layer_scale (float, optional): If not None, LayerScale will be used
|
599 |
-
with the given value as initial scale.
|
600 |
-
positional_embedding (str): Positional embedding strategy (sin, rope, or sin_rope).
|
601 |
-
max_period (float): Maximum period of the time embedding.
|
602 |
-
positional_scale (float): Scale of positional embedding, set to 0 to deactivate.
|
603 |
-
xpos (bool): Apply xpos exponential decay to positional embedding (rope only).
|
604 |
-
lr (float, optional): learning rate override through the `make_optim_group` API.
|
605 |
-
weight_decay (float, optional): Weight_decay override through the `make_optim_group` API.
|
606 |
-
layer_class: (subclass of `StreamingTransformerLayer): class to use
|
607 |
-
to initialize the layers, allowing further customization outside of AudioCraft.
|
608 |
-
checkpointing (str): Checkpointing strategy to reduce memory usage.
|
609 |
-
No checkpointing if set to 'none'. Per layer checkpointing using PyTorch
|
610 |
-
if set to 'torch' (entire layer checkpointed, i.e. linears are evaluated twice,
|
611 |
-
minimal memory usage, but maximal runtime). Finally, `xformers_default` provide
|
612 |
-
a policy for opting-out some operations of the checkpointing like
|
613 |
-
linear layers and attention, providing a middle ground between speed and memory.
|
614 |
-
device (torch.device, optional): Device on which to initialize.
|
615 |
-
dtype (torch.dtype, optional): dtype to use.
|
616 |
-
**kwargs: See `nn.TransformerEncoderLayer`.
|
617 |
-
"""
|
618 |
def __init__(self, d_model: int, num_heads: int, num_layers: int, dim_feedforward: int = 2048,
|
619 |
dropout: float = 0.1, bias_ff: bool = True, bias_attn: bool = True,
|
620 |
causal: bool = False, past_context: tp.Optional[int] = None,
|
621 |
custom: bool = False, memory_efficient: bool = False, attention_as_float32: bool = False,
|
622 |
-
cross_attention: bool = False,
|
623 |
positional_embedding: str = 'sin', max_period: float = 10_000, positional_scale: float = 1.,
|
624 |
-
xpos
|
625 |
-
|
626 |
-
|
|
|
|
|
|
|
|
|
|
|
627 |
super().__init__()
|
628 |
assert d_model % num_heads == 0
|
629 |
-
|
630 |
self.positional_embedding = positional_embedding
|
631 |
self.max_period = max_period
|
632 |
self.positional_scale = positional_scale
|
@@ -634,12 +403,6 @@ class StreamingTransformer(StreamingModule):
|
|
634 |
self.lr = lr
|
635 |
|
636 |
assert positional_embedding in ['sin', 'rope', 'sin_rope']
|
637 |
-
self.rope: tp.Optional[RotaryEmbedding] = None
|
638 |
-
if self.positional_embedding in ['rope', 'sin_rope']:
|
639 |
-
assert _is_custom(custom, memory_efficient)
|
640 |
-
self.rope = RotaryEmbedding(d_model // num_heads, max_period=max_period,
|
641 |
-
xpos=xpos, scale=positional_scale, device=device)
|
642 |
-
|
643 |
self.checkpointing = checkpointing
|
644 |
|
645 |
assert checkpointing in ['none', 'torch', 'xformers_default', 'xformers_mm']
|
@@ -654,7 +417,8 @@ class StreamingTransformer(StreamingModule):
|
|
654 |
dropout=dropout, bias_ff=bias_ff, bias_attn=bias_attn,
|
655 |
causal=causal, past_context=past_context, custom=custom,
|
656 |
memory_efficient=memory_efficient, attention_as_float32=attention_as_float32,
|
657 |
-
cross_attention=cross_attention,
|
|
|
658 |
device=device, dtype=dtype, **kwargs))
|
659 |
|
660 |
if self.checkpointing != 'none':
|
@@ -663,58 +427,37 @@ class StreamingTransformer(StreamingModule):
|
|
663 |
# backward hook inside of FSDP...
|
664 |
layer._magma_checkpointed = True # type: ignore
|
665 |
|
666 |
-
|
667 |
-
method = self.checkpointing
|
668 |
-
if method == 'none':
|
669 |
-
return layer(*args, **kwargs)
|
670 |
-
elif method == 'torch':
|
671 |
-
return torch_checkpoint(layer, *args, use_reentrant=False, **kwargs)
|
672 |
-
elif method.startswith('xformers'):
|
673 |
-
from xformers.checkpoint_fairinternal import checkpoint, _get_default_policy
|
674 |
-
if method == 'xformers_default':
|
675 |
-
# those operations will be saved, and not recomputed.
|
676 |
-
# According to Francisco we can get smarter policies but this is a good start.
|
677 |
-
allow_list = [
|
678 |
-
"xformers.efficient_attention_forward_cutlass.default",
|
679 |
-
"xformers_flash.flash_fwd.default",
|
680 |
-
"aten.addmm.default",
|
681 |
-
"aten.mm.default",
|
682 |
-
]
|
683 |
-
elif method == 'xformers_mm':
|
684 |
-
# those operations will be saved, and not recomputed.
|
685 |
-
# According to Francisco we can get smarter policies but this is a good start.
|
686 |
-
allow_list = [
|
687 |
-
"aten.addmm.default",
|
688 |
-
"aten.mm.default",
|
689 |
-
]
|
690 |
-
else:
|
691 |
-
raise ValueError(f"xformers checkpointing xformers policy {method} is not known.")
|
692 |
-
policy_fn = _get_default_policy(allow_list)
|
693 |
-
return checkpoint(layer, *args, policy_fn=policy_fn, **kwargs)
|
694 |
-
else:
|
695 |
-
raise ValueError(f"Checkpointing method {method} is unknown.")
|
696 |
|
697 |
def forward(self, x: torch.Tensor, *args, **kwargs):
|
698 |
-
|
|
|
699 |
B, T, C = x.shape
|
700 |
|
701 |
-
|
702 |
-
|
703 |
-
else:
|
704 |
-
offsets = torch.zeros(B, dtype=torch.long, device=x.device)
|
705 |
|
706 |
-
if self.positional_embedding in ['sin',
|
|
|
707 |
positions = torch.arange(T, device=x.device).view(1, -1, 1)
|
708 |
-
|
709 |
pos_emb = create_sin_embedding(positions, C, max_period=self.max_period, dtype=x.dtype)
|
710 |
x = x + self.positional_scale * pos_emb
|
711 |
-
|
712 |
-
for
|
713 |
-
|
714 |
-
|
715 |
-
|
716 |
-
|
717 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
718 |
return x
|
719 |
|
720 |
def make_optim_group(self):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import typing as tp
|
|
|
2 |
from einops import rearrange
|
3 |
import torch
|
4 |
import torch.nn as nn
|
5 |
from torch.nn import functional as F
|
|
|
6 |
from xformers import ops
|
7 |
|
|
|
|
|
|
|
8 |
_efficient_attention_backend: str = 'torch'
|
9 |
|
10 |
|
|
|
15 |
_efficient_attention_backend = backend
|
16 |
|
17 |
|
|
|
|
|
|
|
|
|
|
|
18 |
|
19 |
|
20 |
+
|
21 |
+
def _is_profiled():
|
22 |
# Return true if we are currently running with a xformers profiler activated.
|
23 |
try:
|
24 |
from xformers.profiler import profiler
|
|
|
27 |
return profiler._Profiler._CURRENT_PROFILER is not None
|
28 |
|
29 |
|
30 |
+
def create_norm_fn(norm_type, dim, **kwargs):
|
|
|
31 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
32 |
if norm_type == 'layer_norm':
|
33 |
return nn.LayerNorm(dim, eps=1e-5, **kwargs)
|
34 |
else:
|
|
|
54 |
adim = torch.arange(half_dim, device=positions.device, dtype=dtype).view(1, 1, -1)
|
55 |
max_period_tensor = torch.full([], max_period, device=positions.device, dtype=dtype) # avoid sync point
|
56 |
phase = positions / (max_period_tensor ** (adim / (half_dim - 1)))
|
57 |
+
# print('==============CONCAT 3 ============'
|
58 |
+
return torch.cat([torch.cos(phase), torch.sin(phase)], dim=-1)
|
59 |
|
60 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
61 |
|
62 |
|
|
|
|
|
|
|
63 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
64 |
|
65 |
|
|
|
|
|
66 |
|
67 |
+
class StreamingMultiheadAttention(nn.Module):
|
68 |
+
|
69 |
+
def __init__(self,
|
70 |
+
embed_dim,
|
71 |
+
num_heads,
|
72 |
+
dropout=0.0, bias: bool = True,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
73 |
causal: bool = False, past_context: tp.Optional[int] = None, custom: bool = False,
|
74 |
memory_efficient: bool = False, attention_as_float32: bool = False,
|
75 |
+
cross_attention: bool = False,
|
76 |
+
qk_layer_norm: bool = False, kv_repeat: int = 1,
|
77 |
device=None, dtype=None):
|
78 |
super().__init__()
|
79 |
factory_kwargs = {'device': device, 'dtype': dtype}
|
|
|
85 |
self.past_context = past_context
|
86 |
self.memory_efficient = memory_efficient
|
87 |
self.attention_as_float32 = attention_as_float32
|
88 |
+
|
89 |
self.cross_attention = cross_attention
|
90 |
+
|
91 |
self.num_heads = num_heads
|
92 |
self.dropout = dropout
|
93 |
self.kv_repeat = kv_repeat
|
94 |
if cross_attention:
|
95 |
assert not causal, "Causal cannot work with cross attention."
|
96 |
+
|
97 |
|
98 |
if memory_efficient:
|
99 |
_verify_xformers_memory_efficient_compat()
|
|
|
138 |
state_dict[prefix + "mha." + key] = state_dict.pop(prefix + key)
|
139 |
super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)
|
140 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
141 |
|
142 |
+
|
143 |
+
|
144 |
+
|
145 |
+
|
146 |
+
|
147 |
+
def forward(self,
|
148 |
+
query,
|
149 |
+
key,
|
150 |
+
value,
|
151 |
+
key_padding_mask=None,
|
152 |
+
need_weights=False,
|
153 |
+
attn_mask=None,
|
154 |
+
is_causal=False):
|
155 |
+
|
156 |
assert not is_causal, ("New param added in torch 2.0.1 not supported, "
|
157 |
"use the causal args in the constructor.")
|
158 |
+
# print(f'{query.shape=} {key.shape=} {value.shape=} MHA')
|
159 |
+
time_dim = 2
|
160 |
if time_dim == 2:
|
161 |
layout = "b h t d"
|
162 |
else:
|
163 |
layout = "b t h d"
|
164 |
dtype = query.dtype
|
165 |
+
|
|
|
|
|
166 |
|
167 |
custom_attn_mask = attn_mask is not None
|
168 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
169 |
if self.custom:
|
170 |
# custom implementation
|
171 |
assert need_weights is False
|
172 |
assert key_padding_mask is None
|
173 |
if self.cross_attention:
|
174 |
+
# print('\n\n\n\nCROSS\n\n\n\n')
|
175 |
+
|
176 |
+
|
177 |
dim = self.in_proj_weight.shape[0] // 3
|
178 |
if self.in_proj_bias is None:
|
179 |
bias_q, bias_k, bias_v = None, None, None
|
|
|
182 |
bias_k = self.in_proj_bias[dim: 2 * dim]
|
183 |
bias_v = self.in_proj_bias[2 * dim:]
|
184 |
q = nn.functional.linear(query, self.in_proj_weight[:dim], bias_q)
|
185 |
+
# print(f'{q.shape=} TRANSF FORW who concaten')
|
186 |
# todo: when streaming, we could actually save k, v and check the shape actually match.
|
187 |
k = nn.functional.linear(key, self.in_proj_weight[dim: 2 * dim], bias_k)
|
188 |
v = nn.functional.linear(value, self.in_proj_weight[2 * dim:], bias_v)
|
189 |
if self.qk_layer_norm is True:
|
190 |
q = self.q_layer_norm(q)
|
191 |
k = self.k_layer_norm(k)
|
192 |
+
|
193 |
q, k, v = [rearrange(x, f"b t (h d) -> {layout}", h=self.num_heads) for x in [q, k, v]]
|
194 |
+
# print(f'{q.shape=} {k.shape=} {v.shape=} after rearrange')
|
195 |
else:
|
196 |
+
# print('\n\n\n\nSELF\n\n\n\n')
|
197 |
+
#
|
198 |
+
# 47x Transformers selfattn followed by crossattn
|
199 |
+
#
|
200 |
+
# self-attn is on history? previous key or is it on only the last token?
|
201 |
+
|
202 |
if not _is_profiled():
|
203 |
# profiling breaks that propertysomehow.
|
204 |
assert query is key, "specialized implementation"
|
|
|
209 |
bound_layout = "b h p t d"
|
210 |
else:
|
211 |
bound_layout = "b t p h d"
|
212 |
+
|
213 |
packed = rearrange(projected, f"b t (p h d) -> {bound_layout}", p=3, h=self.num_heads)
|
214 |
+
|
215 |
+
|
216 |
+
# print(f'{query.shape=} before unbind') # [2, 1, 4 , 2048] already bs=2
|
217 |
q, k, v = ops.unbind(packed, dim=2)
|
218 |
+
# print(f'{q.shape=} {v.shape=} @L331 trasnforemr.py') # packed is bs=2
|
219 |
else:
|
220 |
embed_dim = self.embed_dim
|
221 |
per_head_dim = (embed_dim // self.num_heads)
|
|
|
235 |
q = self.q_layer_norm(q)
|
236 |
k = self.k_layer_norm(k)
|
237 |
q, k = [rearrange(x, f"b t (h d) -> {layout}", h=self.num_heads) for x in [q, k]]
|
238 |
+
|
239 |
+
|
|
|
240 |
if self.kv_repeat > 1:
|
241 |
+
#
|
242 |
+
print('Expand repear 2')
|
243 |
+
|
244 |
if self.attention_as_float32:
|
245 |
q, k, v = [x.float() for x in [q, k, v]]
|
246 |
if self.memory_efficient:
|
|
|
269 |
q = q / q.shape[-1] ** 0.5
|
270 |
key_layout = layout.replace('t', 'k')
|
271 |
query_layout = layout
|
272 |
+
|
273 |
+
pre_w = torch.einsum(f"{query_layout},{key_layout}-> b h t k", q, k)
|
|
|
|
|
|
|
274 |
if attn_mask is not None:
|
275 |
pre_w = pre_w + attn_mask
|
276 |
w = torch.softmax(pre_w, dim=-1)
|
|
|
281 |
x = rearrange(x, f"{layout} -> b t (h d)", h=self.num_heads)
|
282 |
x = self.out_proj(x)
|
283 |
else:
|
284 |
+
raise NotImplementedError
|
|
|
|
|
|
|
|
|
|
|
|
|
285 |
|
286 |
return x, None
|
287 |
|
288 |
|
289 |
class StreamingTransformerLayer(nn.TransformerEncoderLayer):
|
290 |
+
def __init__(self,
|
291 |
+
d_model,
|
292 |
+
num_heads,
|
293 |
+
dim_feedforward=2048,
|
294 |
+
dropout=0.1,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
295 |
bias_ff: bool = True, bias_attn: bool = True, causal: bool = False,
|
296 |
past_context: tp.Optional[int] = None, custom: bool = False,
|
297 |
memory_efficient: bool = False, attention_as_float32: bool = False,
|
298 |
qk_layer_norm: bool = False, qk_layer_norm_cross: bool = False,
|
299 |
+
cross_attention: bool = False,
|
300 |
+
# rope=None,
|
301 |
+
attention_dropout: tp.Optional[float] = None,
|
302 |
kv_repeat: int = 1, norm: str = 'layer_norm', device=None, dtype=None, **kwargs):
|
303 |
super().__init__(d_model, num_heads, dim_feedforward, dropout,
|
304 |
device=device, dtype=dtype, batch_first=True, **kwargs)
|
|
|
314 |
'attention_as_float32': attention_as_float32,
|
315 |
}
|
316 |
self.self_attn: StreamingMultiheadAttention = StreamingMultiheadAttention(
|
317 |
+
causal=causal, past_context=past_context,
|
318 |
+
# rope=rope,
|
319 |
+
qk_layer_norm=qk_layer_norm,
|
320 |
kv_repeat=kv_repeat, **attn_kwargs, **factory_kwargs) # type: ignore
|
321 |
# Redefine feedforward layers to expose bias parameter
|
322 |
self.linear1 = nn.Linear(d_model, dim_feedforward, bias=bias_ff, **factory_kwargs)
|
323 |
self.linear2 = nn.Linear(dim_feedforward, d_model, bias=bias_ff, **factory_kwargs)
|
324 |
|
325 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
326 |
|
327 |
+
self.cross_attention = None # default
|
328 |
if cross_attention:
|
329 |
self.cross_attention = StreamingMultiheadAttention(
|
330 |
cross_attention=True, qk_layer_norm=qk_layer_norm_cross,
|
|
|
333 |
self.dropout_cross = nn.Dropout(dropout)
|
334 |
# eps value matching that used in PyTorch reference implementation.
|
335 |
self.norm_cross = nn.LayerNorm(d_model, eps=1e-5, **factory_kwargs)
|
336 |
+
|
|
|
|
|
|
|
|
|
337 |
self.norm1 = create_norm_fn(norm, d_model, **factory_kwargs) # type: ignore
|
338 |
self.norm2 = create_norm_fn(norm, d_model, **factory_kwargs) # type: ignore
|
339 |
|
340 |
+
def _cross_attention_block(self,
|
341 |
+
src,
|
342 |
+
cross_attention_src):
|
343 |
+
|
344 |
# queries are from src, keys and values from cross_attention_src.
|
345 |
x = self.cross_attention(
|
346 |
src, cross_attention_src, cross_attention_src, need_weights=False)[0]
|
347 |
return self.dropout_cross(x) # type: ignore
|
348 |
|
349 |
+
def forward(self,
|
350 |
+
src,
|
351 |
+
src_mask=None,
|
352 |
+
src_key_padding_mask=None, # key = value = looooong I think I pass them inversed
|
353 |
+
cross_attention_src=None):
|
354 |
+
|
355 |
+
|
356 |
x = src
|
357 |
if self.norm_first:
|
358 |
+
# print('selfattn', x.shape, src_mask, src_key_padding_mask)
|
359 |
+
x = x + self._sa_block(self.norm1(x),
|
360 |
+
src_mask, #None
|
361 |
+
src_key_padding_mask # None
|
362 |
+
) # Internal nn
|
363 |
+
# print('crossattn', x.shape, cross_attention_src.shape)
|
364 |
if cross_attention_src is not None:
|
365 |
+
x = x + self._cross_attention_block(
|
366 |
+
self.norm_cross(x),
|
367 |
+
cross_attention_src)
|
368 |
+
# selfattn torch.Size([2, 2, 1536]) None None NO 4D TOKEN!
|
369 |
+
# crossattn torch.Size([2, 2, 1536]) torch.Size([2, 4, 1536])
|
370 |
+
else:
|
371 |
+
raise NotImplementedError # all layers have a self & cross?
|
372 |
+
x = x + self._ff_block(self.norm2(x))
|
373 |
else:
|
374 |
+
print('NLAST')
|
375 |
+
# print('NT', x.shape) # [1,2 ,1536]
|
|
|
|
|
|
|
|
|
|
|
376 |
return x
|
377 |
|
378 |
|
379 |
+
class StreamingTransformer(nn.Module):
|
380 |
+
'''layer_class=<class 'audiocraft.transformer.StreamingTransformerLayer'> StrTrnsf'''
|
381 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
382 |
def __init__(self, d_model: int, num_heads: int, num_layers: int, dim_feedforward: int = 2048,
|
383 |
dropout: float = 0.1, bias_ff: bool = True, bias_attn: bool = True,
|
384 |
causal: bool = False, past_context: tp.Optional[int] = None,
|
385 |
custom: bool = False, memory_efficient: bool = False, attention_as_float32: bool = False,
|
386 |
+
cross_attention: bool = False,
|
387 |
positional_embedding: str = 'sin', max_period: float = 10_000, positional_scale: float = 1.,
|
388 |
+
xpos=False,
|
389 |
+
lr=None,
|
390 |
+
weight_decay=None,
|
391 |
+
layer_class=StreamingTransformerLayer,
|
392 |
+
checkpointing='none',
|
393 |
+
device=None,
|
394 |
+
dtype=None,
|
395 |
+
**kwargs):
|
396 |
super().__init__()
|
397 |
assert d_model % num_heads == 0
|
398 |
+
|
399 |
self.positional_embedding = positional_embedding
|
400 |
self.max_period = max_period
|
401 |
self.positional_scale = positional_scale
|
|
|
403 |
self.lr = lr
|
404 |
|
405 |
assert positional_embedding in ['sin', 'rope', 'sin_rope']
|
|
|
|
|
|
|
|
|
|
|
|
|
406 |
self.checkpointing = checkpointing
|
407 |
|
408 |
assert checkpointing in ['none', 'torch', 'xformers_default', 'xformers_mm']
|
|
|
417 |
dropout=dropout, bias_ff=bias_ff, bias_attn=bias_attn,
|
418 |
causal=causal, past_context=past_context, custom=custom,
|
419 |
memory_efficient=memory_efficient, attention_as_float32=attention_as_float32,
|
420 |
+
cross_attention=cross_attention,
|
421 |
+
# rope=self.rope,
|
422 |
device=device, dtype=dtype, **kwargs))
|
423 |
|
424 |
if self.checkpointing != 'none':
|
|
|
427 |
# backward hook inside of FSDP...
|
428 |
layer._magma_checkpointed = True # type: ignore
|
429 |
|
430 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
431 |
|
432 |
def forward(self, x: torch.Tensor, *args, **kwargs):
|
433 |
+
# print(f'{x.shape=} StreamingTransf') # [1, 1, 1536] Always no batch==2 here
|
434 |
+
# why is this called with time-len = 1? Shouldnt be called with context?
|
435 |
B, T, C = x.shape
|
436 |
|
437 |
+
|
438 |
+
|
|
|
|
|
439 |
|
440 |
+
if self.positional_embedding in ['sin',
|
441 |
+
'sin_rope']:
|
442 |
positions = torch.arange(T, device=x.device).view(1, -1, 1)
|
443 |
+
|
444 |
pos_emb = create_sin_embedding(positions, C, max_period=self.max_period, dtype=x.dtype)
|
445 |
x = x + self.positional_scale * pos_emb
|
446 |
+
# UNTIL HERE BATCH=1
|
447 |
+
for _, lay in enumerate(self.layers):
|
448 |
+
# if _ < 2:
|
449 |
+
# L=0 [1,1,1536]
|
450 |
+
# L=1 [2,1,1536]
|
451 |
+
|
452 |
+
print(f'L={_} {args=} {kwargs["cross_attention_src"].shape=} {x.shape=} StreamTransf ForLoop') # [2, 1, 1536] BATCH=2
|
453 |
+
# x = self._apply_layer(layer, x, *args, **kwargs)
|
454 |
+
# x = lay(x, **kwargs)
|
455 |
+
x = lay(x,
|
456 |
+
cross_attention_src=kwargs["cross_attention_src"],
|
457 |
+
src_mask=kwargs['src_mask'])
|
458 |
+
# concat old token to query oh not here is on lm generate
|
459 |
+
print('OUT OF Tall', x.shape) # [1,2,1536] # why this gets filled with sequence 1,2...
|
460 |
+
# should be 1 query
|
461 |
return x
|
462 |
|
463 |
def make_optim_group(self):
|
demo.py
CHANGED
@@ -7,7 +7,7 @@ print('\n\n\n\n___________________')
|
|
7 |
txt = 'dogs in street'
|
8 |
|
9 |
sound_generator = AudioGen.get_pretrained('facebook/audiogen-medium')
|
10 |
-
sound_generator.set_generation_params(duration
|
11 |
|
12 |
x = sound_generator.generate([txt])[0].detach().cpu().numpy()[0, :]
|
13 |
x /= np.abs(x).max() + 1e-7
|
|
|
7 |
txt = 'dogs in street'
|
8 |
|
9 |
sound_generator = AudioGen.get_pretrained('facebook/audiogen-medium')
|
10 |
+
sound_generator.set_generation_params(duration=1.24) # why is generating so long at 14 seconds
|
11 |
|
12 |
x = sound_generator.generate([txt])[0].detach().cpu().numpy()[0, :]
|
13 |
x /= np.abs(x).max() + 1e-7
|