debug special token
Browse files- audiocraft/builders.py +26 -20
- audiocraft/conditioners.py +1 -8
- audiocraft/encodec.py +9 -51
- audiocraft/lm.py +34 -70
- audiocraft/loaders.py +1 -1
- audiocraft/seanet.py +14 -125
- audiocraft/utils/utils.py +11 -136
audiocraft/builders.py
CHANGED
@@ -17,7 +17,7 @@ import torch
|
|
17 |
|
18 |
from .encodec import CompressionModel, EncodecModel
|
19 |
from .lm import LMModel
|
20 |
-
from .seanet import
|
21 |
from .codebooks_patterns import (
|
22 |
CodebooksPatternProvider,
|
23 |
DelayedPatternProvider,
|
@@ -49,34 +49,40 @@ def get_quantizer(quantizer: str, cfg: omegaconf.DictConfig, dimension: int) ->
|
|
49 |
return klass(**kwargs)
|
50 |
|
51 |
|
52 |
-
def get_encodec_autoencoder(
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
decoder = SEANetDecoder(**decoder_kwargs)
|
61 |
-
return encoder, decoder
|
62 |
-
else:
|
63 |
-
raise KeyError(f"Unexpected compression model {cfg.compression_model}")
|
64 |
|
65 |
|
66 |
-
def get_compression_model(cfg
|
67 |
"""Instantiate a compression model."""
|
68 |
if cfg.compression_model == 'encodec':
|
69 |
kwargs = dict_from_config(getattr(cfg, 'encodec'))
|
70 |
-
encoder_name = kwargs.pop('autoencoder')
|
71 |
quantizer_name = kwargs.pop('quantizer')
|
72 |
-
|
73 |
-
quantizer = get_quantizer(quantizer_name, cfg,
|
74 |
-
frame_rate = kwargs['sample_rate'] // encoder.hop_length
|
75 |
renormalize = kwargs.pop('renormalize', False)
|
76 |
# deprecated params
|
|
|
77 |
kwargs.pop('renorm', None)
|
78 |
-
|
79 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
80 |
else:
|
81 |
raise KeyError(f"Unexpected compression model {cfg.compression_model}")
|
82 |
|
|
|
17 |
|
18 |
from .encodec import CompressionModel, EncodecModel
|
19 |
from .lm import LMModel
|
20 |
+
from .seanet import SEANetDecoder
|
21 |
from .codebooks_patterns import (
|
22 |
CodebooksPatternProvider,
|
23 |
DelayedPatternProvider,
|
|
|
49 |
return klass(**kwargs)
|
50 |
|
51 |
|
52 |
+
def get_encodec_autoencoder(cfg):
|
53 |
+
kwargs = dict_from_config(getattr(cfg, 'seanet'))
|
54 |
+
_ = kwargs.pop('encoder')
|
55 |
+
decoder_override_kwargs = kwargs.pop('decoder')
|
56 |
+
decoder_kwargs = {**kwargs, **decoder_override_kwargs}
|
57 |
+
decoder = SEANetDecoder(**decoder_kwargs)
|
58 |
+
return decoder
|
59 |
+
|
|
|
|
|
|
|
|
|
60 |
|
61 |
|
62 |
+
def get_compression_model(cfg):
|
63 |
"""Instantiate a compression model."""
|
64 |
if cfg.compression_model == 'encodec':
|
65 |
kwargs = dict_from_config(getattr(cfg, 'encodec'))
|
|
|
66 |
quantizer_name = kwargs.pop('quantizer')
|
67 |
+
decoder = get_encodec_autoencoder(cfg)
|
68 |
+
quantizer = get_quantizer(quantizer_name, cfg, 128)
|
|
|
69 |
renormalize = kwargs.pop('renormalize', False)
|
70 |
# deprecated params
|
71 |
+
# print(f'{frame_rate=} {encoder.dimension=}') frame_rate=50 encoder.dimension=128
|
72 |
kwargs.pop('renorm', None)
|
73 |
+
# print('\n______!____________\n', kwargs, '\n______!____________\n')
|
74 |
+
# ______!____________
|
75 |
+
# {'autoencoder': 'seanet', 'sample_rate': 16000, 'channels': 1, 'causal': False}
|
76 |
+
# ______!____________
|
77 |
+
|
78 |
+
return EncodecModel(decoder=decoder,
|
79 |
+
quantizer=quantizer,
|
80 |
+
frame_rate=50,
|
81 |
+
renormalize=renormalize,
|
82 |
+
sample_rate=16000,
|
83 |
+
channels=1,
|
84 |
+
causal=False
|
85 |
+
).to(cfg.device)
|
86 |
else:
|
87 |
raise KeyError(f"Unexpected compression model {cfg.compression_model}")
|
88 |
|
audiocraft/conditioners.py
CHANGED
@@ -1,11 +1,7 @@
|
|
1 |
from collections import defaultdict
|
2 |
from dataclasses import dataclass, field
|
3 |
-
from itertools import chain
|
4 |
import logging
|
5 |
-
import math
|
6 |
-
from pathlib import Path
|
7 |
import random
|
8 |
-
import re
|
9 |
import typing as tp
|
10 |
import warnings
|
11 |
import soundfile
|
@@ -14,11 +10,8 @@ import torch
|
|
14 |
from torch import nn
|
15 |
from .streaming import StreamingModule
|
16 |
|
17 |
-
|
18 |
-
from .quantization import ResidualVectorQuantizer
|
19 |
from .utils.autocast import TorchAutocast
|
20 |
-
|
21 |
-
from .utils.utils import collate, hash_trick, length_to_mask, load_clap_state_dict, warn_once
|
22 |
|
23 |
|
24 |
logger = logging.getLogger(__name__)
|
|
|
1 |
from collections import defaultdict
|
2 |
from dataclasses import dataclass, field
|
|
|
3 |
import logging
|
|
|
|
|
4 |
import random
|
|
|
5 |
import typing as tp
|
6 |
import warnings
|
7 |
import soundfile
|
|
|
10 |
from torch import nn
|
11 |
from .streaming import StreamingModule
|
12 |
|
|
|
|
|
13 |
from .utils.autocast import TorchAutocast
|
14 |
+
|
|
|
15 |
|
16 |
|
17 |
logger = logging.getLogger(__name__)
|
audiocraft/encodec.py
CHANGED
@@ -30,14 +30,7 @@ class CompressionModel(ABC, nn.Module):
|
|
30 |
with a language model.
|
31 |
"""
|
32 |
|
33 |
-
|
34 |
-
def forward(self, x: torch.Tensor) -> qt.QuantizedResult:
|
35 |
-
...
|
36 |
-
|
37 |
-
@abstractmethod
|
38 |
-
def encode(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]:
|
39 |
-
"""See `EncodecModel.encode`."""
|
40 |
-
...
|
41 |
|
42 |
@abstractmethod
|
43 |
def decode(self, codes: torch.Tensor, scale: tp.Optional[torch.Tensor] = None):
|
@@ -142,16 +135,15 @@ class EncodecModel(CompressionModel):
|
|
142 |
channels: int = 0
|
143 |
|
144 |
def __init__(self,
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
renormalize: bool = False):
|
153 |
super().__init__()
|
154 |
-
|
155 |
self.decoder = decoder
|
156 |
self.quantizer = quantizer
|
157 |
self.frame_rate = frame_rate
|
@@ -203,40 +195,6 @@ class EncodecModel(CompressionModel):
|
|
203 |
x = x * scale.view(-1, 1, 1)
|
204 |
return x
|
205 |
|
206 |
-
def forward(self, x: torch.Tensor) -> qt.QuantizedResult:
|
207 |
-
assert x.dim() == 3
|
208 |
-
length = x.shape[-1]
|
209 |
-
x, scale = self.preprocess(x)
|
210 |
-
|
211 |
-
emb = self.encoder(x)
|
212 |
-
q_res = self.quantizer(emb, self.frame_rate)
|
213 |
-
out = self.decoder(q_res.x)
|
214 |
-
|
215 |
-
# remove extra padding added by the encoder and decoder
|
216 |
-
assert out.shape[-1] >= length, (out.shape[-1], length)
|
217 |
-
out = out[..., :length]
|
218 |
-
|
219 |
-
q_res.x = self.postprocess(out, scale)
|
220 |
-
|
221 |
-
return q_res
|
222 |
-
|
223 |
-
def encode(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]:
|
224 |
-
"""Encode the given input tensor to quantized representation along with scale parameter.
|
225 |
-
|
226 |
-
Args:
|
227 |
-
x (torch.Tensor): Float tensor of shape [B, C, T]
|
228 |
-
|
229 |
-
Returns:
|
230 |
-
codes, scale (tuple of torch.Tensor, torch.Tensor): Tuple composed of:
|
231 |
-
codes: a float tensor of shape [B, K, T] with K the number of codebooks used and T the timestep.
|
232 |
-
scale: a float tensor containing the scale for audio renormalization.
|
233 |
-
"""
|
234 |
-
assert x.dim() == 3
|
235 |
-
x, scale = self.preprocess(x)
|
236 |
-
emb = self.encoder(x)
|
237 |
-
codes = self.quantizer.encode(emb)
|
238 |
-
return codes, scale
|
239 |
-
|
240 |
def decode(self, codes: torch.Tensor, scale: tp.Optional[torch.Tensor] = None):
|
241 |
"""Decode the given codes to a reconstructed representation, using the scale to perform
|
242 |
audio denormalization if needed.
|
|
|
30 |
with a language model.
|
31 |
"""
|
32 |
|
33 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
34 |
|
35 |
@abstractmethod
|
36 |
def decode(self, codes: torch.Tensor, scale: tp.Optional[torch.Tensor] = None):
|
|
|
135 |
channels: int = 0
|
136 |
|
137 |
def __init__(self,
|
138 |
+
decoder=None,
|
139 |
+
quantizer=None,
|
140 |
+
frame_rate=None,
|
141 |
+
sample_rate=None,
|
142 |
+
channels=None,
|
143 |
+
causal=False,
|
144 |
+
renormalize=False):
|
|
|
145 |
super().__init__()
|
146 |
+
|
147 |
self.decoder = decoder
|
148 |
self.quantizer = quantizer
|
149 |
self.frame_rate = frame_rate
|
|
|
195 |
x = x * scale.view(-1, 1, 1)
|
196 |
return x
|
197 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
198 |
def decode(self, codes: torch.Tensor, scale: tp.Optional[torch.Tensor] = None):
|
199 |
"""Decode the given codes to a reconstructed representation, using the scale to perform
|
200 |
audio denormalization if needed.
|
audiocraft/lm.py
CHANGED
@@ -14,16 +14,14 @@ import warnings
|
|
14 |
import einops
|
15 |
from num2words import num2words
|
16 |
import spacy
|
17 |
-
from transformers import
|
18 |
import torch
|
19 |
import torch.nn.functional as F
|
20 |
from torch.nn.utils.rnn import pad_sequence
|
21 |
from audiocraft.streaming import StreamingModule
|
22 |
from audiocraft.transformer import create_sin_embedding
|
23 |
-
from audiocraft.utils.audio_utils import convert_audio
|
24 |
from audiocraft.utils.autocast import TorchAutocast
|
25 |
-
from audiocraft.utils.
|
26 |
-
from audiocraft.utils.utils import collate, hash_trick, length_to_mask, load_clap_state_dict, warn_once
|
27 |
from audiocraft.transformer import StreamingTransformer, create_norm_fn
|
28 |
from dataclasses import dataclass
|
29 |
from functools import partial
|
@@ -297,13 +295,7 @@ class BaseConditioner(nn.Module):
|
|
297 |
self.output_dim = output_dim
|
298 |
self.output_proj = nn.Linear(dim, output_dim)
|
299 |
|
300 |
-
|
301 |
-
"""Should be any part of the processing that will lead to a synchronization
|
302 |
-
point, e.g. BPE tokenization with transfer to the GPU.
|
303 |
-
|
304 |
-
The returned value will be saved and return later when calling forward().
|
305 |
-
"""
|
306 |
-
raise NotImplementedError()
|
307 |
|
308 |
def forward(self, inputs: tp.Any) -> ConditionType:
|
309 |
"""Gets input that should be used as conditioning (e.g, genre, description or a waveform).
|
@@ -530,34 +522,6 @@ class ConditioningProvider(nn.Module):
|
|
530 |
def has_wav_condition(self):
|
531 |
return len(self.wav_conditions) > 0
|
532 |
|
533 |
-
def tokenize(self, inputs: tp.List[ConditioningAttributes]) -> tp.Dict[str, tp.Any]:
|
534 |
-
"""Match attributes/wavs with existing conditioners in self, and compute tokenize them accordingly.
|
535 |
-
This should be called before starting any real GPU work to avoid synchronization points.
|
536 |
-
This will return a dict matching conditioner names to their arbitrary tokenized representations.
|
537 |
-
|
538 |
-
Args:
|
539 |
-
inputs (list[ConditioningAttributes]): List of ConditioningAttributes objects containing
|
540 |
-
text and wav conditions.
|
541 |
-
"""
|
542 |
-
assert all([isinstance(x, ConditioningAttributes) for x in inputs]), (
|
543 |
-
"Got unexpected types input for conditioner! should be tp.List[ConditioningAttributes]",
|
544 |
-
f" but types were {set([type(x) for x in inputs])}"
|
545 |
-
)
|
546 |
-
|
547 |
-
output = {}
|
548 |
-
text = self._collate_text(inputs)
|
549 |
-
wavs = self._collate_wavs(inputs)
|
550 |
-
joint_embeds = self._collate_joint_embeds(inputs)
|
551 |
-
|
552 |
-
assert set(text.keys() | wavs.keys() | joint_embeds.keys()).issubset(set(self.conditioners.keys())), (
|
553 |
-
f"Got an unexpected attribute! Expected {self.conditioners.keys()}, ",
|
554 |
-
f"got {text.keys(), wavs.keys(), joint_embeds.keys()}"
|
555 |
-
)
|
556 |
-
|
557 |
-
for attribute, batch in chain(text.items(), wavs.items(), joint_embeds.items()):
|
558 |
-
output[attribute] = self.conditioners[attribute].tokenize(batch)
|
559 |
-
return output
|
560 |
-
|
561 |
def forward(self, tokenized: tp.Dict[str, tp.Any]) -> tp.Dict[str, ConditionType]:
|
562 |
"""Compute pairs of `(embedding, mask)` using the configured conditioners and the tokenized representations.
|
563 |
The output is for example:
|
@@ -780,6 +744,7 @@ class ConditionFuser(StreamingModule):
|
|
780 |
raise ValueError(f"unknown op ({op})")
|
781 |
|
782 |
if self.cross_attention_pos_emb and cross_attention_output is not None:
|
|
|
783 |
positions = torch.arange(
|
784 |
cross_attention_output.shape[1],
|
785 |
device=cross_attention_output.device
|
@@ -925,7 +890,7 @@ class LMModel(StreamingModule):
|
|
925 |
|
926 |
self.condition_provider = condition_provider
|
927 |
self.fuser = fuser
|
928 |
-
self.card = card
|
929 |
embed_dim = self.card + 1
|
930 |
self.n_q = n_q
|
931 |
self.dim = dim
|
@@ -1030,6 +995,7 @@ class LMModel(StreamingModule):
|
|
1030 |
# remove the prefix from the model outputs
|
1031 |
if len(self.fuser.fuse2cond['prepend']) > 0:
|
1032 |
logits = logits[:, :, -S:]
|
|
|
1033 |
|
1034 |
return logits # [B, K, S, card]
|
1035 |
|
@@ -1067,6 +1033,8 @@ class LMModel(StreamingModule):
|
|
1067 |
B, K, T = codes.shape
|
1068 |
codes = codes.contiguous()
|
1069 |
# map codes [B, K, T] into pattern sequence [B, K, S] using special_token_id for masked tokens
|
|
|
|
|
1070 |
pattern = self.pattern_provider.get_pattern(T)
|
1071 |
sequence_codes, sequence_indexes, sequence_mask = pattern.build_pattern_sequence(
|
1072 |
codes, self.special_token_id, keep_only_valid_steps=keep_only_valid_steps,
|
@@ -1118,35 +1086,33 @@ class LMModel(StreamingModule):
|
|
1118 |
model = self if self._fsdp is None else self._fsdp
|
1119 |
two_step_cfg = self.two_step_cfg if two_step_cfg is None else two_step_cfg
|
1120 |
if two_step_cfg and cfg_conditions != {}:
|
1121 |
-
|
1122 |
-
condition_tensors, null_condition_tensors = cfg_conditions
|
1123 |
-
cond_logits = model(sequence, conditions=[], condition_tensors=condition_tensors)
|
1124 |
-
state = self.get_streaming_state()
|
1125 |
-
self.set_streaming_state(unconditional_state)
|
1126 |
-
uncond_logits = model(sequence, conditions=[], condition_tensors=null_condition_tensors)
|
1127 |
-
unconditional_state.update(self.get_streaming_state())
|
1128 |
-
self.set_streaming_state(state)
|
1129 |
-
logits = uncond_logits + (cond_logits - uncond_logits) * self.cfg_coef
|
1130 |
else:
|
|
|
1131 |
assert isinstance(cfg_conditions, dict)
|
1132 |
condition_tensors = cfg_conditions
|
1133 |
if condition_tensors:
|
|
|
1134 |
# Preparing for CFG, predicting both conditional and unconditional logits.
|
1135 |
sequence = torch.cat([sequence, sequence], dim=0)
|
1136 |
all_logits = model(
|
1137 |
sequence,
|
1138 |
conditions=[], condition_tensors=condition_tensors)
|
1139 |
if condition_tensors:
|
1140 |
-
cond_logits, uncond_logits = all_logits.split(B, dim=0)
|
1141 |
-
logits = uncond_logits + (cond_logits - uncond_logits) * cfg_coef
|
|
|
|
|
1142 |
else:
|
1143 |
-
|
|
|
1144 |
|
1145 |
logits = logits.permute(0, 1, 3, 2) # [B, K, card, T]
|
1146 |
logits = logits[..., -1] # [B x K x card]
|
1147 |
|
1148 |
# Apply softmax for sampling if temp > 0. Else, do greedy sampling to avoid zero division error.
|
1149 |
if use_sampling and temp > 0.0:
|
|
|
1150 |
probs = torch.softmax(logits / temp, dim=-1)
|
1151 |
if top_p > 0.0:
|
1152 |
next_token = utils.sample_top_p(probs, p=top_p)
|
@@ -1155,7 +1121,9 @@ class LMModel(StreamingModule):
|
|
1155 |
else:
|
1156 |
next_token = utils.multinomial(probs, num_samples=1)
|
1157 |
else:
|
1158 |
-
|
|
|
|
|
1159 |
|
1160 |
return next_token
|
1161 |
|
@@ -1249,9 +1217,9 @@ class LMModel(StreamingModule):
|
|
1249 |
# this token is used as default value for codes that are not generated yet
|
1250 |
unknown_token = -1
|
1251 |
|
1252 |
-
|
1253 |
gen_codes = torch.full((B, K, max_gen_len), unknown_token, dtype=torch.long, device=device)
|
1254 |
-
|
1255 |
gen_codes[..., :start_offset] = prompt
|
1256 |
# create the gen_sequence with proper interleaving from the pattern: [B, K, S]
|
1257 |
gen_sequence, indexes, mask = pattern.build_pattern_sequence(gen_codes, self.special_token_id)
|
@@ -1280,9 +1248,17 @@ class LMModel(StreamingModule):
|
|
1280 |
# ensure the tokens that should be masked are properly set to special_token_id
|
1281 |
# as the model never output special_token_id
|
1282 |
valid_mask = mask[..., offset:offset+1].expand(B, -1, -1)
|
1283 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1284 |
# ensure we don't overwrite prompt tokens, we only write over unknown tokens
|
1285 |
-
|
1286 |
gen_sequence[..., offset:offset+1] = torch.where(
|
1287 |
gen_sequence[..., offset:offset+1] == unknown_token,
|
1288 |
next_token, gen_sequence[..., offset:offset+1]
|
@@ -1292,23 +1268,11 @@ class LMModel(StreamingModule):
|
|
1292 |
callback(1 + offset - start_offset_sequence, gen_sequence_len - start_offset_sequence)
|
1293 |
unconditional_state.clear()
|
1294 |
|
1295 |
-
# ensure sequence has been entirely filled
|
1296 |
-
assert not (gen_sequence == unknown_token).any()
|
1297 |
-
# ensure gen_sequence pattern and mask are matching
|
1298 |
-
# which means the gen_sequence is valid according to the pattern
|
1299 |
-
assert (
|
1300 |
-
gen_sequence == torch.where(mask[None, ...].expand(B, -1, -1), gen_sequence, self.special_token_id)
|
1301 |
-
).all()
|
1302 |
-
# get back the codes, trimming the prompt if needed and cutting potentially incomplete timesteps
|
1303 |
out_codes, out_indexes, out_mask = pattern.revert_pattern_sequence(gen_sequence, special_token=unknown_token)
|
1304 |
|
1305 |
-
# sanity checks over the returned codes and corresponding masks
|
1306 |
-
assert (out_codes[..., :max_gen_len] != unknown_token).all()
|
1307 |
-
assert (out_mask[..., :max_gen_len] == 1).all()
|
1308 |
-
|
1309 |
out_start_offset = start_offset if remove_prompts else 0
|
1310 |
out_codes = out_codes[..., out_start_offset:max_gen_len]
|
1311 |
|
1312 |
# ensure the returned codes are all valid
|
1313 |
-
assert (out_codes >= 0).all() and (out_codes <= self.card).all()
|
1314 |
return out_codes
|
|
|
14 |
import einops
|
15 |
from num2words import num2words
|
16 |
import spacy
|
17 |
+
from transformers import T5EncoderModel, T5Tokenizer # type: ignore
|
18 |
import torch
|
19 |
import torch.nn.functional as F
|
20 |
from torch.nn.utils.rnn import pad_sequence
|
21 |
from audiocraft.streaming import StreamingModule
|
22 |
from audiocraft.transformer import create_sin_embedding
|
|
|
23 |
from audiocraft.utils.autocast import TorchAutocast
|
24 |
+
from audiocraft.utils.utils import collate, length_to_mask
|
|
|
25 |
from audiocraft.transformer import StreamingTransformer, create_norm_fn
|
26 |
from dataclasses import dataclass
|
27 |
from functools import partial
|
|
|
295 |
self.output_dim = output_dim
|
296 |
self.output_proj = nn.Linear(dim, output_dim)
|
297 |
|
298 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
299 |
|
300 |
def forward(self, inputs: tp.Any) -> ConditionType:
|
301 |
"""Gets input that should be used as conditioning (e.g, genre, description or a waveform).
|
|
|
522 |
def has_wav_condition(self):
|
523 |
return len(self.wav_conditions) > 0
|
524 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
525 |
def forward(self, tokenized: tp.Dict[str, tp.Any]) -> tp.Dict[str, ConditionType]:
|
526 |
"""Compute pairs of `(embedding, mask)` using the configured conditioners and the tokenized representations.
|
527 |
The output is for example:
|
|
|
744 |
raise ValueError(f"unknown op ({op})")
|
745 |
|
746 |
if self.cross_attention_pos_emb and cross_attention_output is not None:
|
747 |
+
print('SIN EMBED')
|
748 |
positions = torch.arange(
|
749 |
cross_attention_output.shape[1],
|
750 |
device=cross_attention_output.device
|
|
|
890 |
|
891 |
self.condition_provider = condition_provider
|
892 |
self.fuser = fuser
|
893 |
+
self.card = card # 2048 ?
|
894 |
embed_dim = self.card + 1
|
895 |
self.n_q = n_q
|
896 |
self.dim = dim
|
|
|
995 |
# remove the prefix from the model outputs
|
996 |
if len(self.fuser.fuse2cond['prepend']) > 0:
|
997 |
logits = logits[:, :, -S:]
|
998 |
+
print('PRESFIX')
|
999 |
|
1000 |
return logits # [B, K, S, card]
|
1001 |
|
|
|
1033 |
B, K, T = codes.shape
|
1034 |
codes = codes.contiguous()
|
1035 |
# map codes [B, K, T] into pattern sequence [B, K, S] using special_token_id for masked tokens
|
1036 |
+
# what is the T is it 2048 ?
|
1037 |
+
# and then what is pattern -> another function?
|
1038 |
pattern = self.pattern_provider.get_pattern(T)
|
1039 |
sequence_codes, sequence_indexes, sequence_mask = pattern.build_pattern_sequence(
|
1040 |
codes, self.special_token_id, keep_only_valid_steps=keep_only_valid_steps,
|
|
|
1086 |
model = self if self._fsdp is None else self._fsdp
|
1087 |
two_step_cfg = self.two_step_cfg if two_step_cfg is None else two_step_cfg
|
1088 |
if two_step_cfg and cfg_conditions != {}:
|
1089 |
+
print('\nNOT HERE\n')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1090 |
else:
|
1091 |
+
print('C')
|
1092 |
assert isinstance(cfg_conditions, dict)
|
1093 |
condition_tensors = cfg_conditions
|
1094 |
if condition_tensors:
|
1095 |
+
# print('\nD\n')
|
1096 |
# Preparing for CFG, predicting both conditional and unconditional logits.
|
1097 |
sequence = torch.cat([sequence, sequence], dim=0)
|
1098 |
all_logits = model(
|
1099 |
sequence,
|
1100 |
conditions=[], condition_tensors=condition_tensors)
|
1101 |
if condition_tensors:
|
1102 |
+
cond_logits, uncond_logits = all_logits.split(B, dim=0) #torch.Size([2, 4, 1, 2048])
|
1103 |
+
# logits = uncond_logits + (cond_logits - uncond_logits) * cfg_coef
|
1104 |
+
# logits = 3 * cond_logits - 2.4 * uncond_logits
|
1105 |
+
logits = 2 * cond_logits - 1.4 * uncond_logits
|
1106 |
else:
|
1107 |
+
print('\nF!\n')
|
1108 |
+
|
1109 |
|
1110 |
logits = logits.permute(0, 1, 3, 2) # [B, K, card, T]
|
1111 |
logits = logits[..., -1] # [B x K x card]
|
1112 |
|
1113 |
# Apply softmax for sampling if temp > 0. Else, do greedy sampling to avoid zero division error.
|
1114 |
if use_sampling and temp > 0.0:
|
1115 |
+
# print(f'\nR {temp=} {top_p=} {top_k=}\n') -------------> R temp=1.0 top_p=0.0 top_k=250
|
1116 |
probs = torch.softmax(logits / temp, dim=-1)
|
1117 |
if top_p > 0.0:
|
1118 |
next_token = utils.sample_top_p(probs, p=top_p)
|
|
|
1121 |
else:
|
1122 |
next_token = utils.multinomial(probs, num_samples=1)
|
1123 |
else:
|
1124 |
+
#
|
1125 |
+
print('\nNeverHere\n')
|
1126 |
+
|
1127 |
|
1128 |
return next_token
|
1129 |
|
|
|
1217 |
# this token is used as default value for codes that are not generated yet
|
1218 |
unknown_token = -1
|
1219 |
|
1220 |
+
|
1221 |
gen_codes = torch.full((B, K, max_gen_len), unknown_token, dtype=torch.long, device=device)
|
1222 |
+
|
1223 |
gen_codes[..., :start_offset] = prompt
|
1224 |
# create the gen_sequence with proper interleaving from the pattern: [B, K, S]
|
1225 |
gen_sequence, indexes, mask = pattern.build_pattern_sequence(gen_codes, self.special_token_id)
|
|
|
1248 |
# ensure the tokens that should be masked are properly set to special_token_id
|
1249 |
# as the model never output special_token_id
|
1250 |
valid_mask = mask[..., offset:offset+1].expand(B, -1, -1)
|
1251 |
+
|
1252 |
+
# next_token[~valid_mask] = self.special_token_id
|
1253 |
+
|
1254 |
+
# print(f'{unconditional_state=} \n
|
1255 |
+
# print('Set All to Special')
|
1256 |
+
# next_token[:] = self.special_token_id
|
1257 |
+
|
1258 |
+
|
1259 |
+
|
1260 |
# ensure we don't overwrite prompt tokens, we only write over unknown tokens
|
1261 |
+
|
1262 |
gen_sequence[..., offset:offset+1] = torch.where(
|
1263 |
gen_sequence[..., offset:offset+1] == unknown_token,
|
1264 |
next_token, gen_sequence[..., offset:offset+1]
|
|
|
1268 |
callback(1 + offset - start_offset_sequence, gen_sequence_len - start_offset_sequence)
|
1269 |
unconditional_state.clear()
|
1270 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1271 |
out_codes, out_indexes, out_mask = pattern.revert_pattern_sequence(gen_sequence, special_token=unknown_token)
|
1272 |
|
|
|
|
|
|
|
|
|
1273 |
out_start_offset = start_offset if remove_prompts else 0
|
1274 |
out_codes = out_codes[..., out_start_offset:max_gen_len]
|
1275 |
|
1276 |
# ensure the returned codes are all valid
|
1277 |
+
# assert (out_codes >= 0).all() and (out_codes <= self.card).all()
|
1278 |
return out_codes
|
audiocraft/loaders.py
CHANGED
@@ -79,7 +79,7 @@ def load_compression_model(file_or_url_or_id: tp.Union[Path, str], device='cpu',
|
|
79 |
cfg = OmegaConf.create(pkg['xp.cfg'])
|
80 |
cfg.device = str(device)
|
81 |
model = builders.get_compression_model(cfg)
|
82 |
-
model.load_state_dict(pkg['best_state'])
|
83 |
model.eval()
|
84 |
return model
|
85 |
|
|
|
79 |
cfg = OmegaConf.create(pkg['xp.cfg'])
|
80 |
cfg.device = str(device)
|
81 |
model = builders.get_compression_model(cfg)
|
82 |
+
model.load_state_dict(pkg['best_state'], strict=False) # ckpt contains uninstantiated encoder
|
83 |
model.eval()
|
84 |
return model
|
85 |
|
audiocraft/seanet.py
CHANGED
@@ -60,136 +60,25 @@ class SEANetResnetBlock(nn.Module):
|
|
60 |
return self.shortcut(x) + self.block(x)
|
61 |
|
62 |
|
63 |
-
class SEANetEncoder(nn.Module):
|
64 |
-
"""SEANet encoder.
|
65 |
|
66 |
-
Args:
|
67 |
-
channels (int): Audio channels.
|
68 |
-
dimension (int): Intermediate representation dimension.
|
69 |
-
n_filters (int): Base width for the model.
|
70 |
-
n_residual_layers (int): nb of residual layers.
|
71 |
-
ratios (Sequence[int]): kernel size and stride ratios. The encoder uses downsampling ratios instead of
|
72 |
-
upsampling ratios, hence it will use the ratios in the reverse order to the ones specified here
|
73 |
-
that must match the decoder order. We use the decoder order as some models may only employ the decoder.
|
74 |
-
activation (str): Activation function.
|
75 |
-
activation_params (dict): Parameters to provide to the activation function.
|
76 |
-
norm (str): Normalization method.
|
77 |
-
norm_params (dict): Parameters to provide to the underlying normalization used along with the convolution.
|
78 |
-
kernel_size (int): Kernel size for the initial convolution.
|
79 |
-
last_kernel_size (int): Kernel size for the initial convolution.
|
80 |
-
residual_kernel_size (int): Kernel size for the residual layers.
|
81 |
-
dilation_base (int): How much to increase the dilation with each layer.
|
82 |
-
causal (bool): Whether to use fully causal convolution.
|
83 |
-
pad_mode (str): Padding mode for the convolutions.
|
84 |
-
true_skip (bool): Whether to use true skip connection or a simple
|
85 |
-
(streamable) convolution as the skip connection in the residual network blocks.
|
86 |
-
compress (int): Reduced dimensionality in residual branches (from Demucs v3).
|
87 |
-
lstm (int): Number of LSTM layers at the end of the encoder.
|
88 |
-
disable_norm_outer_blocks (int): Number of blocks for which we don't apply norm.
|
89 |
-
For the encoder, it corresponds to the N first blocks.
|
90 |
-
"""
|
91 |
-
def __init__(self, channels: int = 1, dimension: int = 128, n_filters: int = 32, n_residual_layers: int = 3,
|
92 |
-
ratios: tp.List[int] = [8, 5, 4, 2], activation: str = 'ELU', activation_params: dict = {'alpha': 1.0},
|
93 |
-
norm: str = 'none', norm_params: tp.Dict[str, tp.Any] = {}, kernel_size: int = 7,
|
94 |
-
last_kernel_size: int = 7, residual_kernel_size: int = 3, dilation_base: int = 2, causal: bool = False,
|
95 |
-
pad_mode: str = 'reflect', true_skip: bool = True, compress: int = 2, lstm: int = 0,
|
96 |
-
disable_norm_outer_blocks: int = 0):
|
97 |
-
super().__init__()
|
98 |
-
self.channels = channels
|
99 |
-
self.dimension = dimension
|
100 |
-
self.n_filters = n_filters
|
101 |
-
self.ratios = list(reversed(ratios))
|
102 |
-
del ratios
|
103 |
-
self.n_residual_layers = n_residual_layers
|
104 |
-
self.hop_length = np.prod(self.ratios)
|
105 |
-
self.n_blocks = len(self.ratios) + 2 # first and last conv + residual blocks
|
106 |
-
self.disable_norm_outer_blocks = disable_norm_outer_blocks
|
107 |
-
assert self.disable_norm_outer_blocks >= 0 and self.disable_norm_outer_blocks <= self.n_blocks, \
|
108 |
-
"Number of blocks for which to disable norm is invalid." \
|
109 |
-
"It should be lower or equal to the actual number of blocks in the network and greater or equal to 0."
|
110 |
-
|
111 |
-
act = getattr(nn, activation)
|
112 |
-
mult = 1
|
113 |
-
model: tp.List[nn.Module] = [
|
114 |
-
StreamableConv1d(channels, mult * n_filters, kernel_size,
|
115 |
-
norm='none' if self.disable_norm_outer_blocks >= 1 else norm,
|
116 |
-
norm_kwargs=norm_params, causal=causal, pad_mode=pad_mode)
|
117 |
-
]
|
118 |
-
# Downsample to raw audio scale
|
119 |
-
for i, ratio in enumerate(self.ratios):
|
120 |
-
block_norm = 'none' if self.disable_norm_outer_blocks >= i + 2 else norm
|
121 |
-
# Add residual layers
|
122 |
-
for j in range(n_residual_layers):
|
123 |
-
model += [
|
124 |
-
SEANetResnetBlock(mult * n_filters, kernel_sizes=[residual_kernel_size, 1],
|
125 |
-
dilations=[dilation_base ** j, 1],
|
126 |
-
norm=block_norm, norm_params=norm_params,
|
127 |
-
activation=activation, activation_params=activation_params,
|
128 |
-
causal=causal, pad_mode=pad_mode, compress=compress, true_skip=true_skip)]
|
129 |
-
|
130 |
-
# Add downsampling layers
|
131 |
-
model += [
|
132 |
-
act(**activation_params),
|
133 |
-
StreamableConv1d(mult * n_filters, mult * n_filters * 2,
|
134 |
-
kernel_size=ratio * 2, stride=ratio,
|
135 |
-
norm=block_norm, norm_kwargs=norm_params,
|
136 |
-
causal=causal, pad_mode=pad_mode),
|
137 |
-
]
|
138 |
-
mult *= 2
|
139 |
-
|
140 |
-
if lstm:
|
141 |
-
model += [StreamableLSTM(mult * n_filters, num_layers=lstm)]
|
142 |
-
|
143 |
-
model += [
|
144 |
-
act(**activation_params),
|
145 |
-
StreamableConv1d(mult * n_filters, dimension, last_kernel_size,
|
146 |
-
norm='none' if self.disable_norm_outer_blocks == self.n_blocks else norm,
|
147 |
-
norm_kwargs=norm_params, causal=causal, pad_mode=pad_mode)
|
148 |
-
]
|
149 |
-
|
150 |
-
self.model = nn.Sequential(*model)
|
151 |
-
|
152 |
-
def forward(self, x):
|
153 |
-
return self.model(x)
|
154 |
|
155 |
|
156 |
class SEANetDecoder(nn.Module):
|
157 |
-
"""SEANet decoder.
|
158 |
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
residual_kernel_size (int): Kernel size for the residual layers.
|
174 |
-
dilation_base (int): How much to increase the dilation with each layer.
|
175 |
-
causal (bool): Whether to use fully causal convolution.
|
176 |
-
pad_mode (str): Padding mode for the convolutions.
|
177 |
-
true_skip (bool): Whether to use true skip connection or a simple.
|
178 |
-
(streamable) convolution as the skip connection in the residual network blocks.
|
179 |
-
compress (int): Reduced dimensionality in residual branches (from Demucs v3).
|
180 |
-
lstm (int): Number of LSTM layers at the end of the encoder.
|
181 |
-
disable_norm_outer_blocks (int): Number of blocks for which we don't apply norm.
|
182 |
-
For the decoder, it corresponds to the N last blocks.
|
183 |
-
trim_right_ratio (float): Ratio for trimming at the right of the transposed convolution under the causal setup.
|
184 |
-
If equal to 1.0, it means that all the trimming is done at the right.
|
185 |
-
"""
|
186 |
-
def __init__(self, channels: int = 1, dimension: int = 128, n_filters: int = 32, n_residual_layers: int = 3,
|
187 |
-
ratios: tp.List[int] = [8, 5, 4, 2], activation: str = 'ELU', activation_params: dict = {'alpha': 1.0},
|
188 |
-
final_activation: tp.Optional[str] = None, final_activation_params: tp.Optional[dict] = None,
|
189 |
-
norm: str = 'none', norm_params: tp.Dict[str, tp.Any] = {}, kernel_size: int = 7,
|
190 |
-
last_kernel_size: int = 7, residual_kernel_size: int = 3, dilation_base: int = 2, causal: bool = False,
|
191 |
-
pad_mode: str = 'reflect', true_skip: bool = True, compress: int = 2, lstm: int = 0,
|
192 |
-
disable_norm_outer_blocks: int = 0, trim_right_ratio: float = 1.0):
|
193 |
super().__init__()
|
194 |
self.dimension = dimension
|
195 |
self.channels = channels
|
|
|
60 |
return self.shortcut(x) + self.block(x)
|
61 |
|
62 |
|
|
|
|
|
63 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
64 |
|
65 |
|
66 |
class SEANetDecoder(nn.Module):
|
|
|
67 |
|
68 |
+
def __init__(self, channels: int = 1,
|
69 |
+
dimension: int = 128, n_filters: int = 32, n_residual_layers: int = 3,
|
70 |
+
ratios: tp.List[int] = [8, 5, 4, 2], activation: str = 'ELU',
|
71 |
+
activation_params: dict = {'alpha': 1.0},
|
72 |
+
final_activation: tp.Optional[str] = None,
|
73 |
+
final_activation_params: tp.Optional[dict] = None,
|
74 |
+
norm: str = 'none', norm_params: tp.Dict[str, tp.Any] = {},
|
75 |
+
kernel_size: int = 7,
|
76 |
+
last_kernel_size: int = 7, residual_kernel_size: int = 3,
|
77 |
+
dilation_base: int = 2, causal: bool = False,
|
78 |
+
pad_mode: str = 'reflect', true_skip: bool = True,
|
79 |
+
compress: int = 2, lstm: int = 0,
|
80 |
+
disable_norm_outer_blocks: int = 0,
|
81 |
+
trim_right_ratio: float = 1.0):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
82 |
super().__init__()
|
83 |
self.dimension = dimension
|
84 |
self.channels = channels
|
audiocraft/utils/utils.py
CHANGED
@@ -4,7 +4,7 @@
|
|
4 |
# This source code is licensed under the license found in the
|
5 |
# LICENSE file in the root directory of this source tree.
|
6 |
|
7 |
-
|
8 |
from contextlib import contextmanager
|
9 |
from functools import wraps, lru_cache
|
10 |
import hashlib
|
@@ -103,6 +103,9 @@ def multinomial(input: torch.Tensor, num_samples: int, replacement=False, *, gen
|
|
103 |
input_ = input.reshape(-1, input.shape[-1])
|
104 |
output_ = torch.multinomial(input_, num_samples=num_samples, replacement=replacement, generator=generator)
|
105 |
output = output_.reshape(*list(input.shape[:-1]), -1)
|
|
|
|
|
|
|
106 |
return output
|
107 |
|
108 |
|
@@ -115,61 +118,18 @@ def sample_top_k(probs: torch.Tensor, k: int) -> torch.Tensor:
|
|
115 |
Returns:
|
116 |
torch.Tensor: Sampled tokens.
|
117 |
"""
|
118 |
-
top_k_value,
|
119 |
-
min_value_top_k = top_k_value[..., [-1]]
|
120 |
-
probs *= (probs >= min_value_top_k).float()
|
121 |
-
probs.div_(probs.sum(dim=-1, keepdim=True))
|
122 |
next_token = multinomial(probs, num_samples=1)
|
|
|
|
|
|
|
123 |
return next_token
|
124 |
|
125 |
|
126 |
-
def sample_top_p(probs: torch.Tensor, p: float) -> torch.Tensor:
|
127 |
-
"""Sample next token from top P probabilities along the last dimension of the input probs tensor.
|
128 |
-
|
129 |
-
Args:
|
130 |
-
probs (torch.Tensor): Input probabilities with token candidates on the last dimension.
|
131 |
-
p (int): The p in “top-p”.
|
132 |
-
Returns:
|
133 |
-
torch.Tensor: Sampled tokens.
|
134 |
-
"""
|
135 |
-
probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
|
136 |
-
probs_sum = torch.cumsum(probs_sort, dim=-1)
|
137 |
-
mask = probs_sum - probs_sort > p
|
138 |
-
probs_sort *= (~mask).float()
|
139 |
-
probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
|
140 |
-
next_token = multinomial(probs_sort, num_samples=1)
|
141 |
-
next_token = torch.gather(probs_idx, -1, next_token)
|
142 |
-
return next_token
|
143 |
-
|
144 |
-
|
145 |
-
class DummyPoolExecutor:
|
146 |
-
"""Dummy pool executor to use when we actually have only 1 worker.
|
147 |
-
(e.g. instead of ProcessPoolExecutor).
|
148 |
-
"""
|
149 |
-
class DummyResult:
|
150 |
-
def __init__(self, func, *args, **kwargs):
|
151 |
-
self.func = func
|
152 |
-
self.args = args
|
153 |
-
self.kwargs = kwargs
|
154 |
-
|
155 |
-
def result(self):
|
156 |
-
return self.func(*self.args, **self.kwargs)
|
157 |
-
|
158 |
-
def __init__(self, workers, mp_context=None):
|
159 |
-
pass
|
160 |
-
|
161 |
-
def submit(self, func, *args, **kwargs):
|
162 |
-
return DummyPoolExecutor.DummyResult(func, *args, **kwargs)
|
163 |
-
|
164 |
-
def __enter__(self):
|
165 |
-
return self
|
166 |
|
167 |
-
def __exit__(self, exc_type, exc_value, exc_tb):
|
168 |
-
return
|
169 |
-
|
170 |
-
|
171 |
-
def get_pool_executor(num_workers: int, mp_context=None):
|
172 |
-
return ProcessPoolExecutor(num_workers, mp_context) if num_workers > 1 else DummyPoolExecutor(1)
|
173 |
|
174 |
|
175 |
def length_to_mask(lengths: torch.Tensor, max_len: tp.Optional[int] = None) -> torch.Tensor:
|
@@ -188,42 +148,6 @@ def length_to_mask(lengths: torch.Tensor, max_len: tp.Optional[int] = None) -> t
|
|
188 |
return torch.arange(final_length, device=lengths.device)[None, :] < lengths[:, None]
|
189 |
|
190 |
|
191 |
-
def hash_trick(word: str, vocab_size: int) -> int:
|
192 |
-
"""Hash trick to pair each word with an index
|
193 |
-
|
194 |
-
Args:
|
195 |
-
word (str): word we wish to convert to an index
|
196 |
-
vocab_size (int): size of the vocabulary
|
197 |
-
Returns:
|
198 |
-
int: index of the word in the embedding LUT
|
199 |
-
"""
|
200 |
-
hash = int(hashlib.sha256(word.encode("utf-8")).hexdigest(), 16)
|
201 |
-
return hash % vocab_size
|
202 |
-
|
203 |
-
|
204 |
-
def with_rank_rng(base_seed: int = 1234):
|
205 |
-
"""Decorator for a function so that the function will use a Random Number Generator
|
206 |
-
whose state depend on the GPU rank. The original RNG state is restored upon returning.
|
207 |
-
|
208 |
-
Args:
|
209 |
-
base_seed (int): Random seed.
|
210 |
-
"""
|
211 |
-
def _decorator(fun: tp.Callable):
|
212 |
-
@wraps(fun)
|
213 |
-
def _decorated(*args, **kwargs):
|
214 |
-
state = torch.get_rng_state()
|
215 |
-
seed = base_seed ^ flashy.distrib.rank()
|
216 |
-
torch.manual_seed(seed)
|
217 |
-
logger.debug('Rank dependent seed set to %d', seed)
|
218 |
-
try:
|
219 |
-
return fun(*args, **kwargs)
|
220 |
-
finally:
|
221 |
-
torch.set_rng_state(state)
|
222 |
-
logger.debug('RNG state restored.')
|
223 |
-
return _decorated
|
224 |
-
return _decorator
|
225 |
-
|
226 |
-
|
227 |
def collate(tensors: tp.List[torch.Tensor], dim: int = 0) -> tp.Tuple[torch.Tensor, torch.Tensor]:
|
228 |
"""Get a list of tensors and collate them to a single tensor. according to the following logic:
|
229 |
- `dim` specifies the time dimension which will be stacked and padded.
|
@@ -247,52 +171,3 @@ def collate(tensors: tp.List[torch.Tensor], dim: int = 0) -> tp.Tuple[torch.Tens
|
|
247 |
return padded_tensors, lens
|
248 |
|
249 |
|
250 |
-
# TODO: Move to flashy?
|
251 |
-
def copy_state(state: tp.Any, device: tp.Union[torch.device, str] = 'cpu',
|
252 |
-
dtype: tp.Optional[torch.dtype] = None) -> tp.Any:
|
253 |
-
if isinstance(state, torch.Tensor):
|
254 |
-
if dtype is None or not state.is_floating_point():
|
255 |
-
dtype = state.dtype
|
256 |
-
return state.detach().to(device=device, dtype=dtype, copy=True)
|
257 |
-
elif isinstance(state, dict):
|
258 |
-
return {k: copy_state(v, device, dtype) for k, v in state.items()}
|
259 |
-
elif isinstance(state, list):
|
260 |
-
return [copy_state(v, device, dtype) for v in state]
|
261 |
-
|
262 |
-
|
263 |
-
# TODO: Move to flashy?
|
264 |
-
@contextmanager
|
265 |
-
def swap_state(model, state, **kwargs):
|
266 |
-
old_state = copy_state(model.state_dict())
|
267 |
-
model.load_state_dict(state, **kwargs)
|
268 |
-
try:
|
269 |
-
yield
|
270 |
-
finally:
|
271 |
-
model.load_state_dict(old_state)
|
272 |
-
|
273 |
-
|
274 |
-
@lru_cache(None)
|
275 |
-
def warn_once(logger, msg):
|
276 |
-
"""Warn about a given message only once."""
|
277 |
-
logger.warning(msg)
|
278 |
-
|
279 |
-
|
280 |
-
def is_jsonable(x: tp.Any):
|
281 |
-
"""Check if an object can be serialized into a json:"""
|
282 |
-
try:
|
283 |
-
json.dumps(x)
|
284 |
-
return True
|
285 |
-
except (TypeError, OverflowError):
|
286 |
-
return False
|
287 |
-
|
288 |
-
|
289 |
-
def load_clap_state_dict(clap_model, path: tp.Union[str, Path]):
|
290 |
-
"""Wrapper around state dict loading of CLAP model
|
291 |
-
addressing compatibility issues between CLAP and AudioCraft
|
292 |
-
HuggingFace transformer version.
|
293 |
-
See: https://github.com/LAION-AI/CLAP/issues/118
|
294 |
-
"""
|
295 |
-
from clap_module.factory import load_state_dict # type: ignore
|
296 |
-
pkg = load_state_dict(path)
|
297 |
-
pkg.pop('text_branch.embeddings.position_ids', None)
|
298 |
-
clap_model.model.load_state_dict(pkg)
|
|
|
4 |
# This source code is licensed under the license found in the
|
5 |
# LICENSE file in the root directory of this source tree.
|
6 |
|
7 |
+
|
8 |
from contextlib import contextmanager
|
9 |
from functools import wraps, lru_cache
|
10 |
import hashlib
|
|
|
103 |
input_ = input.reshape(-1, input.shape[-1])
|
104 |
output_ = torch.multinomial(input_, num_samples=num_samples, replacement=replacement, generator=generator)
|
105 |
output = output_.reshape(*list(input.shape[:-1]), -1)
|
106 |
+
|
107 |
+
# print('MULTINOmial', input.shape, output.shape) # MULTINOmial torch.Size([1, 4, 2048]) torch.Size([1, 4, 1])
|
108 |
+
# output = input[..., 0:1]
|
109 |
return output
|
110 |
|
111 |
|
|
|
118 |
Returns:
|
119 |
torch.Tensor: Sampled tokens.
|
120 |
"""
|
121 |
+
top_k_value, i250 = torch.topk(probs, k, dim=-1) # probs: [1, 4, 2048]
|
122 |
+
min_value_top_k = top_k_value[..., [-1]] #
|
123 |
+
probs *= (probs >= min_value_top_k).float() # multiply all being > of min_topk with 1 thus zeroing others
|
124 |
+
probs.div_(probs.sum(dim=-1, keepdim=True)) # why normalize by the sum ? oh in order to choose mult
|
125 |
next_token = multinomial(probs, num_samples=1)
|
126 |
+
# so instead of chooose multinomial what happens if we take all 250 topk tokens
|
127 |
+
# probs.shape=torch.Size([1, 4, 2048]) <, print(next_token,f'{probs.shape=}', 'h') # 1,4,1 next token is 4tok
|
128 |
+
# next_token = i250
|
129 |
return next_token
|
130 |
|
131 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
132 |
|
|
|
|
|
|
|
|
|
|
|
|
|
133 |
|
134 |
|
135 |
def length_to_mask(lengths: torch.Tensor, max_len: tp.Optional[int] = None) -> torch.Tensor:
|
|
|
148 |
return torch.arange(final_length, device=lengths.device)[None, :] < lengths[:, None]
|
149 |
|
150 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
151 |
def collate(tensors: tp.List[torch.Tensor], dim: int = 0) -> tp.Tuple[torch.Tensor, torch.Tensor]:
|
152 |
"""Get a list of tensors and collate them to a single tensor. according to the following logic:
|
153 |
- `dim` specifies the time dimension which will be stacked and padded.
|
|
|
171 |
return padded_tensors, lens
|
172 |
|
173 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|