Spaces:
Running
on
Zero
Running
on
Zero
# Copyright 2025 ByteDance and/or its affiliates. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import random | |
from copy import deepcopy | |
import torch | |
import torch.nn.functional as F | |
from torch import nn | |
from torch.nn import Linear | |
from tqdm import tqdm | |
from tts.modules.ar_dur.commons.layers import Embedding, LayerNorm | |
from tts.modules.ar_dur.commons.nar_tts_modules import PosEmb | |
from tts.modules.ar_dur.commons.rot_transformer import RotTransformerDecoderLayer | |
from tts.modules.ar_dur.commons.transformer import SinusoidalPositionalEmbedding | |
from tts.modules.ar_dur.commons.rel_transformer import RelTransformerEncoder | |
FS_ENCODERS = { | |
'rel_fft': lambda hp, dict_size: RelTransformerEncoder( | |
dict_size, hp['hidden_size'], hp['hidden_size'], | |
hp['ffn_hidden_size'], hp['num_heads'], hp['enc_layers'], | |
hp['enc_ffn_kernel_size'], hp['dropout'], prenet=hp['enc_prenet'], pre_ln=hp['enc_pre_ln']), | |
} | |
def fill_with_neg_inf2(t): | |
"""FP16-compatible function that fills a tensor with -inf.""" | |
return t.float().fill_(-1e8).type_as(t) | |
def expand_states(h, mel2token): | |
h = F.pad(h, [0, 0, 1, 0]) | |
mel2token_ = mel2token[..., None].repeat([1, 1, h.shape[-1]]) | |
h = torch.gather(h, 1, mel2token_) # [B, T, H] | |
return h | |
class CodePredictor(nn.Module): | |
def __init__(self, hparams, hidden_size, dec_hidden_size, lm_num_layers, dict_size, code_size): | |
super().__init__() | |
self.hparams = deepcopy(hparams) | |
self.hparams['hidden_size'] = hidden_size | |
self.hidden_size = hidden_size | |
char_dict_size = hparams.get('char_dict_size', 4000) | |
if not hparams.get('lm_use_enc'): | |
self.encoder = nn.Embedding(dict_size, self.hidden_size, padding_idx=0) | |
if hparams.get('mega_use_char', True): | |
self.char_encoder = nn.Embedding(char_dict_size, | |
self.hidden_size, padding_idx=0) | |
else: | |
self.encoder = FS_ENCODERS[self.hparams['encoder_type']](self.hparams, dict_size) | |
if hparams.get('mega_use_char', True): | |
self.char_encoder = FS_ENCODERS[self.hparams['encoder_type']](self.hparams, char_dict_size) | |
if hparams['use_ph_pos_embed']: | |
self.ph_pos_embed = PosEmb(self.hidden_size) | |
self.char_empty_embed = nn.Embedding(1, self.hidden_size) | |
if hparams.get('use_bert_input'): | |
self.bert_input_proj = nn.Linear(768, self.hidden_size) | |
self.ling_label_embed_layers = nn.ModuleDict() | |
for k, s in zip(hparams['ling_labels'], hparams['ling_label_dict_size']): | |
self.ling_label_embed_layers[k] = Embedding(s + 3, self.hidden_size, padding_idx=0) | |
self.dec_hidden_size = dec_hidden_size | |
self.enc_proj = nn.Linear(self.hidden_size, dec_hidden_size) | |
self.code_emb = Embedding(code_size + 2, dec_hidden_size, 0) | |
self.use_pos_embed = hparams.get('use_pos_embed', False) | |
if self.use_pos_embed: | |
self.embed_positions = SinusoidalPositionalEmbedding(dec_hidden_size, 0, init_size=1024) | |
self.use_post_ln = hparams.get('use_post_ln', False) | |
self.layers = None | |
if not self.use_post_ln: | |
self.layer_norm = LayerNorm(dec_hidden_size) | |
self.code_size = code_size | |
self.project_out_dim = Linear(dec_hidden_size, code_size + 1, bias=True) | |
def forward_ling_encoder( | |
self, txt_tokens, ling_feas, char_tokens, ph2char, bert_embed, spk_id, spk_embed, mels_timbre): | |
ph_tokens = txt_tokens | |
hparams = self.hparams | |
ph_nonpadding = (ph_tokens > 0).float()[:, :, None] # [B, T_phone, 1] | |
x_spk = self.forward_style_embed(spk_embed, spk_id, mels_timbre) | |
# enc_ph | |
if not hparams.get('lm_use_enc'): | |
x_ph = self.encoder(ph_tokens) | |
x_ph = x_ph + sum( | |
[self.ling_label_embed_layers[k](ling_feas[k]) for k in hparams['ling_labels']]) \ | |
if len(hparams['ling_labels']) > 0 else 0 | |
x_ph = x_ph + x_spk | |
else: | |
# enc_ph | |
ph_enc_oembed = sum( | |
[self.ling_label_embed_layers[k](ling_feas[k]) for k in hparams['ling_labels']]) \ | |
if len(hparams['ling_labels']) > 0 else 0 | |
ph_enc_oembed = ph_enc_oembed + self.ph_pos_embed( | |
torch.arange(0, ph_tokens.shape[1])[None,].to(ph_tokens.device)) | |
ph_enc_oembed = ph_enc_oembed + x_spk | |
ph_enc_oembed = ph_enc_oembed * ph_nonpadding | |
x_ph = self.encoder(ph_tokens, other_embeds=ph_enc_oembed) | |
# enc_char | |
if char_tokens is not None and ph2char is not None: | |
char_nonpadding = (char_tokens > 0).float()[:, :, None] | |
x_char = self.char_encoder(char_tokens) | |
empty_char = (ph2char > 100000).long() | |
ph2char = ph2char * (1 - empty_char) | |
x_char_phlevel = \ | |
expand_states(x_char * char_nonpadding, ph2char) \ | |
* (1 - empty_char)[..., None] + \ | |
self.char_empty_embed(torch.zeros_like(ph_tokens)) * empty_char[..., None] | |
else: | |
x_char_phlevel = 0 | |
# x_ling | |
x_ling = x_ph + x_char_phlevel | |
x_ling = x_ling * ph_nonpadding | |
x_ling = self.enc_proj(x_ling) | |
return x_ling | |
def sample_one_step(self, vq_pred): | |
hparams = self.hparams | |
if hparams.get('infer_top_k'): | |
top_k = hparams.get('infer_top_k') | |
temperature = hparams.get('infer_temperature', 1) | |
vq_pred = vq_pred[:, -1] / temperature | |
# optionally crop the logits to only the top k options | |
if top_k is not None: | |
v, _ = torch.topk(vq_pred, min(top_k, vq_pred.size(-1))) | |
vq_pred[vq_pred < v[:, [-1]]] = -float('Inf') | |
# apply softmax to convert logits to (normalized) probabilities | |
probs = F.softmax(vq_pred, dim=-1) | |
# sample from the distribution | |
vq_pred = torch.multinomial(probs, num_samples=1) | |
else: | |
vq_pred = torch.argmax(F.softmax(vq_pred[:, -1], dim=-1), 1) | |
return vq_pred | |
def forward_style_embed(self, spk_embed=None, spk_id=None, mel_ref=None): | |
# add spk embed | |
style_embed = 0 | |
if self.hparams['use_spk_embed']: | |
style_embed = style_embed + self.spk_embed_proj(spk_embed)[:, None, :] | |
if self.hparams['use_spk_id']: | |
style_embed = style_embed + self.spk_id_proj(spk_id)[:, None, :] | |
if self.hparams['use_spk_enc']: | |
style_embed = style_embed + self.spk_enc(mel_ref)[:, None, :] | |
return style_embed | |
def buffered_future_mask(self, tensor): | |
dim = tensor.size(0) | |
if ( | |
not hasattr(self, '_future_mask') | |
or self._future_mask is None | |
or self._future_mask.device != tensor.device | |
or self._future_mask.size(0) < dim | |
): | |
self._future_mask = torch.triu(fill_with_neg_inf2(tensor.new(dim, dim)), 1) | |
return self._future_mask[:dim, :dim] | |
class ARDurPredictor(CodePredictor): | |
def __init__(self, hparams, hidden_size, dec_hidden_size, lm_num_layers, dict_size, code_size, use_rot_embed=True, | |
op_version=1): | |
super().__init__(hparams, hidden_size, dec_hidden_size, lm_num_layers, dict_size, code_size) | |
self.use_rot_embed = use_rot_embed | |
bias = hparams.get('lm_bias', True) | |
if self.use_rot_embed: | |
self.layers = nn.ModuleList([]) | |
self.layers.extend([ | |
RotTransformerDecoderLayer( | |
dec_hidden_size, 0.0, kernel_size=1, ffn_hidden_size=dec_hidden_size * 4, | |
post_ln=self.use_post_ln, op_version=op_version, bias=bias) | |
for _ in range(lm_num_layers) | |
]) | |
if hparams['dur_model_type'] == 'ar_mse': | |
self.project_out_dim = nn.Sequential(torch.nn.Linear(dec_hidden_size, 1), nn.Softplus()) | |
else: | |
self.project_out_dim = torch.nn.Linear(dec_hidden_size, code_size + 1) | |
def forward(self, txt_tokens, ling_feas, char_tokens, ph2char, bert_embed, | |
prev_code, spk_id=None, spk_embed=None, mels_timbre=None, mel2ph=None, | |
incremental_state=None, x_ling=None, attn_mask=None, spk_pos_ids_flat=None, | |
prompt_length=None, cache_size=20, streaming=False): | |
x = self.code_emb(prev_code) | |
if x_ling is None: | |
x_ling = self.forward_ling_encoder( | |
txt_tokens, ling_feas, char_tokens, ph2char, bert_embed, spk_id, spk_embed, mels_timbre) | |
x_ling = x_ling.flatten(0, 1) | |
txt_tokens = txt_tokens.flatten(0, 1) | |
x_ling = x_ling[txt_tokens > 0][None] | |
# run decoder | |
self_attn_padding_mask = None | |
if self.use_pos_embed: | |
positions = self.embed_positions( | |
prev_code, | |
incremental_state=incremental_state | |
) | |
if incremental_state is not None: | |
x_ling = x_ling[:, x.shape[1] - 1:x.shape[1]] | |
if spk_pos_ids_flat is not None: | |
spk_pos_ids_flat = spk_pos_ids_flat[:, x.shape[1] - 1:x.shape[1]] | |
x = x[:, -1:] | |
if self.use_pos_embed: | |
positions = positions[:, -1:] | |
if streaming: | |
# Shift Pos: query pos is min(cache_size, idx) | |
spk_pos_ids_flat = torch.min(torch.LongTensor([prompt_length + cache_size]).to(x.device), | |
spk_pos_ids_flat) | |
# # B x T x C -> T x B x C | |
if self.use_pos_embed: | |
x = x + positions | |
x_ling = x_ling[:, :self.hparams['max_tokens']].contiguous() | |
T = min(self.hparams.get('max_tokens_per_item', 1e9), x_ling.shape[1]) | |
x_ling = x_ling.reshape(-1, T, x_ling.shape[-1]) | |
x = x + x_ling | |
x = x.transpose(0, 1) | |
for idx, layer in enumerate(self.layers): | |
if incremental_state is None: | |
self_attn_mask = self.buffered_future_mask(x) | |
if attn_mask is not None: | |
self_attn_mask = self_attn_mask + (1 - attn_mask.float()) * -1e8 | |
self_attn_mask = self_attn_mask.clamp_min(-1e8) | |
else: | |
self_attn_mask = None | |
x, attn_weights = layer( | |
x, | |
incremental_state=incremental_state, | |
self_attn_mask=self_attn_mask, | |
self_attn_padding_mask=self_attn_padding_mask, | |
spk_pos_ids_flat=spk_pos_ids_flat | |
) | |
if streaming and incremental_state != {}: | |
for k, v in incremental_state.items(): | |
if 'attn_state' in k: | |
prev_key, prev_value = incremental_state[k]['prev_key'], incremental_state[k]['prev_value'] | |
cur_length = prev_key.shape[2] | |
if cur_length - prompt_length > cache_size: | |
prev_key = torch.cat((prev_key[:, :, :prompt_length], prev_key[:, :, -cache_size:]), dim=2) | |
prev_value = torch.cat((prev_value[:, :, :prompt_length], prev_value[:, :, -cache_size:]), | |
dim=2) | |
incremental_state[k]['prev_key'], incremental_state[k]['prev_value'] = prev_key, prev_value | |
if not self.use_post_ln: | |
x = self.layer_norm(x) | |
# T x B x C -> B x T x C | |
x = x.transpose(0, 1) | |
x = self.project_out_dim(x) | |
return x | |
def infer(self, txt_tokens, ling_feas, char_tokens, ph2char, bert_embed, | |
spk_id=None, spk_embed=None, mels_timbre=None, | |
incremental_state=None, ctx_vqcodes=None, spk_pos_ids_flat=None, return_state=False, | |
first_step_min=0, return_probs=False, first_decoder_inp=None, dur_disturb=0.0, **kwargs): | |
if incremental_state is None: | |
incremental_state = {} | |
x_ling = self.forward_ling_encoder( | |
txt_tokens, ling_feas, char_tokens, ph2char, bert_embed, | |
spk_id, spk_embed, mels_timbre) | |
x_ling = x_ling.flatten(0, 1) | |
txt_tokens_ori = txt_tokens | |
txt_tokens_withpad = txt_tokens = txt_tokens.flatten(0, 1) | |
x_ling = x_ling[txt_tokens > 0][None] | |
txt_tokens = txt_tokens[txt_tokens > 0][None] | |
decoded = torch.zeros_like(txt_tokens) | |
decoded = F.pad(decoded, [1, 0], value=self.code_size + 1) | |
if incremental_state != {}: | |
if first_decoder_inp is None: | |
assert ctx_vqcodes is not None | |
decoded[:, :ctx_vqcodes.shape[1]] = ctx_vqcodes | |
ctx_vqcodes = None | |
else: | |
decoded[:, :1] = first_decoder_inp | |
probs = [] | |
for step in range(decoded.shape[1] - 1): | |
vq_pred = self(txt_tokens, None, None, None, None, | |
decoded[:, :step + 1], None, None, None, | |
incremental_state=incremental_state, x_ling=x_ling, | |
spk_pos_ids_flat=spk_pos_ids_flat, **kwargs) | |
probs.append(vq_pred.cpu()) | |
if ctx_vqcodes is None or step >= ctx_vqcodes.shape[1]: | |
if self.hparams['dur_model_type'] == 'ar_mse': | |
d = vq_pred[:, -1, 0] | |
if dur_disturb > 0 and step >= 1: | |
if random.random() > 0.5: | |
d = d * (1 + random.random() * dur_disturb) | |
else: | |
d = d / (1 + random.random() * dur_disturb) | |
d = torch.clamp_max(d, self.code_size - 1) | |
vq_pred = torch.round(d).long() | |
else: | |
vq_pred = self.sample_one_step(vq_pred) | |
decoded[:, step + 1] = torch.clamp_min(vq_pred, 1) | |
if step == 0: | |
decoded[:, step + 1] = torch.clamp_min(vq_pred, first_step_min) | |
else: | |
decoded[:, step + 1] = ctx_vqcodes[:, step] | |
decoded = decoded[:, 1:] | |
decoded_2d = torch.zeros_like(txt_tokens_ori) | |
decoded_2d.flatten(0, 1)[txt_tokens_withpad > 0] = decoded | |
if return_state: | |
return decoded_2d, incremental_state | |
if return_probs: | |
return decoded_2d, torch.cat(probs, 1) | |
return decoded_2d | |
def streaming_infer(self, txt_tokens, ling_feas, char_tokens, ph2char, bert_embed, | |
spk_id=None, spk_embed=None, mels_timbre=None, | |
incremental_state=None, ctx_vqcodes=None, spk_pos_ids_flat=None, return_state=False, | |
**kwargs): | |
if incremental_state is None: | |
incremental_state = {} | |
x_ling = self.forward_ling_encoder( | |
txt_tokens, ling_feas, char_tokens, ph2char, bert_embed, | |
spk_id, spk_embed, mels_timbre) | |
x_ling = x_ling.flatten(0, 1) | |
txt_tokens_ori = txt_tokens | |
txt_tokens_withpad = txt_tokens = txt_tokens.flatten(0, 1) | |
x_ling = x_ling[txt_tokens > 0][None] | |
txt_tokens = txt_tokens[txt_tokens > 0][None] | |
vq_decoded = torch.zeros_like(txt_tokens) | |
vq_decoded = F.pad(vq_decoded, [1, 0], value=self.code_size + 1) | |
if incremental_state != {}: | |
assert ctx_vqcodes is not None | |
vq_decoded[:, :ctx_vqcodes.shape[1]] = ctx_vqcodes | |
ctx_vqcodes = None | |
prompt_length = list(incremental_state.items())[0][1]['prev_key'].shape[2] | |
for step in tqdm(range(vq_decoded.shape[1] - 1), desc='AR Duration Predictor inference...'): | |
vq_pred = self(txt_tokens, None, None, None, None, | |
vq_decoded[:, :step + 1], None, None, None, | |
incremental_state=incremental_state, x_ling=x_ling, | |
spk_pos_ids_flat=spk_pos_ids_flat, prompt_length=prompt_length, streaming=True, **kwargs) | |
if ctx_vqcodes is None or step >= ctx_vqcodes.shape[1]: | |
if self.hparams['dur_model_type'] == 'ar_mse': | |
vq_pred = torch.round(vq_pred[:, -1, 0]).long() | |
else: | |
vq_pred = self.sample_one_step(vq_pred) | |
vq_decoded[:, step + 1] = vq_pred | |
else: | |
vq_decoded[:, step + 1] = ctx_vqcodes[:, step] | |
vq_decoded = vq_decoded[:, 1:] | |
vq_decoded_2d = torch.zeros_like(txt_tokens_ori) | |
vq_decoded_2d.flatten(0, 1)[txt_tokens_withpad > 0] = vq_decoded | |
if return_state: | |
return vq_decoded_2d, incremental_state | |
return vq_decoded_2d |