Dionyssos commited on
Commit
a0ce150
·
1 Parent(s): d44fd96

DEBUG: cross_attention_src = query or key?

Browse files
audiocraft/audiogen.py CHANGED
@@ -4,11 +4,6 @@
4
  # This source code is licensed under the license found in the
5
  # LICENSE file in the root directory of this source tree.
6
 
7
- """
8
- Main model for using AudioGen. This will combine all the required components
9
- and provide easy access to the generation API.
10
- """
11
-
12
  import typing as tp
13
  import torch
14
  from audiocraft.loaders import load_compression_model, load_lm_model
@@ -87,51 +82,25 @@ class BaseGenModel(ABC):
87
  """Sample rate of the generated audio."""
88
  return self.compression_model.sample_rate
89
 
90
- @property
91
- def audio_channels(self) -> int:
92
- """Audio channels of the generated audio."""
93
- return self.compression_model.channels
94
-
95
- @torch.no_grad()
96
- def _prepare_tokens_and_attributes(
97
- self,
98
- descriptions,
99
- prompt,
100
- ):
101
- attributes = [
102
- ConditioningAttributes(text={'description': description}) for description in descriptions]
103
- prompt_tokens = None
104
- return attributes, prompt_tokens
105
-
106
- def generate_unconditional(self,
107
- num_samples,
108
- progress=False,
109
- return_tokens=False):
110
- descriptions: tp.List[tp.Optional[str]] = [None] * num_samples
111
- attributes, _ = self._prepare_tokens_and_attributes(descriptions, None)
112
- tokens = self._generate_tokens(attributes)
113
- if return_tokens:
114
- return self.generate_audio(tokens), tokens
115
- return self.generate_audio(tokens)
116
 
117
- def generate(self,
118
- descriptions,
119
- progress=False,
120
- return_tokens=False):
121
- attributes, _ = self._prepare_tokens_and_attributes(descriptions, None)
 
122
  tokens = self._generate_tokens(attributes)
123
- if return_tokens:
124
- return self.generate_audio(tokens), tokens
125
  return self.generate_audio(tokens)
126
 
127
- def _generate_tokens(self, attributes,
128
- prompt_tokens=None,
129
- progress=False):
130
 
131
  total_gen_len = int(self.duration * self.frame_rate)
132
- max_prompt_len = int(min(self.duration, self.max_duration) * self.frame_rate)
133
- current_gen_offset: int = 0
134
 
 
 
 
 
 
135
 
136
 
137
 
@@ -140,10 +109,7 @@ class BaseGenModel(ABC):
140
  # generate by sampling from LM, simple case.
141
 
142
  with self.autocast:
143
- gen_tokens = self.lm.generate(conditions=attributes,
144
- callback=None,
145
- max_gen_len=total_gen_len,
146
- **self.generation_params)
147
  else:
148
  print('<>Long gen ?<>')
149
  # print(f'{gen_tokens.shape=}') # [5,4,35]
 
4
  # This source code is licensed under the license found in the
5
  # LICENSE file in the root directory of this source tree.
6
 
 
 
 
 
 
7
  import typing as tp
8
  import torch
9
  from audiocraft.loaders import load_compression_model, load_lm_model
 
82
  """Sample rate of the generated audio."""
83
  return self.compression_model.sample_rate
84
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
 
86
+
87
+
88
+
89
+ def generate(self, descriptions):
90
+ attributes = [
91
+ ConditioningAttributes(text={'description': d}) for d in descriptions]
92
  tokens = self._generate_tokens(attributes)
 
 
93
  return self.generate_audio(tokens)
94
 
95
+ def _generate_tokens(self, attributes):
 
 
96
 
97
  total_gen_len = int(self.duration * self.frame_rate)
 
 
98
 
99
+
100
+ # # print(f'{self.generation_params=}')
101
+ # self.generation_params={'use_sampling': True,
102
+ # 'temp': 1.0, 'top_k': 250,
103
+ # 'top_p': 0.0, 'cfg_coef': 2.4, 'two_step_cfg': False}
104
 
105
 
106
 
 
109
  # generate by sampling from LM, simple case.
110
 
111
  with self.autocast:
112
+ gen_tokens = self.lm.generate(conditions=attributes, max_gen_len=total_gen_len)
 
 
 
113
  else:
114
  print('<>Long gen ?<>')
115
  # print(f'{gen_tokens.shape=}') # [5,4,35]
audiocraft/conditioners.py CHANGED
@@ -8,7 +8,7 @@ import soundfile
8
  from transformers import T5EncoderModel, T5Tokenizer # type: ignore
9
  import torch
10
  from torch import nn
11
- from .streaming import StreamingModule
12
 
13
  from .utils.autocast import TorchAutocast
14
 
@@ -126,17 +126,7 @@ class BaseConditioner(nn.Module):
126
  """
127
  raise NotImplementedError()
128
 
129
- def forward(self, inputs: tp.Any) -> ConditionType:
130
- """Gets input that should be used as conditioning (e.g, genre, description or a waveform).
131
- Outputs a ConditionType, after the input data was embedded as a dense vector.
132
-
133
- Returns:
134
- ConditionType:
135
- - A tensor of size [B, T, D] where B is the batch size, T is the length of the
136
- output embedding and D is the dimension of the embedding.
137
- - And a mask indicating where the padding tokens.
138
- """
139
- raise NotImplementedError()
140
 
141
 
142
  class TextConditioner(BaseConditioner):
@@ -239,6 +229,9 @@ class T5Conditioner(TextConditioner):
239
  embeds = self.t5(**inputs).last_hidden_state
240
  embeds = self.output_proj(embeds.to(self.output_proj.weight))
241
  embeds = (embeds * mask.unsqueeze(-1))
 
 
 
242
  return embeds, mask
243
 
244
 
@@ -352,21 +345,8 @@ class ConditioningProvider(nn.Module):
352
 
353
 
354
 
355
- class ConditionFuser(StreamingModule):
356
- """Condition fuser handles the logic to combine the different conditions
357
- to the actual model input.
358
 
359
- Args:
360
- fuse2cond (tp.Dict[str, str]): A dictionary that says how to fuse
361
- each condition. For example:
362
- {
363
- "prepend": ["description"],
364
- "sum": ["genre", "bpm"],
365
- "cross": ["description"],
366
- }
367
- cross_attention_pos_emb (bool, optional): Use positional embeddings in cross attention.
368
- cross_attention_pos_emb_scale (int): Scale for positional embeddings in cross attention if used.
369
- """
370
  FUSING_METHODS = ["sum", "prepend", "cross", "input_interpolate"]
371
 
372
  def __init__(self, fuse2cond: tp.Dict[str, tp.List[str]], cross_attention_pos_emb: bool = False,
@@ -387,25 +367,4 @@ class ConditionFuser(StreamingModule):
387
  self,
388
  input,
389
  conditions):
390
-
391
- B, T, _ = input.shape
392
-
393
-
394
- first_step = True
395
- offsets = torch.zeros(input.shape[0], dtype=torch.long, device=input.device)
396
-
397
-
398
- cross_attention_output = None
399
- for cond_type, (cond, cond_mask) in conditions.items():
400
- # print(f'{self.cond2fuse=}') - self.cond2fuse={'description': 'cross'}
401
-
402
- cross_attention_output = cond
403
- # print(f'{cross_attention_output.shape=} for {input.sum()=}')
404
- # cross_attention_output.shape=torch.Size([2, 5, 1536]) for input.sum()=tensor(-0.0650, device='cuda:0')
405
- # cross_attention_output.shape=torch.Size([2, 5, 1536]) for input.sum()=tensor(3.7672, device='cuda:0')
406
-
407
-
408
- if self._is_streaming:
409
- self._streaming_state['offsets'] = offsets + T
410
-
411
- return input, cross_attention_output
 
8
  from transformers import T5EncoderModel, T5Tokenizer # type: ignore
9
  import torch
10
  from torch import nn
11
+
12
 
13
  from .utils.autocast import TorchAutocast
14
 
 
126
  """
127
  raise NotImplementedError()
128
 
129
+
 
 
 
 
 
 
 
 
 
 
130
 
131
 
132
  class TextConditioner(BaseConditioner):
 
229
  embeds = self.t5(**inputs).last_hidden_state
230
  embeds = self.output_proj(embeds.to(self.output_proj.weight))
231
  embeds = (embeds * mask.unsqueeze(-1))
232
+
233
+ # T5 torch.Size([2, 4, 1536]) dict_keys(['input_ids', 'attention_mask'])
234
+ # print(f'{inputs["input_ids"].shape=}') # inputs["input_ids"].shape=torch.Size([2, 4])
235
  return embeds, mask
236
 
237
 
 
345
 
346
 
347
 
348
+ class ConditionFuser(nn.Module):
 
 
349
 
 
 
 
 
 
 
 
 
 
 
 
350
  FUSING_METHODS = ["sum", "prepend", "cross", "input_interpolate"]
351
 
352
  def __init__(self, fuse2cond: tp.Dict[str, tp.List[str]], cross_attention_pos_emb: bool = False,
 
367
  self,
368
  input,
369
  conditions):
370
+ return input, conditions['description'][0] #cross_attention_output
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
audiocraft/lm.py CHANGED
@@ -6,7 +6,6 @@ import re
6
  import typing as tp
7
  import torch
8
  import torch.nn.functional as F
9
- from audiocraft.streaming import StreamingModule
10
  from audiocraft.transformer import StreamingTransformer, create_norm_fn
