Dionyssos commited on
Commit
0a8807e
·
1 Parent(s): 8a23304

debug special token

Browse files
audiocraft/builders.py CHANGED
@@ -17,7 +17,7 @@ import torch
17
 
18
  from .encodec import CompressionModel, EncodecModel
19
  from .lm import LMModel
20
- from .seanet import SEANetEncoder, SEANetDecoder
21
  from .codebooks_patterns import (
22
  CodebooksPatternProvider,
23
  DelayedPatternProvider,
@@ -49,34 +49,40 @@ def get_quantizer(quantizer: str, cfg: omegaconf.DictConfig, dimension: int) ->
49
  return klass(**kwargs)
50
 
51
 
52
- def get_encodec_autoencoder(encoder_name: str, cfg: omegaconf.DictConfig):
53
- if encoder_name == 'seanet':
54
- kwargs = dict_from_config(getattr(cfg, 'seanet'))
55
- encoder_override_kwargs = kwargs.pop('encoder')
56
- decoder_override_kwargs = kwargs.pop('decoder')
57
- encoder_kwargs = {**kwargs, **encoder_override_kwargs}
58
- decoder_kwargs = {**kwargs, **decoder_override_kwargs}
59
- encoder = SEANetEncoder(**encoder_kwargs)
60
- decoder = SEANetDecoder(**decoder_kwargs)
61
- return encoder, decoder
62
- else:
63
- raise KeyError(f"Unexpected compression model {cfg.compression_model}")
64
 
65
 
66
- def get_compression_model(cfg: omegaconf.DictConfig) -> CompressionModel:
67
  """Instantiate a compression model."""
68
  if cfg.compression_model == 'encodec':
69
  kwargs = dict_from_config(getattr(cfg, 'encodec'))
70
- encoder_name = kwargs.pop('autoencoder')
71
  quantizer_name = kwargs.pop('quantizer')
72
- encoder, decoder = get_encodec_autoencoder(encoder_name, cfg)
73
- quantizer = get_quantizer(quantizer_name, cfg, encoder.dimension)
74
- frame_rate = kwargs['sample_rate'] // encoder.hop_length
75
  renormalize = kwargs.pop('renormalize', False)
76
  # deprecated params
 
77
  kwargs.pop('renorm', None)
78
- return EncodecModel(encoder, decoder, quantizer,
79
- frame_rate=frame_rate, renormalize=renormalize, **kwargs).to(cfg.device)
 
 
 
 
 
 
 
 
 
 
 
80
  else:
81
  raise KeyError(f"Unexpected compression model {cfg.compression_model}")
82
 
 
17
 
18
  from .encodec import CompressionModel, EncodecModel
19
  from .lm import LMModel
20
+ from .seanet import SEANetDecoder
21
  from .codebooks_patterns import (
22
  CodebooksPatternProvider,
23
  DelayedPatternProvider,
 
49
  return klass(**kwargs)
50
 
51
 
52
+ def get_encodec_autoencoder(cfg):
53
+ kwargs = dict_from_config(getattr(cfg, 'seanet'))
54
+ _ = kwargs.pop('encoder')
55
+ decoder_override_kwargs = kwargs.pop('decoder')
56
+ decoder_kwargs = {**kwargs, **decoder_override_kwargs}
57
+ decoder = SEANetDecoder(**decoder_kwargs)
58
+ return decoder
59
+
 
 
 
 
60
 
61
 
62
+ def get_compression_model(cfg):
63
  """Instantiate a compression model."""
64
  if cfg.compression_model == 'encodec':
65
  kwargs = dict_from_config(getattr(cfg, 'encodec'))
 
66
  quantizer_name = kwargs.pop('quantizer')
67
+ decoder = get_encodec_autoencoder(cfg)
68
+ quantizer = get_quantizer(quantizer_name, cfg, 128)
 
69
  renormalize = kwargs.pop('renormalize', False)
70
  # deprecated params
71
+ # print(f'{frame_rate=} {encoder.dimension=}') frame_rate=50 encoder.dimension=128
72
  kwargs.pop('renorm', None)
73
+ # print('\n______!____________\n', kwargs, '\n______!____________\n')
74
+ # ______!____________
75
+ # {'autoencoder': 'seanet', 'sample_rate': 16000, 'channels': 1, 'causal': False}
76
+ # ______!____________
77
+
78
+ return EncodecModel(decoder=decoder,
79
+ quantizer=quantizer,
80
+ frame_rate=50,
81
+ renormalize=renormalize,
82
+ sample_rate=16000,
83
+ channels=1,
84
+ causal=False
85
+ ).to(cfg.device)
86
  else:
87
  raise KeyError(f"Unexpected compression model {cfg.compression_model}")
88
 
audiocraft/conditioners.py CHANGED
@@ -1,11 +1,7 @@
1
  from collections import defaultdict
2
  from dataclasses import dataclass, field
3
- from itertools import chain
4
  import logging
5
- import math
6
- from pathlib import Path
7
  import random
8
- import re
9
  import typing as tp
10
  import warnings
11
  import soundfile
@@ -14,11 +10,8 @@ import torch
14
  from torch import nn
15
  from .streaming import StreamingModule
16
 
17
-
18
- from .quantization import ResidualVectorQuantizer
19
  from .utils.autocast import TorchAutocast
20
- from .utils.cache import EmbeddingCache
21
- from .utils.utils import collate, hash_trick, length_to_mask, load_clap_state_dict, warn_once
22
 
23
 
24
  logger = logging.getLogger(__name__)
 
1
  from collections import defaultdict
2
  from dataclasses import dataclass, field
 
3
  import logging
 
 
4
  import random
 
5
  import typing as tp
6
  import warnings
7
  import soundfile
 
10
  from torch import nn
11
  from .streaming import StreamingModule
12
 
 
 
13
  from .utils.autocast import TorchAutocast
14
+
 
15
 
16
 
17
  logger = logging.getLogger(__name__)
audiocraft/encodec.py CHANGED
@@ -30,14 +30,7 @@ class CompressionModel(ABC, nn.Module):
30
  with a language model.
31
  """
32
 
33
- @abstractmethod
34
- def forward(self, x: torch.Tensor) -> qt.QuantizedResult:
35
- ...
36
-
37
- @abstractmethod
38
- def encode(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]:
39
- """See `EncodecModel.encode`."""
40
- ...
41
 
42
  @abstractmethod
43
  def decode(self, codes: torch.Tensor, scale: tp.Optional[torch.Tensor] = None):
@@ -142,16 +135,15 @@ class EncodecModel(CompressionModel):
142
  channels: int = 0
143
 
144
  def __init__(self,
145
- encoder: nn.Module,
146
- decoder: nn.Module,
147
- quantizer: qt.BaseQuantizer,
148
- frame_rate: int,
149
- sample_rate: int,
150
- channels: int,
151
- causal: bool = False,
152
- renormalize: bool = False):
153
  super().__init__()
154
- self.encoder = encoder
155
  self.decoder = decoder
156
  self.quantizer = quantizer
157
  self.frame_rate = frame_rate
@@ -203,40 +195,6 @@ class EncodecModel(CompressionModel):
203
  x = x * scale.view(-1, 1, 1)
204
  return x
205
 
206
- def forward(self, x: torch.Tensor) -> qt.QuantizedResult:
207
- assert x.dim() == 3
208
- length = x.shape[-1]
209
- x, scale = self.preprocess(x)
210
-
211
- emb = self.encoder(x)
212
- q_res = self.quantizer(emb, self.frame_rate)
213
- out = self.decoder(q_res.x)
214
-
215
- # remove extra padding added by the encoder and decoder
216
- assert out.shape[-1] >= length, (out.shape[-1], length)
217
- out = out[..., :length]
218
-
219
- q_res.x = self.postprocess(out, scale)
220
-
221
- return q_res
222
-
223
- def encode(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]:
224
- """Encode the given input tensor to quantized representation along with scale parameter.
225
-
226
- Args:
227
- x (torch.Tensor): Float tensor of shape [B, C, T]
228
-
229
- Returns:
230
- codes, scale (tuple of torch.Tensor, torch.Tensor): Tuple composed of:
231
- codes: a float tensor of shape [B, K, T] with K the number of codebooks used and T the timestep.
232
- scale: a float tensor containing the scale for audio renormalization.
233
- """
234
- assert x.dim() == 3
235
- x, scale = self.preprocess(x)
236
- emb = self.encoder(x)
237
- codes = self.quantizer.encode(emb)
238
- return codes, scale
239
-
240
  def decode(self, codes: torch.Tensor, scale: tp.Optional[torch.Tensor] = None):
241
  """Decode the given codes to a reconstructed representation, using the scale to perform
242
  audio denormalization if needed.
 
30
  with a language model.
31
  """
32
 
33
+
 
 
 
 
 
 
 
34
 
35
  @abstractmethod
36
  def decode(self, codes: torch.Tensor, scale: tp.Optional[torch.Tensor] = None):
 
135
  channels: int = 0
136
 
137
  def __init__(self,
138
+ decoder=None,
139
+ quantizer=None,
140
+ frame_rate=None,
141
+ sample_rate=None,
142
+ channels=None,
143
+ causal=False,
144
+ renormalize=False):
 
145
  super().__init__()
146
+
147
  self.decoder = decoder
148
  self.quantizer = quantizer
149
  self.frame_rate = frame_rate
 
195
  x = x * scale.view(-1, 1, 1)
196
  return x
197
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
198
  def decode(self, codes: torch.Tensor, scale: tp.Optional[torch.Tensor] = None):
199
  """Decode the given codes to a reconstructed representation, using the scale to perform
200
  audio denormalization if needed.
audiocraft/lm.py CHANGED
@@ -14,16 +14,14 @@ import warnings
14
  import einops
15
  from num2words import num2words
16
  import spacy
17
- from transformers import RobertaTokenizer, T5EncoderModel, T5Tokenizer # type: ignore
18
  import torch
19
  import torch.nn.functional as F
20
  from torch.nn.utils.rnn import pad_sequence
21
  from audiocraft.streaming import StreamingModule
22
  from audiocraft.transformer import create_sin_embedding
23
- from audiocraft.utils.audio_utils import convert_audio
24
  from audiocraft.utils.autocast import TorchAutocast
25
- from audiocraft.utils.cache import EmbeddingCache
26
- from audiocraft.utils.utils import collate, hash_trick, length_to_mask, load_clap_state_dict, warn_once
27
  from audiocraft.transformer import StreamingTransformer, create_norm_fn
28
  from dataclasses import dataclass
29
  from functools import partial
@@ -297,13 +295,7 @@ class BaseConditioner(nn.Module):
297
  self.output_dim = output_dim
298
  self.output_proj = nn.Linear(dim, output_dim)
299
 
300
- def tokenize(self, *args, **kwargs) -> tp.Any:
301
- """Should be any part of the processing that will lead to a synchronization
302
- point, e.g. BPE tokenization with transfer to the GPU.
303
-
304
- The returned value will be saved and return later when calling forward().
305
- """
306
- raise NotImplementedError()
307
 
308
  def forward(self, inputs: tp.Any) -> ConditionType:
309
  """Gets input that should be used as conditioning (e.g, genre, description or a waveform).
@@ -530,34 +522,6 @@ class ConditioningProvider(nn.Module):
530
  def has_wav_condition(self):
531
  return len(self.wav_conditions) > 0
532
 
533
- def tokenize(self, inputs: tp.List[ConditioningAttributes]) -> tp.Dict[str, tp.Any]:
534
- """Match attributes/wavs with existing conditioners in self, and compute tokenize them accordingly.
535
- This should be called before starting any real GPU work to avoid synchronization points.
536
- This will return a dict matching conditioner names to their arbitrary tokenized representations.
537
-
538
- Args:
539
- inputs (list[ConditioningAttributes]): List of ConditioningAttributes objects containing
540
- text and wav conditions.
541
- """
542
- assert all([isinstance(x, ConditioningAttributes) for x in inputs]), (
543
- "Got unexpected types input for conditioner! should be tp.List[ConditioningAttributes]",
544
- f" but types were {set([type(x) for x in inputs])}"
545
- )
546
-
547
- output = {}
548
- text = self._collate_text(inputs)
549
- wavs = self._collate_wavs(inputs)
550
- joint_embeds = self._collate_joint_embeds(inputs)
551
-
552
- assert set(text.keys() | wavs.keys() | joint_embeds.keys()).issubset(set(self.conditioners.keys())), (
553
- f"Got an unexpected attribute! Expected {self.conditioners.keys()}, ",
554
- f"got {text.keys(), wavs.keys(), joint_embeds.keys()}"
555
- )
556
-
557
- for attribute, batch in chain(text.items(), wavs.items(), joint_embeds.items()):
558
- output[attribute] = self.conditioners[attribute].tokenize(batch)
559
- return output
560
-
561
  def forward(self, tokenized: tp.Dict[str, tp.Any]) -> tp.Dict[str, ConditionType]:
562
  """Compute pairs of `(embedding, mask)` using the configured conditioners and the tokenized representations.
563
  The output is for example:
@@ -780,6 +744,7 @@ class ConditionFuser(StreamingModule):
780
  raise ValueError(f"unknown op ({op})")
781
 
782
  if self.cross_attention_pos_emb and cross_attention_output is not None:
 
783
  positions = torch.arange(
784
  cross_attention_output.shape[1],
785
  device=cross_attention_output.device
@@ -925,7 +890,7 @@ class LMModel(StreamingModule):
925
 
926
  self.condition_provider = condition_provider
927
  self.fuser = fuser
928
- self.card = card
929
  embed_dim = self.card + 1
930
  self.n_q = n_q
931
  self.dim = dim
@@ -1030,6 +995,7 @@ class LMModel(StreamingModule):
1030
  # remove the prefix from the model outputs
1031
  if len(self.fuser.fuse2cond['prepend']) > 0:
1032
  logits = logits[:, :, -S:]
 
1033
 
1034
  return logits # [B, K, S, card]
1035
 
@@ -1067,6 +1033,8 @@ class LMModel(StreamingModule):
1067
  B, K, T = codes.shape
1068
  codes = codes.contiguous()
1069
  # map codes [B, K, T] into pattern sequence [B, K, S] using special_token_id for masked tokens
 
 
1070
  pattern = self.pattern_provider.get_pattern(T)
1071
  sequence_codes, sequence_indexes, sequence_mask = pattern.build_pattern_sequence(
1072
  codes, self.special_token_id, keep_only_valid_steps=keep_only_valid_steps,
@@ -1118,35 +1086,33 @@ class LMModel(StreamingModule):
1118
  model = self if self._fsdp is None else self._fsdp
1119
  two_step_cfg = self.two_step_cfg if two_step_cfg is None else two_step_cfg
1120
  if two_step_cfg and cfg_conditions != {}:
1121
- assert isinstance(cfg_conditions, tuple), type(cfg_conditions)
1122
- condition_tensors, null_condition_tensors = cfg_conditions
1123
- cond_logits = model(sequence, conditions=[], condition_tensors=condition_tensors)
1124
- state = self.get_streaming_state()
1125
- self.set_streaming_state(unconditional_state)
1126
- uncond_logits = model(sequence, conditions=[], condition_tensors=null_condition_tensors)
1127
- unconditional_state.update(self.get_streaming_state())
1128
- self.set_streaming_state(state)
1129
- logits = uncond_logits + (cond_logits - uncond_logits) * self.cfg_coef
1130
  else:
 
1131
  assert isinstance(cfg_conditions, dict)
1132
  condition_tensors = cfg_conditions
1133
  if condition_tensors:
 
1134
  # Preparing for CFG, predicting both conditional and unconditional logits.
1135
  sequence = torch.cat([sequence, sequence], dim=0)
1136
  all_logits = model(
1137
  sequence,
1138
  conditions=[], condition_tensors=condition_tensors)
1139
  if condition_tensors:
1140
- cond_logits, uncond_logits = all_logits.split(B, dim=0) # [B, K, T, card]
1141
- logits = uncond_logits + (cond_logits - uncond_logits) * cfg_coef
 
 
1142
  else:
1143
- logits = all_logits
 
1144
 
1145
  logits = logits.permute(0, 1, 3, 2) # [B, K, card, T]
1146
  logits = logits[..., -1] # [B x K x card]
1147
 
1148
  # Apply softmax for sampling if temp > 0. Else, do greedy sampling to avoid zero division error.
1149
  if use_sampling and temp > 0.0:
 
1150
  probs = torch.softmax(logits / temp, dim=-1)
1151
  if top_p > 0.0:
1152
  next_token = utils.sample_top_p(probs, p=top_p)
@@ -1155,7 +1121,9 @@ class LMModel(StreamingModule):
1155
  else:
1156
  next_token = utils.multinomial(probs, num_samples=1)
1157
  else:
1158
- next_token = torch.argmax(logits, dim=-1, keepdim=True)
 
 
1159
 
1160
  return next_token
1161
 
@@ -1249,9 +1217,9 @@ class LMModel(StreamingModule):
1249
  # this token is used as default value for codes that are not generated yet
1250
  unknown_token = -1
1251
 
1252
- # we generate codes up to the max_gen_len that will be mapped to the pattern sequence
1253
  gen_codes = torch.full((B, K, max_gen_len), unknown_token, dtype=torch.long, device=device)
1254
- # filling the gen_codes with the prompt if needed
1255
  gen_codes[..., :start_offset] = prompt
1256
  # create the gen_sequence with proper interleaving from the pattern: [B, K, S]
1257
  gen_sequence, indexes, mask = pattern.build_pattern_sequence(gen_codes, self.special_token_id)
@@ -1280,9 +1248,17 @@ class LMModel(StreamingModule):
1280
  # ensure the tokens that should be masked are properly set to special_token_id
1281
  # as the model never output special_token_id
1282
  valid_mask = mask[..., offset:offset+1].expand(B, -1, -1)
1283
- next_token[~valid_mask] = self.special_token_id
 
 
 
 
 
 
 
 
1284
  # ensure we don't overwrite prompt tokens, we only write over unknown tokens
1285
- # (then mask tokens should be left as is as well, which is correct)
1286
  gen_sequence[..., offset:offset+1] = torch.where(
1287
  gen_sequence[..., offset:offset+1] == unknown_token,
1288
  next_token, gen_sequence[..., offset:offset+1]
@@ -1292,23 +1268,11 @@ class LMModel(StreamingModule):
1292
  callback(1 + offset - start_offset_sequence, gen_sequence_len - start_offset_sequence)
1293
  unconditional_state.clear()
1294
 
1295
- # ensure sequence has been entirely filled
1296
- assert not (gen_sequence == unknown_token).any()
1297
- # ensure gen_sequence pattern and mask are matching
1298
- # which means the gen_sequence is valid according to the pattern
1299
- assert (
1300
- gen_sequence == torch.where(mask[None, ...].expand(B, -1, -1), gen_sequence, self.special_token_id)
1301
- ).all()
1302
- # get back the codes, trimming the prompt if needed and cutting potentially incomplete timesteps
1303
  out_codes, out_indexes, out_mask = pattern.revert_pattern_sequence(gen_sequence, special_token=unknown_token)
1304
 
1305
- # sanity checks over the returned codes and corresponding masks
1306
- assert (out_codes[..., :max_gen_len] != unknown_token).all()
1307
- assert (out_mask[..., :max_gen_len] == 1).all()
1308
-
1309
  out_start_offset = start_offset if remove_prompts else 0
1310
  out_codes = out_codes[..., out_start_offset:max_gen_len]
1311
 
1312
  # ensure the returned codes are all valid
1313
- assert (out_codes >= 0).all() and (out_codes <= self.card).all()
1314
  return out_codes
 
14
  import einops
15
  from num2words import num2words
16
  import spacy
17
+ from transformers import T5EncoderModel, T5Tokenizer # type: ignore
18
  import torch
19
  import torch.nn.functional as F
20
  from torch.nn.utils.rnn import pad_sequence
21
  from audiocraft.streaming import StreamingModule
22
  from audiocraft.transformer import create_sin_embedding
 
23
  from audiocraft.utils.autocast import TorchAutocast
24
+ from audiocraft.utils.utils import collate, length_to_mask
 
25
  from audiocraft.transformer import StreamingTransformer, create_norm_fn
26
  from dataclasses import dataclass
27
  from functools import partial
 
295
  self.output_dim = output_dim
296
  self.output_proj = nn.Linear(dim, output_dim)
297
 
298
+
 
 
 
 
 
 
299
 
300
  def forward(self, inputs: tp.Any) -> ConditionType:
301
  """Gets input that should be used as conditioning (e.g, genre, description or a waveform).
 
522
  def has_wav_condition(self):
523
  return len(self.wav_conditions) > 0
524
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
525
  def forward(self, tokenized: tp.Dict[str, tp.Any]) -> tp.Dict[str, ConditionType]:
526
  """Compute pairs of `(embedding, mask)` using the configured conditioners and the tokenized representations.
527
  The output is for example:
 
744
  raise ValueError(f"unknown op ({op})")
745
 
746
  if self.cross_attention_pos_emb and cross_attention_output is not None:
747
+ print('SIN EMBED')
748
  positions = torch.arange(
749
  cross_attention_output.shape[1],
750
  device=cross_attention_output.device
 
890
 
891
  self.condition_provider = condition_provider
892
  self.fuser = fuser
893
+ self.card = card # 2048 ?
894
  embed_dim = self.card + 1
895
  self.n_q = n_q
896
  self.dim = dim
 
995
  # remove the prefix from the model outputs
996
  if len(self.fuser.fuse2cond['prepend']) > 0:
997
  logits = logits[:, :, -S:]
998
+ print('PRESFIX')
999
 
1000
  return logits # [B, K, S, card]
1001
 
 
1033
  B, K, T = codes.shape
1034
  codes = codes.contiguous()
1035
  # map codes [B, K, T] into pattern sequence [B, K, S] using special_token_id for masked tokens
1036
+ # what is the T is it 2048 ?
1037
+ # and then what is pattern -> another function?
1038
  pattern = self.pattern_provider.get_pattern(T)
1039
  sequence_codes, sequence_indexes, sequence_mask = pattern.build_pattern_sequence(
1040
  codes, self.special_token_id, keep_only_valid_steps=keep_only_valid_steps,
 
1086
  model = self if self._fsdp is None else self._fsdp
1087
  two_step_cfg = self.two_step_cfg if two_step_cfg is None else two_step_cfg
1088
  if two_step_cfg and cfg_conditions != {}:
1089
+ print('\nNOT HERE\n')
 
 
 
 
 
 
 
 
1090
  else:
1091
+ print('C')
1092
  assert isinstance(cfg_conditions, dict)
1093
  condition_tensors = cfg_conditions
1094
  if condition_tensors:
1095
+ # print('\nD\n')
1096
  # Preparing for CFG, predicting both conditional and unconditional logits.
1097
  sequence = torch.cat([sequence, sequence], dim=0)
1098
  all_logits = model(
1099
  sequence,
1100
  conditions=[], condition_tensors=condition_tensors)
1101
  if condition_tensors:
1102
+ cond_logits, uncond_logits = all_logits.split(B, dim=0) #torch.Size([2, 4, 1, 2048])
1103
+ # logits = uncond_logits + (cond_logits - uncond_logits) * cfg_coef
1104
+ # logits = 3 * cond_logits - 2.4 * uncond_logits
1105
+ logits = 2 * cond_logits - 1.4 * uncond_logits
1106
  else:
1107
+ print('\nF!\n')
1108
+
1109
 
1110
  logits = logits.permute(0, 1, 3, 2) # [B, K, card, T]
1111
  logits = logits[..., -1] # [B x K x card]
1112
 
1113
  # Apply softmax for sampling if temp > 0. Else, do greedy sampling to avoid zero division error.
1114
  if use_sampling and temp > 0.0:
1115
+ # print(f'\nR {temp=} {top_p=} {top_k=}\n') -------------> R temp=1.0 top_p=0.0 top_k=250
1116
  probs = torch.softmax(logits / temp, dim=-1)
1117
  if top_p > 0.0:
1118
  next_token = utils.sample_top_p(probs, p=top_p)
 
1121
  else:
1122
  next_token = utils.multinomial(probs, num_samples=1)
1123
  else:
1124
+ #
1125
+ print('\nNeverHere\n')
1126
+
1127
 
1128
  return next_token
1129
 
 
1217
  # this token is used as default value for codes that are not generated yet
1218
  unknown_token = -1
1219
 
1220
+
1221
  gen_codes = torch.full((B, K, max_gen_len), unknown_token, dtype=torch.long, device=device)
1222
+
1223
  gen_codes[..., :start_offset] = prompt
1224
  # create the gen_sequence with proper interleaving from the pattern: [B, K, S]
1225
  gen_sequence, indexes, mask = pattern.build_pattern_sequence(gen_codes, self.special_token_id)
 
1248
  # ensure the tokens that should be masked are properly set to special_token_id
1249
  # as the model never output special_token_id
1250
  valid_mask = mask[..., offset:offset+1].expand(B, -1, -1)
1251
+
1252
+ # next_token[~valid_mask] = self.special_token_id
1253
+
1254
+ # print(f'{unconditional_state=} \n
1255
+ # print('Set All to Special')
1256
+ # next_token[:] = self.special_token_id
1257
+
1258
+
1259
+
1260
  # ensure we don't overwrite prompt tokens, we only write over unknown tokens
1261
+
1262
  gen_sequence[..., offset:offset+1] = torch.where(
1263
  gen_sequence[..., offset:offset+1] == unknown_token,
1264
  next_token, gen_sequence[..., offset:offset+1]
 
1268
  callback(1 + offset - start_offset_sequence, gen_sequence_len - start_offset_sequence)
1269
  unconditional_state.clear()
1270
 
 
 
 
 
 
 
 
 
1271
  out_codes, out_indexes, out_mask = pattern.revert_pattern_sequence(gen_sequence, special_token=unknown_token)
1272
 
 
 
 
 
1273
  out_start_offset = start_offset if remove_prompts else 0
1274
  out_codes = out_codes[..., out_start_offset:max_gen_len]
1275
 
1276
  # ensure the returned codes are all valid
1277
+ # assert (out_codes >= 0).all() and (out_codes <= self.card).all()
1278
  return out_codes
audiocraft/loaders.py CHANGED
@@ -79,7 +79,7 @@ def load_compression_model(file_or_url_or_id: tp.Union[Path, str], device='cpu',
79
  cfg = OmegaConf.create(pkg['xp.cfg'])
80
  cfg.device = str(device)
81
  model = builders.get_compression_model(cfg)
82
- model.load_state_dict(pkg['best_state'])
83
  model.eval()
84
  return model
85
 
 
79
  cfg = OmegaConf.create(pkg['xp.cfg'])
80
  cfg.device = str(device)
81
  model = builders.get_compression_model(cfg)
82
+ model.load_state_dict(pkg['best_state'], strict=False) # ckpt contains uninstantiated encoder
83
  model.eval()
84
  return model
85
 
audiocraft/seanet.py CHANGED
@@ -60,136 +60,25 @@ class SEANetResnetBlock(nn.Module):
60
  return self.shortcut(x) + self.block(x)
61
 
62
 
63
- class SEANetEncoder(nn.Module):
64
- """SEANet encoder.
65
 
66
- Args:
67
- channels (int): Audio channels.
68
- dimension (int): Intermediate representation dimension.
69
- n_filters (int): Base width for the model.
70
- n_residual_layers (int): nb of residual layers.
71
- ratios (Sequence[int]): kernel size and stride ratios. The encoder uses downsampling ratios instead of
72
- upsampling ratios, hence it will use the ratios in the reverse order to the ones specified here
73
- that must match the decoder order. We use the decoder order as some models may only employ the decoder.
74
- activation (str): Activation function.
75
- activation_params (dict): Parameters to provide to the activation function.
76
- norm (str): Normalization method.
77
- norm_params (dict): Parameters to provide to the underlying normalization used along with the convolution.
78
- kernel_size (int): Kernel size for the initial convolution.
79
- last_kernel_size (int): Kernel size for the initial convolution.
80
- residual_kernel_size (int): Kernel size for the residual layers.
81
- dilation_base (int): How much to increase the dilation with each layer.
82
- causal (bool): Whether to use fully causal convolution.
83
- pad_mode (str): Padding mode for the convolutions.
84
- true_skip (bool): Whether to use true skip connection or a simple
85
- (streamable) convolution as the skip connection in the residual network blocks.
86
- compress (int): Reduced dimensionality in residual branches (from Demucs v3).
87
- lstm (int): Number of LSTM layers at the end of the encoder.
88
- disable_norm_outer_blocks (int): Number of blocks for which we don't apply norm.
89
- For the encoder, it corresponds to the N first blocks.
90
- """
91
- def __init__(self, channels: int = 1, dimension: int = 128, n_filters: int = 32, n_residual_layers: int = 3,
92
- ratios: tp.List[int] = [8, 5, 4, 2], activation: str = 'ELU', activation_params: dict = {'alpha': 1.0},
93
- norm: str = 'none', norm_params: tp.Dict[str, tp.Any] = {}, kernel_size: int = 7,
94
- last_kernel_size: int = 7, residual_kernel_size: int = 3, dilation_base: int = 2, causal: bool = False,
95
- pad_mode: str = 'reflect', true_skip: bool = True, compress: int = 2, lstm: int = 0,
96
- disable_norm_outer_blocks: int = 0):
97
- super().__init__()
98
- self.channels = channels
99
- self.dimension = dimension
100
- self.n_filters = n_filters
101
- self.ratios = list(reversed(ratios))
102
- del ratios
103
- self.n_residual_layers = n_residual_layers
104
- self.hop_length = np.prod(self.ratios)
105
- self.n_blocks = len(self.ratios) + 2 # first and last conv + residual blocks
106
- self.disable_norm_outer_blocks = disable_norm_outer_blocks
107
- assert self.disable_norm_outer_blocks >= 0 and self.disable_norm_outer_blocks <= self.n_blocks, \
108
- "Number of blocks for which to disable norm is invalid." \
109
- "It should be lower or equal to the actual number of blocks in the network and greater or equal to 0."
110
-
111
- act = getattr(nn, activation)
112
- mult = 1
113
- model: tp.List[nn.Module] = [
114
- StreamableConv1d(channels, mult * n_filters, kernel_size,
115
- norm='none' if self.disable_norm_outer_blocks >= 1 else norm,
116
- norm_kwargs=norm_params, causal=causal, pad_mode=pad_mode)
117
- ]
118
- # Downsample to raw audio scale
119
- for i, ratio in enumerate(self.ratios):
120
- block_norm = 'none' if self.disable_norm_outer_blocks >= i + 2 else norm
121
- # Add residual layers
122
- for j in range(n_residual_layers):
123
- model += [
124
- SEANetResnetBlock(mult * n_filters, kernel_sizes=[residual_kernel_size, 1],
125
- dilations=[dilation_base ** j, 1],
126
- norm=block_norm, norm_params=norm_params,
127
- activation=activation, activation_params=activation_params,
128
- causal=causal, pad_mode=pad_mode, compress=compress, true_skip=true_skip)]
129
-
130
- # Add downsampling layers
131
- model += [
132
- act(**activation_params),
133
- StreamableConv1d(mult * n_filters, mult * n_filters * 2,
134
- kernel_size=ratio * 2, stride=ratio,
135
- norm=block_norm, norm_kwargs=norm_params,
136
- causal=causal, pad_mode=pad_mode),
137
- ]
138
- mult *= 2
139
-
140
- if lstm:
141
- model += [StreamableLSTM(mult * n_filters, num_layers=lstm)]
142
-
143
- model += [
144
- act(**activation_params),
145
- StreamableConv1d(mult * n_filters, dimension, last_kernel_size,
146
- norm='none' if self.disable_norm_outer_blocks == self.n_blocks else norm,
147
- norm_kwargs=norm_params, causal=causal, pad_mode=pad_mode)
148
- ]
149
-
150
- self.model = nn.Sequential(*model)
151
-
152
- def forward(self, x):
153
- return self.model(x)
154
 
155
 
156
  class SEANetDecoder(nn.Module):
157
- """SEANet decoder.
158
 
159
- Args:
160
- channels (int): Audio channels.
161
- dimension (int): Intermediate representation dimension.
162
- n_filters (int): Base width for the model.
163
- n_residual_layers (int): nb of residual layers.
164
- ratios (Sequence[int]): kernel size and stride ratios.
165
- activation (str): Activation function.
166
- activation_params (dict): Parameters to provide to the activation function.
167
- final_activation (str): Final activation function after all convolutions.
168
- final_activation_params (dict): Parameters to provide to the activation function.
169
- norm (str): Normalization method.
170
- norm_params (dict): Parameters to provide to the underlying normalization used along with the convolution.
171
- kernel_size (int): Kernel size for the initial convolution.
172
- last_kernel_size (int): Kernel size for the initial convolution.
173
- residual_kernel_size (int): Kernel size for the residual layers.
174
- dilation_base (int): How much to increase the dilation with each layer.
175
- causal (bool): Whether to use fully causal convolution.
176
- pad_mode (str): Padding mode for the convolutions.
177
- true_skip (bool): Whether to use true skip connection or a simple.
178
- (streamable) convolution as the skip connection in the residual network blocks.
179
- compress (int): Reduced dimensionality in residual branches (from Demucs v3).
180
- lstm (int): Number of LSTM layers at the end of the encoder.
181
- disable_norm_outer_blocks (int): Number of blocks for which we don't apply norm.
182
- For the decoder, it corresponds to the N last blocks.
183
- trim_right_ratio (float): Ratio for trimming at the right of the transposed convolution under the causal setup.
184
- If equal to 1.0, it means that all the trimming is done at the right.
185
- """
186
- def __init__(self, channels: int = 1, dimension: int = 128, n_filters: int = 32, n_residual_layers: int = 3,
187
- ratios: tp.List[int] = [8, 5, 4, 2], activation: str = 'ELU', activation_params: dict = {'alpha': 1.0},
188
- final_activation: tp.Optional[str] = None, final_activation_params: tp.Optional[dict] = None,
189
- norm: str = 'none', norm_params: tp.Dict[str, tp.Any] = {}, kernel_size: int = 7,
190
- last_kernel_size: int = 7, residual_kernel_size: int = 3, dilation_base: int = 2, causal: bool = False,
191
- pad_mode: str = 'reflect', true_skip: bool = True, compress: int = 2, lstm: int = 0,
192
- disable_norm_outer_blocks: int = 0, trim_right_ratio: float = 1.0):
193
  super().__init__()
194
  self.dimension = dimension
195
  self.channels = channels
 
60
  return self.shortcut(x) + self.block(x)
61
 
62
 
 
 
63
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
 
65
 
66
  class SEANetDecoder(nn.Module):
 
67
 
68
+ def __init__(self, channels: int = 1,
69
+ dimension: int = 128, n_filters: int = 32, n_residual_layers: int = 3,
70
+ ratios: tp.List[int] = [8, 5, 4, 2], activation: str = 'ELU',
71
+ activation_params: dict = {'alpha': 1.0},
72
+ final_activation: tp.Optional[str] = None,
73
+ final_activation_params: tp.Optional[dict] = None,
74
+ norm: str = 'none', norm_params: tp.Dict[str, tp.Any] = {},
75
+ kernel_size: int = 7,
76
+ last_kernel_size: int = 7, residual_kernel_size: int = 3,
77
+ dilation_base: int = 2, causal: bool = False,
78
+ pad_mode: str = 'reflect', true_skip: bool = True,
79
+ compress: int = 2, lstm: int = 0,
80
+ disable_norm_outer_blocks: int = 0,
81
+ trim_right_ratio: float = 1.0):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
  super().__init__()
83
  self.dimension = dimension
84
  self.channels = channels
audiocraft/utils/utils.py CHANGED
@@ -4,7 +4,7 @@
4
  # This source code is licensed under the license found in the
5
  # LICENSE file in the root directory of this source tree.
6
 
7
- from concurrent.futures import ProcessPoolExecutor
8
  from contextlib import contextmanager
9
  from functools import wraps, lru_cache
10
  import hashlib
@@ -103,6 +103,9 @@ def multinomial(input: torch.Tensor, num_samples: int, replacement=False, *, gen
103
  input_ = input.reshape(-1, input.shape[-1])
104
  output_ = torch.multinomial(input_, num_samples=num_samples, replacement=replacement, generator=generator)
105
  output = output_.reshape(*list(input.shape[:-1]), -1)
 
 
 
106
  return output
107
 
108
 
@@ -115,61 +118,18 @@ def sample_top_k(probs: torch.Tensor, k: int) -> torch.Tensor:
115
  Returns:
116
  torch.Tensor: Sampled tokens.
117
  """
118
- top_k_value, _ = torch.topk(probs, k, dim=-1)
119
- min_value_top_k = top_k_value[..., [-1]]
120
- probs *= (probs >= min_value_top_k).float()
121
- probs.div_(probs.sum(dim=-1, keepdim=True))
122
  next_token = multinomial(probs, num_samples=1)
 
 
 
123
  return next_token
124
 
125
 
126
- def sample_top_p(probs: torch.Tensor, p: float) -> torch.Tensor:
127
- """Sample next token from top P probabilities along the last dimension of the input probs tensor.
128
-
129
- Args:
130
- probs (torch.Tensor): Input probabilities with token candidates on the last dimension.
131
- p (int): The p in “top-p”.
132
- Returns:
133
- torch.Tensor: Sampled tokens.
134
- """
135
- probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
136
- probs_sum = torch.cumsum(probs_sort, dim=-1)
137
- mask = probs_sum - probs_sort > p
138
- probs_sort *= (~mask).float()
139
- probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
140
- next_token = multinomial(probs_sort, num_samples=1)
141
- next_token = torch.gather(probs_idx, -1, next_token)
142
- return next_token
143
-
144
-
145
- class DummyPoolExecutor:
146
- """Dummy pool executor to use when we actually have only 1 worker.
147
- (e.g. instead of ProcessPoolExecutor).
148
- """
149
- class DummyResult:
150
- def __init__(self, func, *args, **kwargs):
151
- self.func = func
152
- self.args = args
153
- self.kwargs = kwargs
154
-
155
- def result(self):
156
- return self.func(*self.args, **self.kwargs)
157
-
158
- def __init__(self, workers, mp_context=None):
159
- pass
160
-
161
- def submit(self, func, *args, **kwargs):
162
- return DummyPoolExecutor.DummyResult(func, *args, **kwargs)
163
-
164
- def __enter__(self):
165
- return self
166
 
167
- def __exit__(self, exc_type, exc_value, exc_tb):
168
- return
169
-
170
-
171
- def get_pool_executor(num_workers: int, mp_context=None):
172
- return ProcessPoolExecutor(num_workers, mp_context) if num_workers > 1 else DummyPoolExecutor(1)
173
 
174
 
175
  def length_to_mask(lengths: torch.Tensor, max_len: tp.Optional[int] = None) -> torch.Tensor:
@@ -188,42 +148,6 @@ def length_to_mask(lengths: torch.Tensor, max_len: tp.Optional[int] = None) -> t
188
  return torch.arange(final_length, device=lengths.device)[None, :] < lengths[:, None]
189
 
190
 
191
- def hash_trick(word: str, vocab_size: int) -> int:
192
- """Hash trick to pair each word with an index
193
-
194
- Args:
195
- word (str): word we wish to convert to an index
196
- vocab_size (int): size of the vocabulary
197
- Returns:
198
- int: index of the word in the embedding LUT
199
- """
200
- hash = int(hashlib.sha256(word.encode("utf-8")).hexdigest(), 16)
201
- return hash % vocab_size
202
-
203
-
204
- def with_rank_rng(base_seed: int = 1234):
205
- """Decorator for a function so that the function will use a Random Number Generator
206
- whose state depend on the GPU rank. The original RNG state is restored upon returning.
207
-
208
- Args:
209
- base_seed (int): Random seed.
210
- """
211
- def _decorator(fun: tp.Callable):
212
- @wraps(fun)
213
- def _decorated(*args, **kwargs):
214
- state = torch.get_rng_state()
215
- seed = base_seed ^ flashy.distrib.rank()
216
- torch.manual_seed(seed)
217
- logger.debug('Rank dependent seed set to %d', seed)
218
- try:
219
- return fun(*args, **kwargs)
220
- finally:
221
- torch.set_rng_state(state)
222
- logger.debug('RNG state restored.')
223
- return _decorated
224
- return _decorator
225
-
226
-
227
  def collate(tensors: tp.List[torch.Tensor], dim: int = 0) -> tp.Tuple[torch.Tensor, torch.Tensor]:
228
  """Get a list of tensors and collate them to a single tensor. according to the following logic:
229
  - `dim` specifies the time dimension which will be stacked and padded.
@@ -247,52 +171,3 @@ def collate(tensors: tp.List[torch.Tensor], dim: int = 0) -> tp.Tuple[torch.Tens
247
  return padded_tensors, lens
248
 
249
 
250
- # TODO: Move to flashy?
251
- def copy_state(state: tp.Any, device: tp.Union[torch.device, str] = 'cpu',
252
- dtype: tp.Optional[torch.dtype] = None) -> tp.Any:
253
- if isinstance(state, torch.Tensor):
254
- if dtype is None or not state.is_floating_point():
255
- dtype = state.dtype
256
- return state.detach().to(device=device, dtype=dtype, copy=True)
257
- elif isinstance(state, dict):
258
- return {k: copy_state(v, device, dtype) for k, v in state.items()}
259
- elif isinstance(state, list):
260
- return [copy_state(v, device, dtype) for v in state]
261
-
262
-
263
- # TODO: Move to flashy?
264
- @contextmanager
265
- def swap_state(model, state, **kwargs):
266
- old_state = copy_state(model.state_dict())
267
- model.load_state_dict(state, **kwargs)
268
- try:
269
- yield
270
- finally:
271
- model.load_state_dict(old_state)
272
-
273
-
274
- @lru_cache(None)
275
- def warn_once(logger, msg):
276
- """Warn about a given message only once."""
277
- logger.warning(msg)
278
-
279
-
280
- def is_jsonable(x: tp.Any):
281
- """Check if an object can be serialized into a json:"""
282
- try:
283
- json.dumps(x)
284
- return True
285
- except (TypeError, OverflowError):
286
- return False
287
-
288
-
289
- def load_clap_state_dict(clap_model, path: tp.Union[str, Path]):
290
- """Wrapper around state dict loading of CLAP model
291
- addressing compatibility issues between CLAP and AudioCraft
292
- HuggingFace transformer version.
293
- See: https://github.com/LAION-AI/CLAP/issues/118
294
- """
295
- from clap_module.factory import load_state_dict # type: ignore
296
- pkg = load_state_dict(path)
297
- pkg.pop('text_branch.embeddings.position_ids', None)
298
- clap_model.model.load_state_dict(pkg)
 
4
  # This source code is licensed under the license found in the
5
  # LICENSE file in the root directory of this source tree.
6
 
7
+
8
  from contextlib import contextmanager
9
  from functools import wraps, lru_cache
10
  import hashlib
 
103
  input_ = input.reshape(-1, input.shape[-1])
104
  output_ = torch.multinomial(input_, num_samples=num_samples, replacement=replacement, generator=generator)
105
  output = output_.reshape(*list(input.shape[:-1]), -1)
106
+
107
+ # print('MULTINOmial', input.shape, output.shape) # MULTINOmial torch.Size([1, 4, 2048]) torch.Size([1, 4, 1])
108
+ # output = input[..., 0:1]
109
  return output
110
 
111
 
 
118
  Returns:
119
  torch.Tensor: Sampled tokens.
120
  """
121
+ top_k_value, i250 = torch.topk(probs, k, dim=-1) # probs: [1, 4, 2048]
122
+ min_value_top_k = top_k_value[..., [-1]] #
123
+ probs *= (probs >= min_value_top_k).float() # multiply all being > of min_topk with 1 thus zeroing others
124
+ probs.div_(probs.sum(dim=-1, keepdim=True)) # why normalize by the sum ? oh in order to choose mult
125
  next_token = multinomial(probs, num_samples=1)
126
+ # so instead of chooose multinomial what happens if we take all 250 topk tokens
127
+ # probs.shape=torch.Size([1, 4, 2048]) <, print(next_token,f'{probs.shape=}', 'h') # 1,4,1 next token is 4tok
128
+ # next_token = i250
129
  return next_token
130
 
131
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
132
 
 
 
 
 
 
 
133
 
134
 
135
  def length_to_mask(lengths: torch.Tensor, max_len: tp.Optional[int] = None) -> torch.Tensor:
 
148
  return torch.arange(final_length, device=lengths.device)[None, :] < lengths[:, None]
149
 
150
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
151
  def collate(tensors: tp.List[torch.Tensor], dim: int = 0) -> tp.Tuple[torch.Tensor, torch.Tensor]:
152
  """Get a list of tensors and collate them to a single tensor. according to the following logic:
153
  - `dim` specifies the time dimension which will be stacked and padded.
 
171
  return padded_tensors, lens
172
 
173