Spaces:
Runtime error
Runtime error
import random | |
from typing import Dict, Iterator, List, Tuple, Union | |
from fairseq import utils | |
import numpy as np | |
import torch | |
import math | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from fairseq.data import Dictionary | |
from src.slam_llm.models.vallex.transformers import ( | |
LayerNorm, | |
TransformerEncoder, | |
TransformerEncoderLayer, | |
) | |
from src.slam_llm.models.vallex.vallex_config import VallexConfig | |
from transformers.modeling_utils import PreTrainedModel | |
from transformers import AutoConfig, AutoModel, AutoModelForImageClassification | |
from dataclasses import dataclass | |
class ModelOutput: | |
logits: torch.Tensor | |
loss: torch.Tensor | |
acc: torch.Tensor | |
def label_smoothed_nll_loss(lprobs, target, epsilon, ignore_index=None, reduce=True, scale=1, prob_mask=None): | |
if target.dim() == lprobs.dim() - 1: | |
target = target.unsqueeze(-1) | |
if prob_mask is not None: | |
lprobs = lprobs.masked_fill(prob_mask, 0.0) | |
n_class = (1-prob_mask.float()).sum() | |
else: | |
n_class = lprobs.size(-1) | |
nll_loss = -lprobs.gather(dim=-1, index=target) | |
# nll_loss = nll_loss * scale | |
smooth_loss = -lprobs.sum(dim=-1, keepdim=True) * scale | |
if ignore_index is not None: | |
pad_mask = target.eq(ignore_index) | |
nll_loss.masked_fill_(pad_mask, 0.0) | |
smooth_loss.masked_fill_(pad_mask, 0.0) | |
pad_mask_float = (1 - pad_mask.to(torch.float)).sum() | |
else: | |
nll_loss = nll_loss.squeeze(-1) | |
smooth_loss = smooth_loss.squeeze(-1) | |
if reduce: | |
nll_loss = nll_loss.sum() | |
smooth_loss = smooth_loss.sum() | |
eps_i = epsilon / (n_class - 1) | |
loss = (1.0 - epsilon - eps_i) * nll_loss + \ | |
eps_i * smooth_loss | |
return loss / pad_mask_float, nll_loss / pad_mask_float | |
class SinusoidalPositionalEmbedding(nn.Module): | |
def __init__(self, embedding_dim, padding_idx, init_size=1024): | |
super().__init__() | |
self.embedding_dim = embedding_dim | |
self.padding_idx = padding_idx if padding_idx is not None else 0 | |
self.weights = SinusoidalPositionalEmbedding.get_embedding( | |
init_size, embedding_dim, padding_idx | |
) | |
self.onnx_trace = False | |
self.register_buffer("_float_tensor", torch.FloatTensor(1)) | |
self.max_positions = int(1e5) | |
def prepare_for_onnx_export_(self): | |
self.onnx_trace = True | |
def get_embedding( | |
num_embeddings: int, embedding_dim: int, padding_idx = None | |
): | |
half_dim = embedding_dim // 2 | |
emb = math.log(10000) / (half_dim - 1) | |
emb = torch.exp(torch.arange(half_dim, dtype=torch.float) * -emb) | |
emb = torch.arange(num_embeddings, dtype=torch.float).unsqueeze( | |
1 | |
) * emb.unsqueeze(0) | |
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view( | |
num_embeddings, -1 | |
) | |
if embedding_dim % 2 == 1: | |
# zero pad | |
emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1) | |
if padding_idx is not None: | |
emb[padding_idx, :] = 0 | |
return emb | |
def forward( | |
self, | |
input, | |
incremental_state = None, | |
timestep = None, | |
positions = None, | |
): | |
bspair = torch.onnx.operators.shape_as_tensor(input) | |
bsz, seq_len = bspair[0], bspair[1] | |
max_pos = self.padding_idx + 1 + seq_len | |
if self.weights is None or max_pos > self.weights.size(0): | |
# recompute/expand embeddings if needed | |
self.weights = SinusoidalPositionalEmbedding.get_embedding( | |
max_pos, self.embedding_dim, self.padding_idx | |
) | |
self.weights = self.weights.to(self._float_tensor) | |
if incremental_state is not None: | |
# positions is the same for every token when decoding a single step | |
pos = timestep.view(-1)[0] + 1 if timestep is not None else seq_len | |
if self.onnx_trace: | |
return ( | |
self.weights.index_select(index=self.padding_idx + pos, dim=0) | |
.unsqueeze(1) | |
.repeat(bsz, 1, 1) | |
) | |
return self.weights[self.padding_idx + pos, :].expand(bsz, 1, -1) | |
positions = utils.make_positions( | |
input, self.padding_idx, onnx_trace=self.onnx_trace | |
) | |
if self.onnx_trace: | |
flat_embeddings = self.weights.detach().index_select(0, positions.view(-1)) | |
embedding_shape = torch.cat( | |
(bsz.view(1), seq_len.view(1), torch.tensor([-1], dtype=torch.long)) | |
) | |
embeddings = torch.onnx.operators.reshape_from_tensor_shape( | |
flat_embeddings, embedding_shape | |
) | |
return embeddings | |
return ( | |
self.weights.index_select(0, positions.view(-1)) | |
.view(bsz, seq_len, -1) | |
.detach() | |
) | |
class Transpose(nn.Identity): | |
def forward(self, input: torch.Tensor) -> torch.Tensor: | |
return input.transpose(1, 2) | |
class VALLF(PreTrainedModel): | |
config_class = VallexConfig | |
def __init__( | |
self, | |
config: VallexConfig | |
): | |
super().__init__(config) | |
self.ar_at_dict = Dictionary.load(self.config.ar_at_dict) | |
self.ar_st_dict = Dictionary.load(self.config.ar_st_dict) | |
self.nar_at_dict = Dictionary.load(self.config.nar_at_dict) | |
self.nar_st_dict = Dictionary.load(self.config.nar_st_dict) | |
self.ar_at_dict.tts_flag = self.ar_at_dict.add_symbol("<TTS>") | |
self.ar_st_dict.asr_flag = self.ar_st_dict.add_symbol("<ASR>") | |
self.ar_st_dict.mt_flag = self.ar_st_dict.add_symbol("<MT>") | |
self.padding_idx = self.ar_at_dict.pad() | |
self.config = config | |
d_model = self.config.n_dim | |
nar_scale_factor = self.config.nar_scale_factor | |
prepend_bos = self.config.prepend_bos | |
norm_first = self.config.norm_first | |
num_layers = self.config.n_layer | |
self.NUM_AUDIO_TOKENS = self.ar_at_dict.eos() | |
nar_d_model = int(d_model * nar_scale_factor) | |
self.ar_text_embedding = nn.Embedding(len(self.ar_st_dict), d_model, self.ar_st_dict.pad()) # W_x | |
if config.only_ar: | |
pass | |
else: | |
self.nar_text_embedding = nn.Embedding(len(self.nar_st_dict), d_model, self.nar_st_dict.pad()) | |
# ID self.NUM_AUDIO_TOKENS -> PAD | |
# ID self.NUM_AUDIO_TOKENS + 1 -> BOS | |
self.ar_audio_prepend_bos = prepend_bos | |
self.ar_audio_embedding = EncodecDecoderLstm( | |
dictionary=self.ar_at_dict, emb_dim=d_model | |
) | |
self.ar_text_prenet = nn.Identity() | |
self.ar_audio_prenet = nn.Identity() | |
self.ar_text_position = SinusoidalPositionalEmbedding( | |
d_model, | |
padding_idx=self.ar_at_dict.pad(), | |
init_size=1024+self.ar_at_dict.pad()+1 | |
) | |
self.ar_audio_position = SinusoidalPositionalEmbedding( | |
d_model, | |
padding_idx=self.ar_at_dict.pad(), | |
init_size=1024+self.ar_at_dict.pad()+1 | |
) | |
self.ar_decoder = TransformerEncoder( | |
TransformerEncoderLayer( | |
d_model, | |
self.config.n_head, | |
dim_feedforward=d_model * 4, | |
dropout=0.1, | |
batch_first=True, | |
norm_first=norm_first, | |
), | |
num_layers=num_layers, | |
norm=LayerNorm(d_model) if norm_first else None, | |
) | |
self.ar_predict_layer = nn.Linear( | |
d_model, len(self.ar_at_dict), bias=False | |
) | |
self.rng = random.Random(0) | |
self.num_heads = self.config.n_head | |
self.prefix_mode = self.config.prefix_mode | |
self.num_quantizers = self.config.num_quantizers | |
assert self.num_quantizers >= 1 | |
if config.only_ar: | |
pass | |
else: | |
if self.num_quantizers > 1: | |
self.nar_audio_embeddings = NATEncodecDecoderLstm( | |
codecs=[0, 1, 2, 3, 4, 5, 6, 7], dictionary=self.nar_at_dict, emb_dim=d_model | |
) # W_a | |
self.nar_text_prenet = nn.Identity() | |
self.nar_audio_prenet = nn.Identity() | |
self.nar_text_position = SinusoidalPositionalEmbedding( | |
d_model, | |
padding_idx=self.nar_at_dict.pad(), | |
init_size=1024+self.nar_at_dict.pad()+1 | |
) | |
self.nar_audio_position = SinusoidalPositionalEmbedding( | |
d_model, | |
padding_idx=self.nar_at_dict.pad(), | |
init_size=1024+self.nar_at_dict.pad()+1 | |
) | |
self.nar_decoder = TransformerEncoder( | |
TransformerEncoderLayer( | |
nar_d_model, | |
int(self.num_heads * nar_scale_factor), | |
dim_feedforward=nar_d_model * 4, | |
dropout=0.1, | |
batch_first=True, | |
norm_first=norm_first, | |
adaptive_layer_norm=True, | |
), | |
num_layers=int(num_layers * nar_scale_factor), | |
norm=nn.LayerNorm(nar_d_model) | |
if norm_first | |
else None, | |
) | |
self.nar_predict_layers = nn.ModuleList( | |
[ | |
nn.Linear(nar_d_model, len(self.nar_at_dict), bias=False) | |
for i in range(self.num_quantizers) | |
] | |
) | |
self.nar_stage_embeddings = None | |
def stage_parameters(self, stage: int = 1) -> Iterator[nn.Parameter]: | |
assert stage > 0 | |
if stage == 1: | |
for name, param in self.named_parameters(): | |
if name.startswith("ar_"): | |
print(f" AR parameter: {name}") | |
yield param | |
if stage == 2: | |
for name, param in self.named_parameters(): | |
if name.startswith("nar_"): | |
print(f"NAR parameter: {name}") | |
yield param | |
def stage_named_parameters( | |
self, stage: int = 1 | |
) -> Iterator[Tuple[str, nn.Parameter]]: | |
assert stage > 0 | |
if stage == 1: | |
for pair in self.named_parameters(): | |
if pair[0].startswith("ar_"): | |
yield pair | |
if stage == 2: | |
for pair in self.named_parameters(): | |
if pair[0].startswith("nar_"): | |
yield pair | |
def pad_y_eos(self, y, y_mask_int, eos_id): | |
targets = F.pad(y, (0, 1), value=0) + eos_id * F.pad( | |
y_mask_int, (0, 1), value=1 | |
) | |
# inputs, targets | |
if self.ar_audio_prepend_bos: | |
return ( | |
F.pad(targets[:, :-1], (1, 0), value=self.NUM_AUDIO_TOKENS + 1), | |
targets, | |
) | |
return targets[:, :-1], targets[:, 1:] | |
class VALLE(VALLF): | |
config_class = VallexConfig | |
def __init__( | |
self, | |
config: VallexConfig, | |
**kwargs, | |
): | |
super(VALLE, self).__init__( | |
config, | |
**kwargs, | |
) | |
print(config) | |
self.config = config | |
d_model = self.config.n_dim | |
self.eps = config.eps | |
self.language_ID = { | |
'en': 0, | |
'zh': 1, | |
} | |
self.ar_language_embedding = nn.Embedding(3, d_model, padding_idx=2) | |
self.nar_language_embedding = nn.Embedding(3, d_model, padding_idx=2) | |
self.embed_scale = 32.0 | |
def forward( | |
self, | |
zh, | |
en | |
): | |
""" | |
"zh": { | |
"st_tokens": zh_st, | |
"at_tokens_wbos": zh_prev_at, | |
"at_tokens_tgt": zh_tgt_at, | |
"self_atten_mask": zh_self_atten_mask, | |
"padding_mask": zh_padding_mask, | |
"langid": zh_id.long() | |
}, | |
"en": { | |
"st_tokens": en_st, | |
"at_tokens_wbos": en_prev_at, | |
"at_tokens_tgt": en_tgt_at, | |
"self_atten_mask": en_self_atten_mask, | |
"padding_mask": en_padding_mask, | |
"langid": en_id.long() | |
} | |
""" | |
flag = (np.random.randint(low=0, high=1000) % 2 == 0) # zh or en | |
if flag: | |
data = zh | |
else: | |
data = en | |
st_tokens = data["st_tokens"] | |
at_tokens_wbos = data["at_tokens_wbos"] | |
at_tokens_tgt = data["at_tokens_tgt"] | |
self_atten_mask = data["self_atten_mask"] | |
padding_mask = data["padding_mask"] | |
langid = data["langid"] | |
st_len = st_tokens.size(1) | |
st_emb = self.embed_scale * self.ar_text_embedding(st_tokens) | |
src_lang_emb = self.embed_scale * self.ar_language_embedding(langid) | |
st_emb += src_lang_emb | |
st_pos = self.ar_text_position(st_tokens) | |
st_emb += st_pos | |
at_emb, _ = self.ar_audio_embedding(at_tokens_wbos, None) | |
at_emb = self.embed_scale * at_emb | |
tgt_lang_emb = self.embed_scale * self.ar_language_embedding(langid) | |
at_emb += tgt_lang_emb | |
at_pos = self.ar_audio_position(at_tokens_wbos) | |
at_emb += at_pos | |
x = torch.concat([st_emb, at_emb], dim=1) | |
x = self.ar_decoder( | |
x, | |
mask=self_atten_mask, | |
src_key_padding_mask=padding_mask | |
) | |
x = self.ar_predict_layer(x) | |
x = x[:, st_len:, :] | |
loss, nll_loss, lprob, right_rate = self.calculate_loss( | |
x, at_tokens_tgt | |
) | |
return ModelOutput(logits=lprob, loss=loss, acc=right_rate), right_rate | |
def calculate_loss(self, encoder_out, target, reduce=True, scale=1.0, prob_mask=None, acc=True): | |
lprob = self.get_normalized_probs(encoder_out, log_probs=True) | |
with torch.no_grad(): | |
mask = target.ne(self.padding_idx) | |
n_correct = torch.sum( | |
lprob.argmax(-1).masked_select(mask).eq(target.masked_select(mask)) | |
) | |
total = torch.sum(mask) | |
right_rate = n_correct * 100.0 / total | |
lprob, target = lprob.view(-1, lprob.size(-1)), target.view(-1) | |
loss, nll_loss = label_smoothed_nll_loss( | |
lprob, | |
target, | |
self.eps, | |
ignore_index=self.padding_idx, | |
reduce=reduce, | |
scale=scale, | |
prob_mask=prob_mask | |
) | |
return loss, nll_loss, lprob, right_rate | |
def get_normalized_probs(self, encoder_out, log_probs, sample=None): | |
if torch.is_tensor(encoder_out): | |
logits = encoder_out.float() | |
if log_probs: | |
return F.log_softmax(logits, dim=-1) | |
else: | |
return F.softmax(logits, dim=-1) | |
def inference_24L( | |
self, | |
x: torch.Tensor, | |
x_lens: torch.Tensor, | |
y: torch.Tensor, | |
enroll_x_lens: torch.Tensor, | |
top_k: int = -100, | |
temperature: float = 1.0, | |
prompt_language: str = None, | |
text_language: str = None, | |
best_of: int = 1, | |
length_penalty: float = 1.0, | |
return_worst: bool = False, | |
at_eos: int = -1 | |
) -> torch.Tensor: | |
assert x.ndim == 2, x.shape | |
assert x_lens.ndim == 1, x_lens.shape | |
assert y.ndim == 3, y.shape | |
assert y.shape[0] == 1, y.shape | |
assert torch.all(x_lens > 0) | |
self.NUM_AUDIO_TOKENS = at_eos | |
text = x | |
x = self.embed_scale * self.ar_text_embedding(text) | |
prompt_language_id = prompt_language.to(x.device) | |
text_language_id = text_language.to(x.device) | |
src_lang_emb = self.embed_scale * self.ar_language_embedding(prompt_language_id) | |
tgt_lang_emb = self.embed_scale * self.ar_language_embedding(text_language_id) | |
x[:, :enroll_x_lens, :] += src_lang_emb | |
x[:, enroll_x_lens:, :] += tgt_lang_emb | |
x = self.ar_text_prenet(x) | |
x_pos = self.ar_text_position(text) | |
text_len = x_lens.max() | |
prompts = y | |
prefix_len = y.shape[1] | |
# AR Decoder | |
# TODO: Managing decoder steps avoid repetitive computation | |
y = prompts[..., 0] | |
if self.ar_audio_prepend_bos: | |
y = F.pad(y, (1, 0), value=self.ar_at_dict.tts_flag) | |
x_len = x_lens.max() | |
x_attn_mask = torch.zeros((x_len, x_len), dtype=torch.bool) | |
kv_cache = None | |
use_kv_caching = True | |
sum_logprobs = torch.zeros(best_of, device=y.device) # implement batch decoding here | |
x = x.repeat(best_of, 1, 1) | |
y = y.repeat(best_of, 1) | |
lstm_h = None | |
first_ar = True | |
while True: | |
if first_ar: | |
y_emb, lstm_h = self.ar_audio_embedding(y, lstm_h) | |
y_emb = y_emb * self.embed_scale | |
y_emb = self.ar_audio_prenet(y_emb) | |
y_pos = self.ar_audio_position(y) | |
y_emb[:, :prefix_len] = y_emb[:, :prefix_len] + src_lang_emb | |
y_emb[:, prefix_len:] = y_emb[:, prefix_len:] + tgt_lang_emb | |
xy_pos_token = torch.concat([x_pos+x, y_pos+y_emb], dim=1) | |
first_ar = False | |
else: | |
y_emb_cur, lstm_h = self.ar_audio_embedding(y[:, -1:], lstm_h) | |
y_emb_cur = y_emb_cur * self.embed_scale | |
y_emb_cur = self.ar_audio_prenet(y_emb_cur) | |
y_pos_cur = self.ar_audio_position(y)[:, -1:] | |
y_emb_cur = y_emb_cur + src_lang_emb | |
y_emb_cur = y_emb_cur + tgt_lang_emb | |
xy_pos_token = torch.concat([xy_pos_token, y_pos_cur+y_emb_cur], dim=1) | |
# print(xy_pos_token.size()) | |
y_len = y.shape[1] | |
x_attn_mask_pad = F.pad( | |
x_attn_mask, | |
(0, y_len), | |
value=True, | |
) | |
y_attn_mask = F.pad( | |
torch.triu( | |
torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1 | |
), | |
(x_len, 0), | |
value=False, | |
) | |
xy_attn_mask = torch.concat( | |
[x_attn_mask_pad, y_attn_mask], dim=0 | |
).to(y.device) | |
if use_kv_caching and kv_cache is not None: | |
xy_pos = xy_pos_token[:, [-1]] | |
xy_attn_mask = xy_attn_mask[:, [-1]] | |
else: | |
xy_pos = xy_pos_token | |
xy_dec, kv_cache = self.ar_decoder.infer( | |
xy_pos, | |
mask=xy_attn_mask, | |
past_kv=kv_cache, | |
use_cache=use_kv_caching, | |
) | |
logits = self.ar_predict_layer(xy_dec[:, -1]) | |
samples, current_logprobs = topk_sampling( | |
logits, top_k=top_k, top_p=1, temperature=temperature | |
) | |
sum_logprobs += current_logprobs * (y[:, -1] != self.NUM_AUDIO_TOKENS) | |
samples[y[:, -1] == self.NUM_AUDIO_TOKENS] = self.NUM_AUDIO_TOKENS | |
completed = (samples[:, -1] == self.NUM_AUDIO_TOKENS).all() | |
if ( | |
completed | |
or (y.shape[1] - prompts.shape[1]) > x_lens.max() * 32 | |
): | |
if prompts.shape[1] == y.shape[1]: | |
raise SyntaxError( | |
"well trained model shouldn't reach here." | |
) | |
lengths = torch.sum(y != self.NUM_AUDIO_TOKENS, dim=1) | |
avg_logprobs = sum_logprobs / lengths ** length_penalty | |
# choose the best beam according to sum_logprobs | |
best_beam = y[torch.argmax(avg_logprobs), :] | |
worst_beam = y[torch.argmin(avg_logprobs), :] | |
# strip all eos tokens | |
best_beam = best_beam[best_beam != self.NUM_AUDIO_TOKENS] | |
worst_beam = worst_beam[worst_beam != self.NUM_AUDIO_TOKENS] | |
if return_worst: | |
y = worst_beam.unsqueeze(0) | |
else: | |
y = best_beam.unsqueeze(0) | |
print(f"VALL-E EOS [{prompts.shape[1]} -> {y.shape[1]}]") | |
break | |
y = torch.concat([y, samples], dim=1) | |
codes = [y[:, prefix_len + int(self.ar_audio_prepend_bos) :]] | |
if self.num_quantizers == 1: | |
return torch.stack(codes, dim=-1) | |
if self.prefix_mode in [2, 4]: # Exclude enrolled_phonemes | |
enrolled_len = enroll_x_lens.max().item() | |
# SOS + Synthesis Text + EOS | |
text = torch.concat( | |
[ | |
text[:, :1], | |
text[:, enrolled_len - 1 :], | |
], | |
dim=1, | |
) | |
text_len = text_len - (enrolled_len - 2) | |
assert text.shape[0] == 1 | |
x = self.embed_scale * self.nar_text_embedding(text) | |
# Add language embedding | |
prompt_language_id = prompt_language.to(x.device) | |
text_language_id = text_language.to(x.device) | |
src_lang_emb = self.embed_scale * self.nar_language_embedding(prompt_language_id) | |
tgt_lang_emb = self.embed_scale * self.nar_language_embedding(text_language_id) | |
x[:, :enroll_x_lens, :] += src_lang_emb | |
x[:, enroll_x_lens:, :] += tgt_lang_emb | |
x = self.nar_text_prenet(x) | |
x_pos = self.nar_text_position(text) | |
if self.prefix_mode == 0: | |
for i, predict_layer in enumerate( | |
self.nar_predict_layers | |
): | |
y_pos = self.nar_audio_prenet(y_emb) | |
y_pos = self.nar_audio_position(y_pos) | |
xy_pos = torch.concat([x, y_pos], dim=1) | |
xy_dec, _ = self.nar_decoder( | |
(xy_pos, self.nar_stage_embeddings[i].weight) | |
) | |
logits = predict_layer(xy_dec[:, text_len + prefix_len :]) | |
samples = torch.argmax(logits, dim=-1) | |
codes.append(samples) | |
if i < self.num_quantizers - 2: | |
y_emb[:, :prefix_len] += self.embed_scale * self.nar_audio_embeddings( | |
prompts[..., i + 1] | |
)[0] | |
y_emb[:, prefix_len:] += self.embed_scale * self.nar_audio_embeddings(samples)[0] | |
else: | |
y_pos = self.nar_audio_position(y[:, int(self.ar_audio_prepend_bos):]) | |
ref_at_emb = self.embed_scale * self.nar_audio_embeddings(prompts)[0] + src_lang_emb | |
est_at = y[:, prefix_len+int(self.ar_audio_prepend_bos):].unsqueeze(-1) | |
# | |
for i in range(1, 8): | |
y_emb, _ = self.nar_audio_embeddings(est_at) | |
y_emb = self.embed_scale * y_emb + tgt_lang_emb | |
y_emb = torch.concat([ref_at_emb, y_emb], dim=1) | |
xy_pos = torch.concat([x+x_pos, y_emb+y_pos], dim=1) | |
xy_dec = self.nar_decoder( | |
xy_pos | |
) | |
logits = self.nar_predict_layers[i-1](xy_dec[:, text_len + prefix_len :]) | |
# print(logits.size(), xy_pos.size(), xy_dec.size()) | |
samples = torch.argmax(logits, dim=-1) | |
est_at = torch.concat([est_at, samples.unsqueeze(-1)], dim=-1) | |
codes.append(samples) | |
assert len(codes) == self.num_quantizers | |
return torch.stack(codes, dim=-1) | |
def top_k_top_p_filtering( | |
logits, top_k=0, top_p=1.0, filter_value=-float("Inf"), min_tokens_to_keep=1 | |
): | |
if top_k > 0: | |
top_k = min( | |
max(top_k, min_tokens_to_keep), logits.size(-1) | |
) # Safety check | |
# Remove all tokens with a probability less than the last token of the top-k | |
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] | |
logits[indices_to_remove] = filter_value | |
if top_p < 1.0: | |
sorted_logits, sorted_indices = torch.sort(logits, descending=True) | |
cumulative_probs = torch.cumsum( | |
F.softmax(sorted_logits, dim=-1), dim=-1 | |
) | |
# Remove tokens with cumulative probability above the threshold (token with 0 are kept) | |
sorted_indices_to_remove = cumulative_probs > top_p | |
if min_tokens_to_keep > 1: | |
# Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below) | |
sorted_indices_to_remove[..., :min_tokens_to_keep] = 0 | |
# Shift the indices to the right to keep also the first token above the threshold | |
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[ | |
..., :-1 | |
].clone() | |
sorted_indices_to_remove[..., 0] = 0 | |
# scatter sorted tensors to original indexing | |
indices_to_remove = sorted_indices_to_remove.scatter( | |
1, sorted_indices, sorted_indices_to_remove | |
) | |
logits[indices_to_remove] = filter_value | |
return logits | |
def topk_sampling(logits, top_k=10, top_p=1.0, temperature=1.0): | |
if temperature != 1.0: | |
logits = logits / temperature | |
# Top-p/top-k filtering | |
logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p) | |
# Sample | |
token = torch.multinomial(F.softmax(logits, dim=-1), num_samples=1) | |
logprobs = F.log_softmax(logits.float(), dim=-1) | |
current_logprobs = logprobs[torch.arange(logprobs.shape[0]), token.squeeze(1)] | |
return token, current_logprobs | |
class SLSTM(nn.Module): | |
def __init__(self, dimension: int, num_layers: int = 2, skip: bool = True, bidirectional=False): | |
super().__init__() | |
self.skip = skip | |
self.lstm = nn.LSTM(dimension, dimension, num_layers, bidirectional=bidirectional) | |
if bidirectional: | |
self.out_fc = nn.Linear(dimension*2, dimension) | |
else: | |
self.out_fc = None | |
def forward(self, x, hidden=None): | |
x = x.permute(2, 0, 1) | |
y, hidden = self.lstm(x, hidden) | |
if self.out_fc is not None: | |
y = self.out_fc(y) | |
if self.skip: | |
y = y + x | |
y = y.permute(1, 2, 0) | |
return y, hidden | |
class EncodecDecoderLstm(nn.Module): | |
def __init__(self, dictionary, emb_dim, | |
out_dim=None, | |
num_layers=3, lstm_skip=True, lstm_bidire=False, | |
activation_param={'alpha': 1.0}, **kwargs): | |
super().__init__() | |
# Identity() | |
if out_dim is None: | |
out_dim = emb_dim | |
self.slstm = SLSTM(dimension=out_dim, num_layers=num_layers, skip=lstm_skip, bidirectional=lstm_bidire) | |
self.elu = nn.ELU(**activation_param) | |
self.embedding_dim = emb_dim | |
self.padding_idx = dictionary.pad() | |
self.emb = nn.Embedding(len(dictionary), emb_dim, dictionary.pad_index) | |
def forward(self, x, hidden=None): | |
""" | |
Args: | |
x (_type_): B,T,D | |
""" | |
# print(x.size()) | |
quantized_out = self.emb(x) | |
out, hidden = self.slstm(quantized_out.permute(0,2,1), hidden) | |
out = self.elu(out) | |
return out.permute(0,2,1), hidden | |
class NATEncodecDecoderLstm(nn.Module): | |
def __init__(self, codecs, dictionary, emb_dim, | |
out_dim=None, | |
num_layers=3, lstm_skip=True, lstm_bidire=False, | |
activation_param={'alpha': 1.0}, **kwargs): | |
super().__init__() | |
# Identity() | |
if out_dim is None: | |
out_dim = emb_dim | |
self.slstm = SLSTM(dimension=out_dim, num_layers=num_layers, skip=lstm_skip, bidirectional=lstm_bidire) | |
self.elu = nn.ELU(**activation_param) | |
self.codecs = codecs | |
self.embedding_dim = emb_dim | |
self.padding_idx = dictionary.pad() | |
self.emb_list = nn.ModuleList( | |
[nn.Embedding(len(dictionary), emb_dim, dictionary.pad_index) for i in range(len(self.codecs))] | |
) | |
def forward(self, x, hidden=None): | |
""" | |
Args: | |
x (_type_): B,T,D | |
""" | |
if len(x.size()) == 2: | |
x = x.unsqueeze(-1) | |
if x.size(2) != len(self.codecs) and x.size(1) == len(self.codecs): | |
x = x.permute(0, 2, 1) | |
quantized_out = 0 | |
for i in range(x.size(2)): | |
quantized = self.emb_list[i](x[: , :, i]) | |
quantized_out = quantized_out + quantized | |
# quantized_out = quantized_out / len(self.codecs) | |
out, hidden = self.slstm(quantized_out.permute(0,2,1), hidden) | |
out = self.elu(out) | |
return out.permute(0,2,1), hidden | |
AutoModel.register(VallexConfig, VALLE) |