11
  from dataclasses import dataclass
12
  from functools import partial
@@ -109,7 +108,7 @@ class LMOutput:
109
  mask: torch.Tensor # [B, K, T]
110
 
111
 
112
- class LMModel(StreamingModule):
113
  """Transformer-based language model on multiple streams of codes.
114
 
115
  Args:
@@ -148,7 +147,7 @@ class LMModel(StreamingModule):
148
  super().__init__()
149
  self.cfg_coef = cfg_coef
150
 
151
- self.n_draw = 24
152
  self.condition_provider = condition_provider
153
  self.fuser = fuser
154
  self.card = card # 2048 ?
@@ -160,9 +159,26 @@ class LMModel(StreamingModule):
160
  self.emb = nn.ModuleList([ScaledEmbedding(embed_dim, dim, lr=emb_lr) for _ in range(n_q)])
161
  if 'activation' in kwargs:
162
  kwargs['activation'] = get_activation_fn(kwargs['activation'])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
163
  self.transformer = StreamingTransformer(
164
- d_model=dim, num_heads=num_heads, dim_feedforward=int(hidden_scale * dim),
165
- norm=norm, norm_first=norm_first, **kwargs)
 
 
 
166
  self.out_norm: tp.Optional[nn.Module] = None
167
  if norm_first:
168
  self.out_norm = create_norm_fn(norm, dim)
@@ -199,7 +215,10 @@ class LMModel(StreamingModule):
199
  depth = layer_idx + 1
200
  elif depthwise_init == 'global':
201
  depth = len(self.transformer.layers)
202
- init_fn = partial(init_layer, method=weight_init, init_depth=depth, zero_bias_init=zero_bias_init)
 
 
 
203
  tr_layer.apply(init_fn)
204
 
205
  for linear in self.linears:
@@ -215,91 +234,55 @@ class LMModel(StreamingModule):
215
 
216
  def forward(self,
217
  sequence,
218
- conditions,
219
  condition_tensors=None,
220
  stage = -1):
221
- B, K, S = sequence.shape
222
-
223
 
224
  input_ = sum([self.emb[k](sequence[:, k]) for k in range(K)])
225
 
226
 
227
- input_, cross_attention_input = self.fuser(input_, condition_tensors) # DEFINE conditioners.py
228
-
229
- # print(f'{input_.shape=} {cross_attention_input.shape=} FUSER LLM FORw')
230
- # input_.shape=torch.Size([1, 1, 1536]) cross_attention_input.shape=torch.Size([2, 7, 1536]) FUSER LLM FORw
231
 
232
  out = self.transformer(input_, cross_attention_src=cross_attention_input,
233
  src_mask=(self.attn_mask_per_stage[stage] if stage >= 0 else None))
234
  if self.out_norm:
235
  out = self.out_norm(out)
 
 
 
236
  logits = torch.stack([self.linears[k](out) for k in range(K)], dim=1) # [B, K, S, card]
237
-
238
  # remove the prefix from the model outputs
239
- if len(self.fuser.fuse2cond['prepend']) > 0:
240
- logits = logits[:, :, -S:]
241
- print('==========================================PRESFIX')
242
 
243
  return logits # [B, K, S, card]
244
 
245
 
246
- def _sample_next_token(self,
247
- sequence,
248
- cfg_conditions,
249
- unconditional_state):
250
- """self.n_draw"""
251
- B = sequence.shape[0]
252
-
253
- model = self if self._fsdp is None else self._fsdp
254
-
255
- condition_tensors = cfg_conditions
256
- # logits = [2, 4, 1, 2048]
257
- logits = model(
258
- sequence, # cond_logits = wav condition
259
- conditions=[], condition_tensors=condition_tensors) # uncond_logits already see the text
260
-
261
-
262
- # use cfg
263
- # logits = (3 * logits[1, :, :, :] - 2.4 * logits[0, :, :, :]).transpose(1,0)
264
-
265
- # or use 1 of logits
266
- logits = logits[0, :, :, :].transpose(1,0) # [2,4,1, 2048] -> [1,4,2048]
267
-
268
-
269
-
270
- # print(f'{B=}, {logits.shape=} SAMPLER {top_k=}')
271
- next_token = utils.sample_top_k(logits, n_draw=self.n_draw) # [1,4,2048] logits
272
- return next_token
273
-
274
  # GENERATE class revert_codebook_patterns()
275
  @torch.no_grad()
276
  def generate(self,
277
  prompt = None,
278
  conditions = [],
279
- num_samples = 1, # THIS IS HOW MANY GENERATIONS - A SAMPLE IS A FULL WAV
280
- max_gen_len=256, # unduplicated sequence length - actual len will be n_draw * maxgenlen
281
- use_sampling: bool = True,
282
- **kwargs):
283
 
284
- print(f'{num_samples=}')
285
  first_param = next(iter(self.parameters()))
286
  device = first_param.device
287
 
288
- # below we create set of conditions: one conditional and one unconditional
289
- # to do that we merge the regular condition together with the null condition
290
- # we then do 1 forward pass instead of 2.
291
- # the reason for that is two-fold:
292
- # 1. it is about x2 faster than doing 2 forward passes
293
- # 2. avoid the streaming API treating the 2 passes as part of different time steps
294
- # We also support doing two different passes, in particular to ensure that
295
- # the padding structure is exactly the same between train and test.
296
- # With a batch size of 1, this can be slower though.
297
- cfg_conditions: CFGConditions
298
- # two_step_cfg = self.two_step_cfg if two_step_cfg is None else two_step_cfg
299
-
300
- null_conditions = conditions
301
- conditions = conditions + null_conditions
302
  tokenized = self.condition_provider.tokenize(conditions)
 
 
 
 
 
303
  cfg_conditions = self.condition_provider(tokenized)
304
 
305
 
@@ -326,58 +309,59 @@ class LMModel(StreamingModule):
326
 
327
 
328
 
329
- with self.streaming():
 
 
 
 
 
330
 
331
- unconditional_state = self.get_streaming_state()
332
- prev_offset = 0
333
- gen_sequence_len = _gen_sequence.shape[-1] # gen_sequence shape is [B, K, S]
 
 
 
 
 
 
 
 
 
 
334
 
 
 
 
 
 
335
  # --
336
- # print(mask.shape, mask.sum(), 'MSK LM')
337
- # torch.Size([4, 39]) tensor(140, device='cuda:0') MSK LM ? Fully 1 normal no special token
338
- # --
339
- duplicate_draw = [
340
- _gen_sequence[:, :, 0:1].repeat(self.n_draw, 1, 1)
341
- ]
342
- # list to hold next tokens - draw sample multiple tokens at each time-step
343
- # but continue the sequence only with isingle next token
 
 
 
 
 
 
 
 
 
 
 
 
 
344
 
345
- for offset in range(1, gen_sequence_len): # start_offset_sequence=1
346
- # print(f'{_gen_sequence.shape=}') # [1,4,16]
347
- # starts from 1 not 0 thus uses the 0:1 as curr sequence
348
- # although this is empty contains -1 ?
349
 
350
- curr_sequence = _gen_sequence[..., prev_offset:offset]
351
-
352
-
353
-
354
- next_token = self._sample_next_token(
355
- curr_sequence,
356
- cfg_conditions,
357
- unconditional_state) # [5, 4, 1]
358
-
359
-
360
-
361
-
362
-
363
-
364
-
365
- # RUNS with = 2047 just different of self.special_token_id = 2047 = alwayssingletoken = drill noise
366
- # special_token_id is filler for CODEBOOK_PATTERN ?
367
-
368
- # next_token[:] = self.special_token_id # seanet.embed torch.embedding does not have this - out of bounds in detokenize
369
-
370
- _gen_sequence[..., offset:offset+1] = next_token[0, :, :] #gen_sequence.shape=torch.Size([1, 4, 39])
371
-
372
- duplicate_draw.append(next_token)
373
-
374
- prev_offset = offset
375
-
376
-
377
-
378
- unconditional_state.clear()
379
-
380
- gen_sequence = torch.cat(duplicate_draw, 2) # [self.n_draw, 4, len_seq]
381
 
382
  # revert codes as "batch"
383
 
@@ -415,4 +399,4 @@ class LMModel(StreamingModule):
415
 
416
 
417
 
418
- return out_codes # supposedly contains extra prompt
 
6
  import typing as tp
7
  import torch
8
  import torch.nn.functional as F
 
9
  from audiocraft.transformer import StreamingTransformer, create_norm_fn
10
  from dataclasses import dataclass
11
  from functools import partial
 
108
  mask: torch.Tensor # [B, K, T]
109
 
110
 
111
+ class LMModel(nn.Module):
112
  """Transformer-based language model on multiple streams of codes.
113
 
114
  Args:
 
147
  super().__init__()
148
  self.cfg_coef = cfg_coef
149
 
150
+ self.n_draw = 5
151
  self.condition_provider = condition_provider
152
  self.fuser = fuser
153
  self.card = card # 2048 ?
 
159
  self.emb = nn.ModuleList([ScaledEmbedding(embed_dim, dim, lr=emb_lr) for _ in range(n_q)])
160
  if 'activation' in kwargs:
161
  kwargs['activation'] = get_activation_fn(kwargs['activation'])
162
+ # ========================================================================
163
+ # {
164
+ # 'dtype': torch.float16, 'device': 'cuda',
165
+ # 'num_layers': 48, 'dropout': 0.0, 'activation': 'gelu',
166
+ # 'bias_ff': False, 'bias_attn': False,
167
+ # 'past_context': None, 'causal': True,
168
+ # 'custom': False, 'memory_efficient': True,
169
+ # 'attention_as_float32': False, 'positional_embedding': 'sin', 'xpos': False,
170
+ # 'checkpointing': 'none', 'cross_attention': True, 'qk_layer_norm': False,
171
+ # 'qk_layer_norm_cross': False, 'attention_dropout': None, 'kv_repeat': 1
172
+ # }
173
+ # ==========================================================================
174
+ kwargs.pop('layer_scale') # nn.Indentity()
175
+
176
  self.transformer = StreamingTransformer(
177
+ d_model=dim,
178
+ num_heads=num_heads,
179
+ dim_feedforward=int(hidden_scale * dim),
180
+ norm=norm,
181
+ norm_first=norm_first, **kwargs)
182
  self.out_norm: tp.Optional[nn.Module] = None
183
  if norm_first:
184
  self.out_norm = create_norm_fn(norm, dim)
 
215
  depth = layer_idx + 1
216
  elif depthwise_init == 'global':
217
  depth = len(self.transformer.layers)
218
+ init_fn = partial(init_layer,
219
+ method=weight_init,
220
+ init_depth=depth,
221
+ zero_bias_init=zero_bias_init)
222
  tr_layer.apply(init_fn)
223
 
224
  for linear in self.linears:
 
234
 
235
  def forward(self,
236
  sequence,
 
237
  condition_tensors=None,
238
  stage = -1):
239
+ B, K, S = sequence.shape # linears are n_q
 
240
 
241
  input_ = sum([self.emb[k](sequence[:, k]) for k in range(K)])
242
 
243
 
244
+ # input_, cross_attention_input = self.fuser(input_, condition_tensors)
245
+ cross_attention_input = condition_tensors['description'][0]
246
+ # print(f'{input_.shape=} {cross_attention_input.shape=} FUSER LLM')
247
+
248
 
249
  out = self.transformer(input_, cross_attention_src=cross_attention_input,
250
  src_mask=(self.attn_mask_per_stage[stage] if stage >= 0 else None))
251
  if self.out_norm:
252
  out = self.out_norm(out)
253
+ # K = 2 because of llm producing 2 tokens?
254
+ # so only 2 x sel.flinear() of 4 are used ?
255
+ # WHy torch.stack is in dim=1
256
  logits = torch.stack([self.linears[k](out) for k in range(K)], dim=1) # [B, K, S, card]
257
+ print(f'{input_.shape=} {out.shape=} {cross_attention_input.shape=} {logits.shape=} FUSER LLM')
258
  # remove the prefix from the model outputs
259
+ # if len(self.fuser.fuse2cond['prepend']) > 0:
260
+ # logits = logits[:, :, -S:]
261
+ # print('==========================================PRESFIX')
262
 
263
  return logits # [B, K, S, card]
264
 
265
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
266
  # GENERATE class revert_codebook_patterns()
267
  @torch.no_grad()
268
  def generate(self,
269
  prompt = None,
270
  conditions = [],
271
+ num_samples = 1, # N next token
272
+ max_gen_len=256):
 
 
273
 
274
+ print(f'{prompt=} {conditions=}')
275
  first_param = next(iter(self.parameters()))
276
  device = first_param.device
277
 
278
+
279
+
 
 
 
 
 
 
 
 
 
 
 
 
280
  tokenized = self.condition_provider.tokenize(conditions)
281
+ # print('TOKENIZ', tokenized) # 'description'
282
+ # TOKENIZ {'description': {'input_ids': tensor([[3887, 16, 2815, 1],
283
+ # [3887, 16, 2815, 1]], device='cuda:0'), 'attention_mask': tensor([[1, 1, 1, 1],
284
+ # [1, 1, 1, 1]], device='cuda:0')}}
285
+
286
  cfg_conditions = self.condition_provider(tokenized)
287
 
288
 
 
309
 
310
 
311
 
312
+
313
+
314
+ # --
315
+ # print(mask.shape, mask.sum(), 'MSK LM')
316
+ # torch.Size([4, 39]) tensor(140, device='cuda:0') MSK LM ? Fully 1 normal no special token
317
+ # --\
318
 
319
+ # list - Elongation for take-5 next tokens - n_draw 5 tokens at each time-step
320
+ # append them at end of sequence
321
+ duplicate_draw = [
322
+ _gen_sequence[:, :, 0:1].repeat(self.n_draw, 1, 1)
323
+ ]
324
+
325
+
326
+ for offset in range(1, _gen_sequence.shape[2]): # gen_sequence shape is [B, K, S]):
327
+ # print(f'{_gen_sequence.shape=}') # [1,4,16]
328
+ # starts from 1 not 0 thus uses the 0:1 as curr sequence
329
+ # although this is empty contains -1 ?
330
+
331
+
332
 
333
+
334
+ # ====================== SAMPLE NEXT TOK
335
+ # next_token = self._sample_next_token(
336
+ # _gen_sequence[..., :offset],
337
+ # cfg_conditions) # [5, 4, 1]
338
  # --
339
+ # def _sample_next_token(self,
340
+ # sequence,
341
+ # cfg_conditions):
342
+ model = self if self._fsdp is None else self._fsdp
343
+
344
+ logits = model(_gen_sequence[..., :offset],
345
+ condition_tensors=cfg_conditions)
346
+ # print(logits.shape, 'Next Logits') # [1, 4, 2, 2048] why 2 tokens on query
347
+
348
+ # use cfg
349
+ # logits = (3 * logits[1, :, :, :] - 2.4 * logits[0, :, :, :]).transpose(1,0)
350
+
351
+ # or use 1 of logits
352
+ logits = logits[0, :, 0:1, :] # [1,4,2048]
353
+ next_token = utils.sample_top_k(logits, n_draw=self.n_draw) # [1,4,2048] logits
354
+ # =================================
355
+
356
+
357
+ _gen_sequence[:, :, offset] = next_token[0, :, 0] #gen_sequence.shape=torch.Size([1, 4, 39])
358
+
359
+ duplicate_draw.append(next_token)
360
 
 
 
 
 
361
 
362
+
363
+
364
+ gen_sequence = torch.cat(duplicate_draw, 2) # RESHAPE -> N_DRAW -> TIME
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
365
 
366
  # revert codes as "batch"
367
 
 
399
 
400
 
401
 
402
+ return out_codes #
audiocraft/streaming.py DELETED
@@ -1,131 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the license found in the
5
- # LICENSE file in the root directory of this source tree.
6
-
7
- """
8
- Streaming module API that should be implemented by all Streaming components,
9
- """
10
-
11
- from contextlib import contextmanager
12
- import typing as tp
13
- from torch import nn
14
- import torch
15
-
16
-
17
- State = tp.Dict[str, torch.Tensor]
18
-
19
-
20
- class StreamingModule(nn.Module):
21
- """Common API for streaming components.
22
-
23
- Each streaming component has a streaming state, which is just a dict[str, Tensor].
24
- By convention, the first dim of each tensor must be the batch size.
25
- Don't use dots in the key names, as this would clash with submodules
26
- (like in state_dict).
27
-
28
- If `self._is_streaming` is True, the component should use and remember
29
- the proper state inside `self._streaming_state`.
30
-
31
- To set a streaming component in streaming state, use
32
-
33
- with module.streaming():
34
- ...
35
-
36
- This will automatically reset the streaming state when exiting the context manager.
37
- This also automatically propagates to all streaming children module.
38
-
39
- Some module might also implement the `StreamingModule.flush` method, although
40
- this one is trickier, as all parents module must be StreamingModule and implement
41
- it as well for it to work properly. See `StreamingSequential` after.
42
- """
43
- def __init__(self) -> None:
44
- super().__init__()
45
- self._streaming_state: State = {}
46
- self._is_streaming = False
47
-
48
- def _apply_named_streaming(self, fn: tp.Any):
49
- for name, module in self.named_modules():
50
- if isinstance(module, StreamingModule):
51
- fn(name, module)
52
-
53
- def _set_streaming(self, streaming: bool):
54
- def _set_streaming(name, module):
55
- module._is_streaming = streaming
56
- self._apply_named_streaming(_set_streaming)
57
-
58
- @contextmanager
59
- def streaming(self):
60
- """Context manager to enter streaming mode. Reset streaming state on exit."""
61
- self._set_streaming(True)
62
- try:
63
- yield
64
- finally:
65
- self._set_streaming(False)
66
- self.reset_streaming()
67
-
68
- def reset_streaming(self):
69
- """Reset the streaming state."""
70
- def _reset(name: str, module: StreamingModule):
71
- module._streaming_state.clear()
72
-
73
- self._apply_named_streaming(_reset)
74
-
75
- def get_streaming_state(self) -> State:
76
- """Return the streaming state, including that of sub-modules."""
77
- state: State = {}
78
-
79
- def _add(name: str, module: StreamingModule):
80
- if name:
81
- name += "."
82
- for key, value in module._streaming_state.items():
83
- state[name + key] = value
84
-
85
- self._apply_named_streaming(_add)
86
- return state
87
-
88
- def set_streaming_state(self, state: State):
89
- """Set the streaming state, including that of sub-modules."""
90
- state = dict(state)
91
-
92
- def _set(name: str, module: StreamingModule):
93
- if name:
94
- name += "."
95
- module._streaming_state.clear()
96
- for key, value in list(state.items()):
97
- # complexity is not ideal here, but probably fine.
98
- if key.startswith(name):
99
- local_key = key[len(name):]
100
- if '.' not in local_key:
101
- module._streaming_state[local_key] = value
102
- del state[key]
103
-
104
- self._apply_named_streaming(_set)
105
- assert len(state) == 0, list(state.keys())
106
-
107
- def flush(self, x: tp.Optional[torch.Tensor] = None):
108
- """Flush any remaining outputs that were waiting for completion.
109
- Typically, for convolutions, this will add the final padding
110
- and process the last buffer.
111
-
112
- This should take an optional argument `x`, which will be provided
113
- if a module before this one in the streaming pipeline has already
114
- spitted out a flushed out buffer.
115
- """
116
- if x is None:
117
- return None
118
- else:
119
- return self(x)
120
-
121
-
122
- class StreamingSequential(StreamingModule, nn.Sequential):
123
- """A streaming compatible alternative of `nn.Sequential`.
124
- """
125
- def flush(self, x: tp.Optional[torch.Tensor] = None):
126
- for module in self:
127
- if isinstance(module, StreamingModule):
128
- x = module.flush(x)
129
- elif x is not None:
130
- x = module(x)
131
- return x
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
audiocraft/transformer.py CHANGED
@@ -1,30 +1,10 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the license found in the
5
- # LICENSE file in the root directory of this source tree.
6
-
7
- """
8
- Transformer model, with streaming support, xformer attention support
9
- and easy causal attention with a potentially finite receptive field.
10
-
11
- See `StreamingTransformer` for more information.
12
-
13
- Unlike regular PyTorch Transformer, we make the hard choice that batches are first.
14
- """
15
-
16
  import typing as tp
17
-
18
  from einops import rearrange
19
  import torch
20
  import torch.nn as nn
21
  from torch.nn import functional as F
22
- from torch.utils.checkpoint import checkpoint as torch_checkpoint
23
  from xformers import ops
24
 
25
- from .rope import RotaryEmbedding
26
- from .streaming import StreamingModule
27
-
28
  _efficient_attention_backend: str = 'torch'
29
 
30
 
@@ -35,14 +15,10 @@ def set_efficient_attention_backend(backend: str = 'torch'):
35
  _efficient_attention_backend = backend
36
 
37
 
38
- def _get_attention_time_dimension(memory_efficient: bool) -> int:
39
- if _efficient_attention_backend == 'torch' and memory_efficient:
40
- return 2
41
- else:
42
- return 1
43
 
44
 
45
- def _is_profiled() -> bool:
 
46
  # Return true if we are currently running with a xformers profiler activated.
47
  try:
48
  from xformers.profiler import profiler
@@ -51,16 +27,8 @@ def _is_profiled() -> bool:
51
  return profiler._Profiler._CURRENT_PROFILER is not None
52
 
53
 
54
- def create_norm_fn(norm_type: str, dim: int, **kwargs) -> nn.Module:
55
- """Create normalization module for transformer encoder layer.
56
 
57
- Args:
58
- norm_type (str): Normalization method.
59
- dim (int): Dimension of the normalized layer.
60
- **kwargs (dict): Additional parameters for normalization layer.
61
- Returns:
62
- nn.Module: Normalization module.
63
- """
64
  if norm_type == 'layer_norm':
65
  return nn.LayerNorm(dim, eps=1e-5, **kwargs)
66
  else:
@@ -86,87 +54,26 @@ def create_sin_embedding(positions: torch.Tensor, dim: int, max_period: float =
86
  adim = torch.arange(half_dim, device=positions.device, dtype=dtype).view(1, 1, -1)
87
  max_period_tensor = torch.full([], max_period, device=positions.device, dtype=dtype) # avoid sync point
88
  phase = positions / (max_period_tensor ** (adim / (half_dim - 1)))
89
- print('==============CONCAT 3 ============')
90
- return torch.cat([torch.cos(phase), torch.sin(phase)], dim=-1)
91
 
92
 
93
- def expand_repeated_kv(x: torch.Tensor, n_rep: int, memory_efficient: bool) -> torch.Tensor:
94
- """torch.repeat_interleave(x, dim=2, repeats=n_rep) from xlformers."""
95
- if n_rep == 1:
96
- return x
97
- if _efficient_attention_backend == 'torch' and memory_efficient:
98
- bs, n_kv_heads, slen, head_dim = x.shape
99
- return (
100
- x[:, :, None, :, :]
101
- .expand(bs, n_kv_heads, n_rep, slen, head_dim)
102
- .reshape(bs, n_kv_heads * n_rep, slen, head_dim)
103
- )
104
- else:
105
- bs, slen, n_kv_heads, head_dim = x.shape
106
- return (
107
- x[:, :, :, None, :]
108
- .expand(bs, slen, n_kv_heads, n_rep, head_dim)
109
- .reshape(bs, slen, n_kv_heads * n_rep, head_dim)
110
- )
111
 
112
 
113
- class LayerScale(nn.Module):
114
- """Layer scale from [Touvron et al 2021] (https://arxiv.org/pdf/2103.17239.pdf).
115
- This rescales diagonally the residual outputs close to 0, with a learnt scale.
116
 
117
- Args:
118
- channels (int): Number of channels.
119
- init (float): Initial scale.
120
- channel_last (bool): If True, expect `[*, C]` shaped tensors, otherwise, `[*, C, T]`.
121
- device (torch.device or str, optional): Device on which to initialize the module.
122
- dtype (torch.dtype, optional): dtype to use to initialize the module.
123
- """
124
- def __init__(self, channels: int, init: float = 1e-4, channel_last: bool = True,
125
- device=None, dtype=None):
126
- super().__init__()
127
- self.channel_last = channel_last
128
- self.scale = nn.Parameter(
129
- torch.full((channels,), init,
130
- requires_grad=True, device=device, dtype=dtype))
131
-
132
- def forward(self, x: torch.Tensor):
133
- if self.channel_last:
134
- return self.scale * x
135
- else:
136
- return self.scale[:, None] * x
137
 
138
 
139
- class StreamingMultiheadAttention(StreamingModule):
140
- """Similar to `nn.MultiheadAttention` but with support for streaming, causal evaluation.
141
 
142
- Args:
143
- embed_dim (int): Dimension to project to.
144
- num_heads (int): Number of heads.
145
- dropout (float): Dropout level.
146
- bias (bool): Use bias in projections.
147
- causal (bool): Causal mask applied automatically.
148
- past_context (int, optional): Receptive field for the causal mask, infinite if None.
149
- custom (bool): Use custom MHA implementation, for testing / benchmarking.
150
- memory_efficient (bool): Use xformers based memory efficient attention.
151
- attention_as_float32 (bool): Perform the attention as float32
152
- (especially important with memory_efficient as autocast won't do this automatically).
153
- rope (`RotaryEmbedding`, optional): Rope embedding to use.
154
- cross_attention: Should be true when used as a cross attention.
155
- All keys and values must be available at once, streaming is only for the queries.
156
- Cannot be used with `causal` or `rope` (as it wouldn't make sens to
157
- interpret the time steps in the keys relative to those in the queries).
158
- safe_streaming (bool): Bug fix, will go away with xformers update.
159
- qk_layer_norm (bool): Layer normalization applied to queries and keys before dot product.
160
- kv_repeat (int): If > 1, will repeat keys and queries multiple times (need to divide num_heads).
161
- This will lead to faster decoding time on A100 or other GPUs with tensorcore.
162
- device (torch.device, optional): Device on which to initialize.
163
- dtype (torch.dtype, optional): dtype to use.
164
- """
165
- def __init__(self, embed_dim: int, num_heads: int, dropout: float = 0.0, bias: bool = True,
166
  causal: bool = False, past_context: tp.Optional[int] = None, custom: bool = False,
167
  memory_efficient: bool = False, attention_as_float32: bool = False,
168
- rope: tp.Optional[RotaryEmbedding] = None, cross_attention: bool = False,
169
- safe_streaming: bool = True, qk_layer_norm: bool = False, kv_repeat: int = 1,
170
  device=None, dtype=None):
171
  super().__init__()
172
  factory_kwargs = {'device': device, 'dtype': dtype}
@@ -178,15 +85,15 @@ class StreamingMultiheadAttention(StreamingModule):
178
  self.past_context = past_context
179
  self.memory_efficient = memory_efficient
180
  self.attention_as_float32 = attention_as_float32
181
- self.rope = rope
182
  self.cross_attention = cross_attention
183
- self.safe_streaming = safe_streaming
184
  self.num_heads = num_heads
185
  self.dropout = dropout
186
  self.kv_repeat = kv_repeat
187
  if cross_attention:
188
  assert not causal, "Causal cannot work with cross attention."
189
- assert rope is None, "Rope cannot work with cross attention."
190
 
191
  if memory_efficient:
192
  _verify_xformers_memory_efficient_compat()
@@ -231,123 +138,42 @@ class StreamingMultiheadAttention(StreamingModule):
231
  state_dict[prefix + "mha." + key] = state_dict.pop(prefix + key)
232
  super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)
233
 
234
- def _get_mask(self, current_steps: int, device: torch.device, dtype: torch.dtype):
235
- # Return a causal mask, accounting for potentially stored past keys/values
236
- # We actually return a bias for the attention score, as this has the same
237
- # convention both in the builtin MHA in Pytorch, and Xformers functions.
238
- time_dim = _get_attention_time_dimension(self.memory_efficient)
239
- if self.memory_efficient:
240
- from xformers.ops import LowerTriangularMask
241
- if current_steps == 1:
242
- # If we only have one step, then we do not need a mask.
243
- return None
244
- elif 'past_keys' in self._streaming_state:
245
- raise RuntimeError("Not supported at the moment")
246
- else:
247
- # Then we can safely use a lower triangular mask
248
- return LowerTriangularMask()
249
- if self._streaming_state:
250
- past_keys = self._streaming_state['past_keys']
251
- past_steps = past_keys.shape[time_dim]
252
- else:
253
- past_steps = 0
254
-
255
- queries_pos = torch.arange(
256
- past_steps, current_steps + past_steps, device=device).view(-1, 1)
257
- keys_pos = torch.arange(past_steps + current_steps, device=device).view(1, -1)
258
- delta = queries_pos - keys_pos
259
- valid = delta >= 0
260
- if self.past_context is not None:
261
- valid &= (delta <= self.past_context)
262
- return torch.where(
263
- valid,
264
- torch.zeros([], device=device, dtype=dtype),
265
- torch.full([], float('-inf'), device=device, dtype=dtype))
266
-
267
- def _complete_kv(self, k, v):
268
-
269
- time_dim = _get_attention_time_dimension(self.memory_efficient)
270
- if self.cross_attention:
271
- # With cross attention we assume all keys and values
272
- # are already available, and streaming is with respect
273
- # to the queries only.
274
- return k, v
275
- # Complete the key/value pair using the streaming state.
276
- if self._streaming_state:
277
- pk = self._streaming_state['past_keys']
278
- nk = torch.cat([pk, k], dim=time_dim)
279
- print('==============CONCAT 1===============')
280
- if v is k:
281
- nv = nk
282
- else:
283
- pv = self._streaming_state['past_values']
284
- nv = torch.cat([pv, v], dim=time_dim)
285
- print('==============CONCAT 2================')
286
- else:
287
- nk = k
288
- nv = v
289
-
290
- assert nk.shape[time_dim] == nv.shape[time_dim]
291
- offset = 0
292
- if self.past_context is not None:
293
- offset = max(0, nk.shape[time_dim] - self.past_context)
294
- if self._is_streaming:
295
- self._streaming_state['past_keys'] = nk[:, offset:]
296
- if v is not k:
297
- self._streaming_state['past_values'] = nv[:, offset:]
298
- if 'offset' in self._streaming_state:
299
- self._streaming_state['offset'] += offset
300
- else:
301
- self._streaming_state['offset'] = torch.tensor(0)
302
- return nk, nv
303
-
304
- def _apply_rope(self, query: torch.Tensor, key: torch.Tensor):
305
- time_dim = _get_attention_time_dimension(self.memory_efficient)
306
- # Apply rope embeddings to query and key tensors.
307
- assert self.rope is not None
308
- if 'past_keys' in self._streaming_state:
309
- past_keys_offset = self._streaming_state['past_keys'].shape[1]
310
- else:
311
- past_keys_offset = 0
312
- if 'offset' in self._streaming_state:
313
- past_context_offset = int(self._streaming_state['offset'].item())
314
- else:
315
- past_context_offset = 0
316
- streaming_offset = past_context_offset + past_keys_offset
317
- return self.rope.rotate_qk(query, key, start=streaming_offset, time_dim=time_dim)
318
 
319
- def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor,
320
- key_padding_mask=None, need_weights=False, attn_mask=None,
321
- average_attn_weights=True, is_causal=False):
 
 
 
 
 
 
 
 
 
 
 
322
  assert not is_causal, ("New param added in torch 2.0.1 not supported, "
323
  "use the causal args in the constructor.")
324
-
325
- time_dim = _get_attention_time_dimension(self.memory_efficient)
326
  if time_dim == 2:
327
  layout = "b h t d"
328
  else:
329
  layout = "b t h d"
330
  dtype = query.dtype
331
- if self._is_streaming:
332
- assert self.causal or self.cross_attention, \
333
- "Streaming only available for causal or cross attention"
334
 
335
  custom_attn_mask = attn_mask is not None
336
 
337
- if self.causal:
338
- assert attn_mask is None
339
- # At the moment we specialize only for the self-attention case.
340
- assert query.shape[1] == key.shape[1], "Causal only for same length query / key / value"
341
- assert value.shape[1] == key.shape[1], "Causal only for same length query / key / value"
342
- attn_mask = self._get_mask(query.shape[1], query.device, query.dtype)
343
-
344
  if self.custom:
345
  # custom implementation
346
  assert need_weights is False
347
  assert key_padding_mask is None
348
  if self.cross_attention:
349
- # Different queries, keys, values, we have to spit manually the weights
350
- # before applying the linear.
 
351
  dim = self.in_proj_weight.shape[0] // 3
352
  if self.in_proj_bias is None:
353
  bias_q, bias_k, bias_v = None, None, None
@@ -356,14 +182,23 @@ class StreamingMultiheadAttention(StreamingModule):
356
  bias_k = self.in_proj_bias[dim: 2 * dim]
357
  bias_v = self.in_proj_bias[2 * dim:]
358
  q = nn.functional.linear(query, self.in_proj_weight[:dim], bias_q)
 
359
  # todo: when streaming, we could actually save k, v and check the shape actually match.
360
  k = nn.functional.linear(key, self.in_proj_weight[dim: 2 * dim], bias_k)
361
  v = nn.functional.linear(value, self.in_proj_weight[2 * dim:], bias_v)
362
  if self.qk_layer_norm is True:
363
  q = self.q_layer_norm(q)
364
  k = self.k_layer_norm(k)
 
365
  q, k, v = [rearrange(x, f"b t (h d) -> {layout}", h=self.num_heads) for x in [q, k, v]]
 
366
  else:
 
 
 
 
 
 
367
  if not _is_profiled():
368
  # profiling breaks that propertysomehow.
369
  assert query is key, "specialized implementation"
@@ -374,8 +209,13 @@ class StreamingMultiheadAttention(StreamingModule):
374
  bound_layout = "b h p t d"
375
  else:
376
  bound_layout = "b t p h d"
 
377
  packed = rearrange(projected, f"b t (p h d) -> {bound_layout}", p=3, h=self.num_heads)
 
 
 
378
  q, k, v = ops.unbind(packed, dim=2)
 
379
  else:
380
  embed_dim = self.embed_dim
381
  per_head_dim = (embed_dim // self.num_heads)
@@ -395,12 +235,12 @@ class StreamingMultiheadAttention(StreamingModule):
395
  q = self.q_layer_norm(q)
396
  k = self.k_layer_norm(k)
397
  q, k = [rearrange(x, f"b t (h d) -> {layout}", h=self.num_heads) for x in [q, k]]
398
- if self.rope:
399
- q, k = self._apply_rope(q, k)
400
- k, v = self._complete_kv(k, v)
401
  if self.kv_repeat > 1:
402
- k = expand_repeated_kv(k, self.kv_repeat, self.memory_efficient)
403
- v = expand_repeated_kv(v, self.kv_repeat, self.memory_efficient)
 
404
  if self.attention_as_float32:
405
  q, k, v = [x.float() for x in [q, k, v]]
406
  if self.memory_efficient:
@@ -429,11 +269,8 @@ class StreamingMultiheadAttention(StreamingModule):
429
  q = q / q.shape[-1] ** 0.5
430
  key_layout = layout.replace('t', 'k')
431
  query_layout = layout
432
- if self._is_streaming and self.safe_streaming and q.device.type == 'cuda':
433
- with torch.autocast(device_type=q.device.type, dtype=torch.float32):
434
- pre_w = torch.einsum(f"{query_layout},{key_layout}-> b h t k", q, k)
435
- else:
436
- pre_w = torch.einsum(f"{query_layout},{key_layout}-> b h t k", q, k)
437
  if attn_mask is not None:
438
  pre_w = pre_w + attn_mask
439
  w = torch.softmax(pre_w, dim=-1)
@@ -444,58 +281,24 @@ class StreamingMultiheadAttention(StreamingModule):
444
  x = rearrange(x, f"{layout} -> b t (h d)", h=self.num_heads)
445
  x = self.out_proj(x)
446
  else:
447
- key, value = self._complete_kv(key, value)
448
- if self.attention_as_float32:
449
- query, key, value = [x.float() for x in [query, key, value]]
450
- x, _ = self.mha(
451
- query, key, value, key_padding_mask,
452
- need_weights, attn_mask, average_attn_weights)
453
- x = x.to(dtype)
454
 
455
  return x, None
456
 
457
 
458
  class StreamingTransformerLayer(nn.TransformerEncoderLayer):
459
- """TransformerLayer with Streaming / Causal support.
460
- This also integrates cross_attention, when passing `cross_attention=True`,
461
- rather than having two separate classes like in PyTorch.
462
-
463
- Args:
464
- d_model (int): Dimension of the data.
465
- num_heads (int): Number of heads.
466
- dim_feedforward (int): Intermediate dimension of FF module.
467
- dropout (float): Dropout both for MHA and FF.
468
- bias_ff (bool): Use bias for FF.
469
- bias_attn (bool): Use bias for MHA.
470
- causal (bool): Causal mask applied automatically.
471
- past_context (int, optional): Receptive field for the causal mask, infinite if None.
472
- custom (bool): Use custom MHA implementation, for testing / benchmarking.
473
- memory_efficient (bool): Use xformers based memory efficient attention.
474
- attention_as_float32 (bool): Perform the attention as float32
475
- (especially important with memory_efficient as autocast won't do this automatically).
476
- qk_layer_norm (bool): Layer normalization applied to queries and keys before dot product in attention.
477
- qk_layer_norm_cross (bool): Same for the cross attention.
478
- cross_attention (bool): If True, expect to get secondary input for cross-attention.
479
- Cross attention will use the default MHA, as it typically won't require
480
- special treatment.
481
- layer_scale (float, optional): If not None, LayerScale will be used with
482
- the given value as initial scale.
483
- rope (`RotaryEmbedding`, optional): Rope embedding to use.
484
- attention_dropout (float, optional): If not None, separate the value of the dimension dropout
485
- in FFN and of the attention dropout.
486
- kv_repeat (int): If > 1, will repeat keys and queries multiple times (need to divide num_heads).
487
- This will lead to faster decoding time on A100 or other GPUs with tensorcore.
488
- device (torch.device, optional): Device on which to initialize.
489
- dtype (torch.dtype, optional): dtype to use.
490
- **kwargs: See `nn.TransformerEncoderLayer`.
491
- """
492
- def __init__(self, d_model: int, num_heads: int, dim_feedforward: int = 2048, dropout: float = 0.1,
493
  bias_ff: bool = True, bias_attn: bool = True, causal: bool = False,
494
  past_context: tp.Optional[int] = None, custom: bool = False,
495
  memory_efficient: bool = False, attention_as_float32: bool = False,
496
  qk_layer_norm: bool = False, qk_layer_norm_cross: bool = False,
497
- cross_attention: bool = False, layer_scale: tp.Optional[float] = None,
498
- rope: tp.Optional[RotaryEmbedding] = None, attention_dropout: tp.Optional[float] = None,
 
499
  kv_repeat: int = 1, norm: str = 'layer_norm', device=None, dtype=None, **kwargs):
500
  super().__init__(d_model, num_heads, dim_feedforward, dropout,
501
  device=device, dtype=dtype, batch_first=True, **kwargs)
@@ -511,22 +314,17 @@ class StreamingTransformerLayer(nn.TransformerEncoderLayer):
511
  'attention_as_float32': attention_as_float32,
512
  }
513
  self.self_attn: StreamingMultiheadAttention = StreamingMultiheadAttention(
514
- causal=causal, past_context=past_context, rope=rope, qk_layer_norm=qk_layer_norm,
 
 
515
  kv_repeat=kv_repeat, **attn_kwargs, **factory_kwargs) # type: ignore
516
  # Redefine feedforward layers to expose bias parameter
517
  self.linear1 = nn.Linear(d_model, dim_feedforward, bias=bias_ff, **factory_kwargs)
518
  self.linear2 = nn.Linear(dim_feedforward, d_model, bias=bias_ff, **factory_kwargs)
519
 
520
- self.layer_scale_1: nn.Module
521
- self.layer_scale_2: nn.Module
522
- if layer_scale is None:
523
- self.layer_scale_1 = nn.Identity()
524
- self.layer_scale_2 = nn.Identity()
525
- else:
526
- self.layer_scale_1 = LayerScale(d_model, layer_scale, **factory_kwargs)
527
- self.layer_scale_2 = LayerScale(d_model, layer_scale, **factory_kwargs)
528
 
529
- self.cross_attention: tp.Optional[nn.Module] = None
530
  if cross_attention:
531
  self.cross_attention = StreamingMultiheadAttention(
532
  cross_attention=True, qk_layer_norm=qk_layer_norm_cross,
@@ -535,98 +333,69 @@ class StreamingTransformerLayer(nn.TransformerEncoderLayer):
535
  self.dropout_cross = nn.Dropout(dropout)
536
  # eps value matching that used in PyTorch reference implementation.
537
  self.norm_cross = nn.LayerNorm(d_model, eps=1e-5, **factory_kwargs)
538
- self.layer_scale_cross: nn.Module
539
- if layer_scale is None:
540
- self.layer_scale_cross = nn.Identity()
541
- else:
542
- self.layer_scale_cross = LayerScale(d_model, layer_scale, **factory_kwargs)
543
  self.norm1 = create_norm_fn(norm, d_model, **factory_kwargs) # type: ignore
544
  self.norm2 = create_norm_fn(norm, d_model, **factory_kwargs) # type: ignore
545
 
546
- def _cross_attention_block(self, src: torch.Tensor,
547
- cross_attention_src: torch.Tensor) -> torch.Tensor:
548
- assert self.cross_attention is not None
 
549
  # queries are from src, keys and values from cross_attention_src.
550
  x = self.cross_attention(
551
  src, cross_attention_src, cross_attention_src, need_weights=False)[0]
552
  return self.dropout_cross(x) # type: ignore
553
 
554
- def forward(self, src: torch.Tensor, src_mask: tp.Optional[torch.Tensor] = None, # type: ignore
555
- src_key_padding_mask: tp.Optional[torch.Tensor] = None,
556
- cross_attention_src: tp.Optional[torch.Tensor] = None):
557
- if self.cross_attention is None:
558
- assert cross_attention_src is None
559
- else:
560
- assert cross_attention_src is not None
561
  x = src
562
  if self.norm_first:
563
- x = x + self.layer_scale_1(
564
- self._sa_block(self.norm1(x), src_mask, src_key_padding_mask))
 
 
 
 
565
  if cross_attention_src is not None:
566
- x = x + self.layer_scale_cross(
567
- self._cross_attention_block(
568
- self.norm_cross(x), cross_attention_src))
569
- x = x + self.layer_scale_2(self._ff_block(self.norm2(x)))
 
 
 
 
570
  else:
571
- x = self.norm1(x + self.layer_scale_1(
572
- self._sa_block(x, src_mask, src_key_padding_mask)))
573
- if cross_attention_src is not None:
574
- x = self.norm_cross(
575
- x + self.layer_scale_cross(
576
- self._cross_attention_block(src, cross_attention_src)))
577
- x = self.norm2(x + self.layer_scale_2(self._ff_block(x)))
578
  return x
579
 
580
 
581
- class StreamingTransformer(StreamingModule):
582
- """Transformer with Streaming / Causal support.
583
-
584
- Args:
585
- d_model (int): Dimension of the data.
586
- num_heads (int): Number of heads.
587
- dim_feedforward (int): Intermediate dimension of FF module.
588
- dropout (float): Dropout both for MHA and FF.
589
- bias_ff (bool): Use bias for FF.
590
- bias_attn (bool): Use bias for MHA.
591
- causal (bool): Causal mask applied automatically.
592
- past_context (int, optional): Receptive field for the causal mask, infinite if None.
593
- custom (bool): Use custom MHA implementation, for testing / benchmarking.
594
- memory_efficient (bool): Use xformers based memory efficient attention.
595
- attention_as_float32 (bool): Perform the attention as float32
596
- (especially important with memory_efficient as autocast won't do this automatically).
597
- cross_attention (bool): If True, expect to get secondary input for cross-attention.
598
- layer_scale (float, optional): If not None, LayerScale will be used
599
- with the given value as initial scale.
600
- positional_embedding (str): Positional embedding strategy (sin, rope, or sin_rope).
601
- max_period (float): Maximum period of the time embedding.
602
- positional_scale (float): Scale of positional embedding, set to 0 to deactivate.
603
- xpos (bool): Apply xpos exponential decay to positional embedding (rope only).
604
- lr (float, optional): learning rate override through the `make_optim_group` API.
605
- weight_decay (float, optional): Weight_decay override through the `make_optim_group` API.
606
- layer_class: (subclass of `StreamingTransformerLayer): class to use
607
- to initialize the layers, allowing further customization outside of AudioCraft.
608
- checkpointing (str): Checkpointing strategy to reduce memory usage.
609
- No checkpointing if set to 'none'. Per layer checkpointing using PyTorch
610
- if set to 'torch' (entire layer checkpointed, i.e. linears are evaluated twice,
611
- minimal memory usage, but maximal runtime). Finally, `xformers_default` provide
612
- a policy for opting-out some operations of the checkpointing like
613
- linear layers and attention, providing a middle ground between speed and memory.
614
- device (torch.device, optional): Device on which to initialize.
615
- dtype (torch.dtype, optional): dtype to use.
616
- **kwargs: See `nn.TransformerEncoderLayer`.
617
- """
618
  def __init__(self, d_model: int, num_heads: int, num_layers: int, dim_feedforward: int = 2048,
619
  dropout: float = 0.1, bias_ff: bool = True, bias_attn: bool = True,
620
  causal: bool = False, past_context: tp.Optional[int] = None,
621
  custom: bool = False, memory_efficient: bool = False, attention_as_float32: bool = False,
622
- cross_attention: bool = False, layer_scale: tp.Optional[float] = None,
623
  positional_embedding: str = 'sin', max_period: float = 10_000, positional_scale: float = 1.,
624
- xpos: bool = False, lr: tp.Optional[float] = None, weight_decay: tp.Optional[float] = None,
625
- layer_class: tp.Type[StreamingTransformerLayer] = StreamingTransformerLayer,
626
- checkpointing: str = 'none', device=None, dtype=None, **kwargs):
 
 
 
 
 
627
  super().__init__()
628
  assert d_model % num_heads == 0
629
-
630
  self.positional_embedding = positional_embedding
631
  self.max_period = max_period
632
  self.positional_scale = positional_scale
@@ -634,12 +403,6 @@ class StreamingTransformer(StreamingModule):
634
  self.lr = lr
635
 
636
  assert positional_embedding in ['sin', 'rope', 'sin_rope']
637
- self.rope: tp.Optional[RotaryEmbedding] = None
638
- if self.positional_embedding in ['rope', 'sin_rope']:
639
- assert _is_custom(custom, memory_efficient)
640
- self.rope = RotaryEmbedding(d_model // num_heads, max_period=max_period,
641
- xpos=xpos, scale=positional_scale, device=device)
642
-
643
  self.checkpointing = checkpointing
644
 
645
  assert checkpointing in ['none', 'torch', 'xformers_default', 'xformers_mm']
@@ -654,7 +417,8 @@ class StreamingTransformer(StreamingModule):
654
  dropout=dropout, bias_ff=bias_ff, bias_attn=bias_attn,
655
  causal=causal, past_context=past_context, custom=custom,
656
  memory_efficient=memory_efficient, attention_as_float32=attention_as_float32,
657
- cross_attention=cross_attention, layer_scale=layer_scale, rope=self.rope,
 
658
  device=device, dtype=dtype, **kwargs))
659
 
660
  if self.checkpointing != 'none':
@@ -663,58 +427,37 @@ class StreamingTransformer(StreamingModule):
663
  # backward hook inside of FSDP...
664
  layer._magma_checkpointed = True # type: ignore
665
 
666
- def _apply_layer(self, layer, *args, **kwargs):
667
- method = self.checkpointing
668
- if method == 'none':
669
- return layer(*args, **kwargs)
670
- elif method == 'torch':
671
- return torch_checkpoint(layer, *args, use_reentrant=False, **kwargs)
672
- elif method.startswith('xformers'):
673
- from xformers.checkpoint_fairinternal import checkpoint, _get_default_policy
674
- if method == 'xformers_default':
675
- # those operations will be saved, and not recomputed.
676
- # According to Francisco we can get smarter policies but this is a good start.
677
- allow_list = [
678
- "xformers.efficient_attention_forward_cutlass.default",
679
- "xformers_flash.flash_fwd.default",
680
- "aten.addmm.default",
681
- "aten.mm.default",
682
- ]
683
- elif method == 'xformers_mm':
684
- # those operations will be saved, and not recomputed.
685
- # According to Francisco we can get smarter policies but this is a good start.
686
- allow_list = [
687
- "aten.addmm.default",
688
- "aten.mm.default",
689
- ]
690
- else:
691
- raise ValueError(f"xformers checkpointing xformers policy {method} is not known.")
692
- policy_fn = _get_default_policy(allow_list)
693
- return checkpoint(layer, *args, policy_fn=policy_fn, **kwargs)
694
- else:
695
- raise ValueError(f"Checkpointing method {method} is unknown.")
696
 
697
  def forward(self, x: torch.Tensor, *args, **kwargs):
698
-
 
699
  B, T, C = x.shape
700
 
701
- if 'offsets' in self._streaming_state:
702
- offsets = self._streaming_state['offsets']
703
- else:
704
- offsets = torch.zeros(B, dtype=torch.long, device=x.device)
705
 
706
- if self.positional_embedding in ['sin', 'sin_rope']:
 
707
  positions = torch.arange(T, device=x.device).view(1, -1, 1)
708
- positions = positions + offsets.view(-1, 1, 1)
709
  pos_emb = create_sin_embedding(positions, C, max_period=self.max_period, dtype=x.dtype)
710
  x = x + self.positional_scale * pos_emb
711
-
712
- for layer in self.layers:
713
- x = self._apply_layer(layer, x, *args, **kwargs)
714
-
715
- if self._is_streaming:
716
- self._streaming_state['offsets'] = offsets + T
717
-
 
 
 
 
 
 
 
 
718
  return x
719
 
720
  def make_optim_group(self):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import typing as tp
 
2
  from einops import rearrange
3
  import torch
4
  import torch.nn as nn
5
  from torch.nn import functional as F
 
6
  from xformers import ops
7
 
 
 
 
8
  _efficient_attention_backend: str = 'torch'
9
 
10
 
 
15
  _efficient_attention_backend = backend
16
 
17
 
 
 
 
 
 
18
 
19
 
20
+
21
+ def _is_profiled():
22
  # Return true if we are currently running with a xformers profiler activated.
23
  try:
24
  from xformers.profiler import profiler
 
27
  return profiler._Profiler._CURRENT_PROFILER is not None
28
 
29
 
30
+ def create_norm_fn(norm_type, dim, **kwargs):
 
31
 
 
 
 
 
 
 
 
32
  if norm_type == 'layer_norm':
33
  return nn.LayerNorm(dim, eps=1e-5, **kwargs)
34
  else:
 
54
  adim = torch.arange(half_dim, device=positions.device, dtype=dtype).view(1, 1, -1)
55
  max_period_tensor = torch.full([], max_period, device=positions.device, dtype=dtype) # avoid sync point
56
  phase = positions / (max_period_tensor ** (adim / (half_dim - 1)))
57
+ # print('==============CONCAT 3 ============'
58
+ return torch.cat([torch.cos(phase), torch.sin(phase)], dim=-1)
59
 
60
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
 
62
 
 
 
 
63
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
 
65
 
 
 
66
 
67
+ class StreamingMultiheadAttention(nn.Module):
68
+
69
+ def __init__(self,
70
+ embed_dim,
71
+ num_heads,
72
+ dropout=0.0, bias: bool = True,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
  causal: bool = False, past_context: tp.Optional[int] = None, custom: bool = False,
74
  memory_efficient: bool = False, attention_as_float32: bool = False,
75
+ cross_attention: bool = False,
76
+ qk_layer_norm: bool = False, kv_repeat: int = 1,
77
  device=None, dtype=None):
78
  super().__init__()
79
  factory_kwargs = {'device': device, 'dtype': dtype}
 
85
  self.past_context = past_context
86
  self.memory_efficient = memory_efficient
87
  self.attention_as_float32 = attention_as_float32
88
+
89
  self.cross_attention = cross_attention
90
+
91
  self.num_heads = num_heads
92
  self.dropout = dropout
93
  self.kv_repeat = kv_repeat
94
  if cross_attention:
95
  assert not causal, "Causal cannot work with cross attention."
96
+
97
 
98
  if memory_efficient:
99
  _verify_xformers_memory_efficient_compat()
 
138
  state_dict[prefix + "mha." + key] = state_dict.pop(prefix + key)
139
  super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)
140
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
141
 
142
+
143
+
144
+
145
+
146
+
147
+ def forward(self,
148
+ query,
149
+ key,
150
+ value,
151
+ key_padding_mask=None,
152
+ need_weights=False,
153
+ attn_mask=None,
154
+ is_causal=False):
155
+
156
  assert not is_causal, ("New param added in torch 2.0.1 not supported, "
157
  "use the causal args in the constructor.")
158
+ # print(f'{query.shape=} {key.shape=} {value.shape=} MHA')
159
+ time_dim = 2
160
  if time_dim == 2:
161
  layout = "b h t d"
162
  else:
163
  layout = "b t h d"
164
  dtype = query.dtype
165
+
 
 
166
 
167
  custom_attn_mask = attn_mask is not None
168
 
 
 
 
 
 
 
 
169
  if self.custom:
170
  # custom implementation
171
  assert need_weights is False
172
  assert key_padding_mask is None
173
  if self.cross_attention:
174
+ # print('\n\n\n\nCROSS\n\n\n\n')
175
+
176
+
177
  dim = self.in_proj_weight.shape[0] // 3
178
  if self.in_proj_bias is None:
179
  bias_q, bias_k, bias_v = None, None, None
 
182
  bias_k = self.in_proj_bias[dim: 2 * dim]
183
  bias_v = self.in_proj_bias[2 * dim:]
184
  q = nn.functional.linear(query, self.in_proj_weight[:dim], bias_q)
185
+ # print(f'{q.shape=} TRANSF FORW who concaten')
186
  # todo: when streaming, we could actually save k, v and check the shape actually match.
187
  k = nn.functional.linear(key, self.in_proj_weight[dim: 2 * dim], bias_k)
188
  v = nn.functional.linear(value, self.in_proj_weight[2 * dim:], bias_v)
189
  if self.qk_layer_norm is True:
190
  q = self.q_layer_norm(q)
191
  k = self.k_layer_norm(k)
192
+
193
  q, k, v = [rearrange(x, f"b t (h d) -> {layout}", h=self.num_heads) for x in [q, k, v]]
194
+ # print(f'{q.shape=} {k.shape=} {v.shape=} after rearrange')
195
  else:
196
+ # print('\n\n\n\nSELF\n\n\n\n')
197
+ #
198
+ # 47x Transformers selfattn followed by crossattn
199
+ #
200
+ # self-attn is on history? previous key or is it on only the last token?
201
+
202
  if not _is_profiled():
203
  # profiling breaks that propertysomehow.
204
  assert query is key, "specialized implementation"
 
209
  bound_layout = "b h p t d"
210
  else:
211
  bound_layout = "b t p h d"
212
+
213
  packed = rearrange(projected, f"b t (p h d) -> {bound_layout}", p=3, h=self.num_heads)
214
+
215
+
216
+ # print(f'{query.shape=} before unbind') # [2, 1, 4 , 2048] already bs=2
217
  q, k, v = ops.unbind(packed, dim=2)
218
+ # print(f'{q.shape=} {v.shape=} @L331 trasnforemr.py') # packed is bs=2
219
  else:
220
  embed_dim = self.embed_dim
221
  per_head_dim = (embed_dim // self.num_heads)
 
235
  q = self.q_layer_norm(q)
236
  k = self.k_layer_norm(k)
237
  q, k = [rearrange(x, f"b t (h d) -> {layout}", h=self.num_heads) for x in [q, k]]
238
+
239
+
 
240
  if self.kv_repeat > 1:
241
+ #
242
+ print('Expand repear 2')
243
+
244
  if self.attention_as_float32:
245
  q, k, v = [x.float() for x in [q, k, v]]
246
  if self.memory_efficient:
 
269
  q = q / q.shape[-1] ** 0.5
270
  key_layout = layout.replace('t', 'k')
271
  query_layout = layout
272
+
273
+ pre_w = torch.einsum(f"{query_layout},{key_layout}-> b h t k", q, k)
 
 
 
274
  if attn_mask is not None:
275
  pre_w = pre_w + attn_mask
276
  w = torch.softmax(pre_w, dim=-1)
 
281
  x = rearrange(x, f"{layout} -> b t (h d)", h=self.num_heads)
282
  x = self.out_proj(x)
283
  else:
284
+ raise NotImplementedError
 
 
 
 
 
 
285
 
286
  return x, None
287
 
288
 
289
  class StreamingTransformerLayer(nn.TransformerEncoderLayer):
290
+ def __init__(self,
291
+ d_model,
292
+ num_heads,
293
+ dim_feedforward=2048,
294
+ dropout=0.1,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
295
  bias_ff: bool = True, bias_attn: bool = True, causal: bool = False,
296
  past_context: tp.Optional[int] = None, custom: bool = False,
297
  memory_efficient: bool = False, attention_as_float32: bool = False,
298
  qk_layer_norm: bool = False, qk_layer_norm_cross: bool = False,
299
+ cross_attention: bool = False,
300
+ # rope=None,
301
+ attention_dropout: tp.Optional[float] = None,
302
  kv_repeat: int = 1, norm: str = 'layer_norm', device=None, dtype=None, **kwargs):
303
  super().__init__(d_model, num_heads, dim_feedforward, dropout,
304
  device=device, dtype=dtype, batch_first=True, **kwargs)
 
314
  'attention_as_float32': attention_as_float32,
315
  }
316
  self.self_attn: StreamingMultiheadAttention = StreamingMultiheadAttention(
317
+ causal=causal, past_context=past_context,
318
+ # rope=rope,
319
+ qk_layer_norm=qk_layer_norm,
320
  kv_repeat=kv_repeat, **attn_kwargs, **factory_kwargs) # type: ignore
321
  # Redefine feedforward layers to expose bias parameter
322
  self.linear1 = nn.Linear(d_model, dim_feedforward, bias=bias_ff, **factory_kwargs)
323
  self.linear2 = nn.Linear(dim_feedforward, d_model, bias=bias_ff, **factory_kwargs)
324
 
325
+
 
 
 
 
 
 
 
326
 
327
+ self.cross_attention = None # default
328
  if cross_attention:
329
  self.cross_attention = StreamingMultiheadAttention(
330
  cross_attention=True, qk_layer_norm=qk_layer_norm_cross,
 
333
  self.dropout_cross = nn.Dropout(dropout)
334
  # eps value matching that used in PyTorch reference implementation.
335
  self.norm_cross = nn.LayerNorm(d_model, eps=1e-5, **factory_kwargs)
336
+
 
 
 
 
337
  self.norm1 = create_norm_fn(norm, d_model, **factory_kwargs) # type: ignore
338
  self.norm2 = create_norm_fn(norm, d_model, **factory_kwargs) # type: ignore
339
 
340
+ def _cross_attention_block(self,
341
+ src,
342
+ cross_attention_src):
343
+
344
  # queries are from src, keys and values from cross_attention_src.
345
  x = self.cross_attention(
346
  src, cross_attention_src, cross_attention_src, need_weights=False)[0]
347
  return self.dropout_cross(x) # type: ignore
348
 
349
+ def forward(self,
350
+ src,
351
+ src_mask=None,
352
+ src_key_padding_mask=None, # key = value = looooong I think I pass them inversed
353
+ cross_attention_src=None):
354
+
355
+
356
  x = src
357
  if self.norm_first:
358
+ # print('selfattn', x.shape, src_mask, src_key_padding_mask)
359
+ x = x + self._sa_block(self.norm1(x),
360
+ src_mask, #None
361
+ src_key_padding_mask # None
362
+ ) # Internal nn
363
+ # print('crossattn', x.shape, cross_attention_src.shape)
364
  if cross_attention_src is not None:
365
+ x = x + self._cross_attention_block(
366
+ self.norm_cross(x),
367
+ cross_attention_src)
368
+ # selfattn torch.Size([2, 2, 1536]) None None NO 4D TOKEN!
369
+ # crossattn torch.Size([2, 2, 1536]) torch.Size([2, 4, 1536])
370
+ else:
371
+ raise NotImplementedError # all layers have a self & cross?
372
+ x = x + self._ff_block(self.norm2(x))
373
  else:
374
+ print('NLAST')
375
+ # print('NT', x.shape) # [1,2 ,1536]
 
 
 
 
 
376
  return x
377
 
378
 
379
+ class StreamingTransformer(nn.Module):
380
+ '''layer_class=<class 'audiocraft.transformer.StreamingTransformerLayer'> StrTrnsf'''
381
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
382
  def __init__(self, d_model: int, num_heads: int, num_layers: int, dim_feedforward: int = 2048,
383
  dropout: float = 0.1, bias_ff: bool = True, bias_attn: bool = True,
384
  causal: bool = False, past_context: tp.Optional[int] = None,
385
  custom: bool = False, memory_efficient: bool = False, attention_as_float32: bool = False,
386
+ cross_attention: bool = False,
387
  positional_embedding: str = 'sin', max_period: float = 10_000, positional_scale: float = 1.,
388
+ xpos=False,
389
+ lr=None,
390
+ weight_decay=None,
391
+ layer_class=StreamingTransformerLayer,
392
+ checkpointing='none',
393
+ device=None,
394
+ dtype=None,
395
+ **kwargs):
396
  super().__init__()
397
  assert d_model % num_heads == 0
398
+
399
  self.positional_embedding = positional_embedding
400
  self.max_period = max_period
401
  self.positional_scale = positional_scale
 
403
  self.lr = lr
404
 
405
  assert positional_embedding in ['sin', 'rope', 'sin_rope']
 
 
 
 
 
 
406
  self.checkpointing = checkpointing
407
 
408
  assert checkpointing in ['none', 'torch', 'xformers_default', 'xformers_mm']
 
417
  dropout=dropout, bias_ff=bias_ff, bias_attn=bias_attn,
418
  causal=causal, past_context=past_context, custom=custom,
419
  memory_efficient=memory_efficient, attention_as_float32=attention_as_float32,
420
+ cross_attention=cross_attention,
421
+ # rope=self.rope,
422
  device=device, dtype=dtype, **kwargs))
423
 
424
  if self.checkpointing != 'none':
 
427
  # backward hook inside of FSDP...
428
  layer._magma_checkpointed = True # type: ignore
429
 
430
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
431
 
432
  def forward(self, x: torch.Tensor, *args, **kwargs):
433
+ # print(f'{x.shape=} StreamingTransf') # [1, 1, 1536] Always no batch==2 here
434
+ # why is this called with time-len = 1? Shouldnt be called with context?
435
  B, T, C = x.shape
436
 
437
+
438
+
 
 
439
 
440
+ if self.positional_embedding in ['sin',
441
+ 'sin_rope']:
442
  positions = torch.arange(T, device=x.device).view(1, -1, 1)
443
+
444
  pos_emb = create_sin_embedding(positions, C, max_period=self.max_period, dtype=x.dtype)
445
  x = x + self.positional_scale * pos_emb
446
+ # UNTIL HERE BATCH=1
447
+ for _, lay in enumerate(self.layers):
448
+ # if _ < 2:
449
+ # L=0 [1,1,1536]
450
+ # L=1 [2,1,1536]
451
+
452
+ print(f'L={_} {args=} {kwargs["cross_attention_src"].shape=} {x.shape=} StreamTransf ForLoop') # [2, 1, 1536] BATCH=2
453
+ # x = self._apply_layer(layer, x, *args, **kwargs)
454
+ # x = lay(x, **kwargs)
455
+ x = lay(x,
456
+ cross_attention_src=kwargs["cross_attention_src"],
457
+ src_mask=kwargs['src_mask'])
458
+ # concat old token to query oh not here is on lm generate
459
+ print('OUT OF Tall', x.shape) # [1,2,1536] # why this gets filled with sequence 1,2...
460
+ # should be 1 query
461
  return x
462
 
463
  def make_optim_group(self):
demo.py CHANGED
@@ -7,7 +7,7 @@ print('\n\n\n\n___________________')
7
  txt = 'dogs in street'
8
 
9
  sound_generator = AudioGen.get_pretrained('facebook/audiogen-medium')
10
- sound_generator.set_generation_params(duration=.7) # why is generating so long at 14 seconds
11
 
12
  x = sound_generator.generate([txt])[0].detach().cpu().numpy()[0, :]
13
  x /= np.abs(x).max() + 1e-7
 
7
  txt = 'dogs in street'
8
 
9
  sound_generator = AudioGen.get_pretrained('facebook/audiogen-medium')
10
+ sound_generator.set_generation_params(duration=1.24) # why is generating so long at 14 seconds
11
 
12
  x = sound_generator.generate([txt])[0].detach().cpu().numpy()[0, :]
13
  x /= np.abs(x).max() + 1e-7