MegaTTS3 / tts /modules /ar_dur /ar_dur_predictor.py
ZiyueJiang's picture
first commit for huggingface space
593f3bc
# Copyright 2025 ByteDance and/or its affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import random
from copy import deepcopy
import torch
import torch.nn.functional as F
from torch import nn
from torch.nn import Linear
from tqdm import tqdm
from tts.modules.ar_dur.commons.layers import Embedding, LayerNorm
from tts.modules.ar_dur.commons.nar_tts_modules import PosEmb
from tts.modules.ar_dur.commons.rot_transformer import RotTransformerDecoderLayer
from tts.modules.ar_dur.commons.transformer import SinusoidalPositionalEmbedding
from tts.modules.ar_dur.commons.rel_transformer import RelTransformerEncoder
FS_ENCODERS = {
'rel_fft': lambda hp, dict_size: RelTransformerEncoder(
dict_size, hp['hidden_size'], hp['hidden_size'],
hp['ffn_hidden_size'], hp['num_heads'], hp['enc_layers'],
hp['enc_ffn_kernel_size'], hp['dropout'], prenet=hp['enc_prenet'], pre_ln=hp['enc_pre_ln']),
}
def fill_with_neg_inf2(t):
"""FP16-compatible function that fills a tensor with -inf."""
return t.float().fill_(-1e8).type_as(t)
def expand_states(h, mel2token):
h = F.pad(h, [0, 0, 1, 0])
mel2token_ = mel2token[..., None].repeat([1, 1, h.shape[-1]])
h = torch.gather(h, 1, mel2token_) # [B, T, H]
return h
class CodePredictor(nn.Module):
def __init__(self, hparams, hidden_size, dec_hidden_size, lm_num_layers, dict_size, code_size):
super().__init__()
self.hparams = deepcopy(hparams)
self.hparams['hidden_size'] = hidden_size
self.hidden_size = hidden_size
char_dict_size = hparams.get('char_dict_size', 4000)
if not hparams.get('lm_use_enc'):
self.encoder = nn.Embedding(dict_size, self.hidden_size, padding_idx=0)
if hparams.get('mega_use_char', True):
self.char_encoder = nn.Embedding(char_dict_size,
self.hidden_size, padding_idx=0)
else:
self.encoder = FS_ENCODERS[self.hparams['encoder_type']](self.hparams, dict_size)
if hparams.get('mega_use_char', True):
self.char_encoder = FS_ENCODERS[self.hparams['encoder_type']](self.hparams, char_dict_size)
if hparams['use_ph_pos_embed']:
self.ph_pos_embed = PosEmb(self.hidden_size)
self.char_empty_embed = nn.Embedding(1, self.hidden_size)
if hparams.get('use_bert_input'):
self.bert_input_proj = nn.Linear(768, self.hidden_size)
self.ling_label_embed_layers = nn.ModuleDict()
for k, s in zip(hparams['ling_labels'], hparams['ling_label_dict_size']):
self.ling_label_embed_layers[k] = Embedding(s + 3, self.hidden_size, padding_idx=0)
self.dec_hidden_size = dec_hidden_size
self.enc_proj = nn.Linear(self.hidden_size, dec_hidden_size)
self.code_emb = Embedding(code_size + 2, dec_hidden_size, 0)
self.use_pos_embed = hparams.get('use_pos_embed', False)
if self.use_pos_embed:
self.embed_positions = SinusoidalPositionalEmbedding(dec_hidden_size, 0, init_size=1024)
self.use_post_ln = hparams.get('use_post_ln', False)
self.layers = None
if not self.use_post_ln:
self.layer_norm = LayerNorm(dec_hidden_size)
self.code_size = code_size
self.project_out_dim = Linear(dec_hidden_size, code_size + 1, bias=True)
def forward_ling_encoder(
self, txt_tokens, ling_feas, char_tokens, ph2char, bert_embed, spk_id, spk_embed, mels_timbre):
ph_tokens = txt_tokens
hparams = self.hparams
ph_nonpadding = (ph_tokens > 0).float()[:, :, None] # [B, T_phone, 1]
x_spk = self.forward_style_embed(spk_embed, spk_id, mels_timbre)
# enc_ph
if not hparams.get('lm_use_enc'):
x_ph = self.encoder(ph_tokens)
x_ph = x_ph + sum(
[self.ling_label_embed_layers[k](ling_feas[k]) for k in hparams['ling_labels']]) \
if len(hparams['ling_labels']) > 0 else 0
x_ph = x_ph + x_spk
else:
# enc_ph
ph_enc_oembed = sum(
[self.ling_label_embed_layers[k](ling_feas[k]) for k in hparams['ling_labels']]) \
if len(hparams['ling_labels']) > 0 else 0
ph_enc_oembed = ph_enc_oembed + self.ph_pos_embed(
torch.arange(0, ph_tokens.shape[1])[None,].to(ph_tokens.device))
ph_enc_oembed = ph_enc_oembed + x_spk
ph_enc_oembed = ph_enc_oembed * ph_nonpadding
x_ph = self.encoder(ph_tokens, other_embeds=ph_enc_oembed)
# enc_char
if char_tokens is not None and ph2char is not None:
char_nonpadding = (char_tokens > 0).float()[:, :, None]
x_char = self.char_encoder(char_tokens)
empty_char = (ph2char > 100000).long()
ph2char = ph2char * (1 - empty_char)
x_char_phlevel = \
expand_states(x_char * char_nonpadding, ph2char) \
* (1 - empty_char)[..., None] + \
self.char_empty_embed(torch.zeros_like(ph_tokens)) * empty_char[..., None]
else:
x_char_phlevel = 0
# x_ling
x_ling = x_ph + x_char_phlevel
x_ling = x_ling * ph_nonpadding
x_ling = self.enc_proj(x_ling)
return x_ling
def sample_one_step(self, vq_pred):
hparams = self.hparams
if hparams.get('infer_top_k'):
top_k = hparams.get('infer_top_k')
temperature = hparams.get('infer_temperature', 1)
vq_pred = vq_pred[:, -1] / temperature
# optionally crop the logits to only the top k options
if top_k is not None:
v, _ = torch.topk(vq_pred, min(top_k, vq_pred.size(-1)))
vq_pred[vq_pred < v[:, [-1]]] = -float('Inf')
# apply softmax to convert logits to (normalized) probabilities
probs = F.softmax(vq_pred, dim=-1)
# sample from the distribution
vq_pred = torch.multinomial(probs, num_samples=1)
else:
vq_pred = torch.argmax(F.softmax(vq_pred[:, -1], dim=-1), 1)
return vq_pred
def forward_style_embed(self, spk_embed=None, spk_id=None, mel_ref=None):
# add spk embed
style_embed = 0
if self.hparams['use_spk_embed']:
style_embed = style_embed + self.spk_embed_proj(spk_embed)[:, None, :]
if self.hparams['use_spk_id']:
style_embed = style_embed + self.spk_id_proj(spk_id)[:, None, :]
if self.hparams['use_spk_enc']:
style_embed = style_embed + self.spk_enc(mel_ref)[:, None, :]
return style_embed
def buffered_future_mask(self, tensor):
dim = tensor.size(0)
if (
not hasattr(self, '_future_mask')
or self._future_mask is None
or self._future_mask.device != tensor.device
or self._future_mask.size(0) < dim
):
self._future_mask = torch.triu(fill_with_neg_inf2(tensor.new(dim, dim)), 1)
return self._future_mask[:dim, :dim]
class ARDurPredictor(CodePredictor):
def __init__(self, hparams, hidden_size, dec_hidden_size, lm_num_layers, dict_size, code_size, use_rot_embed=True,
op_version=1):
super().__init__(hparams, hidden_size, dec_hidden_size, lm_num_layers, dict_size, code_size)
self.use_rot_embed = use_rot_embed
bias = hparams.get('lm_bias', True)
if self.use_rot_embed:
self.layers = nn.ModuleList([])
self.layers.extend([
RotTransformerDecoderLayer(
dec_hidden_size, 0.0, kernel_size=1, ffn_hidden_size=dec_hidden_size * 4,
post_ln=self.use_post_ln, op_version=op_version, bias=bias)
for _ in range(lm_num_layers)
])
if hparams['dur_model_type'] == 'ar_mse':
self.project_out_dim = nn.Sequential(torch.nn.Linear(dec_hidden_size, 1), nn.Softplus())
else:
self.project_out_dim = torch.nn.Linear(dec_hidden_size, code_size + 1)
def forward(self, txt_tokens, ling_feas, char_tokens, ph2char, bert_embed,
prev_code, spk_id=None, spk_embed=None, mels_timbre=None, mel2ph=None,
incremental_state=None, x_ling=None, attn_mask=None, spk_pos_ids_flat=None,
prompt_length=None, cache_size=20, streaming=False):
x = self.code_emb(prev_code)
if x_ling is None:
x_ling = self.forward_ling_encoder(
txt_tokens, ling_feas, char_tokens, ph2char, bert_embed, spk_id, spk_embed, mels_timbre)
x_ling = x_ling.flatten(0, 1)
txt_tokens = txt_tokens.flatten(0, 1)
x_ling = x_ling[txt_tokens > 0][None]
# run decoder
self_attn_padding_mask = None
if self.use_pos_embed:
positions = self.embed_positions(
prev_code,
incremental_state=incremental_state
)
if incremental_state is not None:
x_ling = x_ling[:, x.shape[1] - 1:x.shape[1]]
if spk_pos_ids_flat is not None:
spk_pos_ids_flat = spk_pos_ids_flat[:, x.shape[1] - 1:x.shape[1]]
x = x[:, -1:]
if self.use_pos_embed:
positions = positions[:, -1:]
if streaming:
# Shift Pos: query pos is min(cache_size, idx)
spk_pos_ids_flat = torch.min(torch.LongTensor([prompt_length + cache_size]).to(x.device),
spk_pos_ids_flat)
# # B x T x C -> T x B x C
if self.use_pos_embed:
x = x + positions
x_ling = x_ling[:, :self.hparams['max_tokens']].contiguous()
T = min(self.hparams.get('max_tokens_per_item', 1e9), x_ling.shape[1])
x_ling = x_ling.reshape(-1, T, x_ling.shape[-1])
x = x + x_ling
x = x.transpose(0, 1)
for idx, layer in enumerate(self.layers):
if incremental_state is None:
self_attn_mask = self.buffered_future_mask(x)
if attn_mask is not None:
self_attn_mask = self_attn_mask + (1 - attn_mask.float()) * -1e8
self_attn_mask = self_attn_mask.clamp_min(-1e8)
else:
self_attn_mask = None
x, attn_weights = layer(
x,
incremental_state=incremental_state,
self_attn_mask=self_attn_mask,
self_attn_padding_mask=self_attn_padding_mask,
spk_pos_ids_flat=spk_pos_ids_flat
)
if streaming and incremental_state != {}:
for k, v in incremental_state.items():
if 'attn_state' in k:
prev_key, prev_value = incremental_state[k]['prev_key'], incremental_state[k]['prev_value']
cur_length = prev_key.shape[2]
if cur_length - prompt_length > cache_size:
prev_key = torch.cat((prev_key[:, :, :prompt_length], prev_key[:, :, -cache_size:]), dim=2)
prev_value = torch.cat((prev_value[:, :, :prompt_length], prev_value[:, :, -cache_size:]),
dim=2)
incremental_state[k]['prev_key'], incremental_state[k]['prev_value'] = prev_key, prev_value
if not self.use_post_ln:
x = self.layer_norm(x)
# T x B x C -> B x T x C
x = x.transpose(0, 1)
x = self.project_out_dim(x)
return x
def infer(self, txt_tokens, ling_feas, char_tokens, ph2char, bert_embed,
spk_id=None, spk_embed=None, mels_timbre=None,
incremental_state=None, ctx_vqcodes=None, spk_pos_ids_flat=None, return_state=False,
first_step_min=0, return_probs=False, first_decoder_inp=None, dur_disturb=0.0, **kwargs):
if incremental_state is None:
incremental_state = {}
x_ling = self.forward_ling_encoder(
txt_tokens, ling_feas, char_tokens, ph2char, bert_embed,
spk_id, spk_embed, mels_timbre)
x_ling = x_ling.flatten(0, 1)
txt_tokens_ori = txt_tokens
txt_tokens_withpad = txt_tokens = txt_tokens.flatten(0, 1)
x_ling = x_ling[txt_tokens > 0][None]
txt_tokens = txt_tokens[txt_tokens > 0][None]
decoded = torch.zeros_like(txt_tokens)
decoded = F.pad(decoded, [1, 0], value=self.code_size + 1)
if incremental_state != {}:
if first_decoder_inp is None:
assert ctx_vqcodes is not None
decoded[:, :ctx_vqcodes.shape[1]] = ctx_vqcodes
ctx_vqcodes = None
else:
decoded[:, :1] = first_decoder_inp
probs = []
for step in range(decoded.shape[1] - 1):
vq_pred = self(txt_tokens, None, None, None, None,
decoded[:, :step + 1], None, None, None,
incremental_state=incremental_state, x_ling=x_ling,
spk_pos_ids_flat=spk_pos_ids_flat, **kwargs)
probs.append(vq_pred.cpu())
if ctx_vqcodes is None or step >= ctx_vqcodes.shape[1]:
if self.hparams['dur_model_type'] == 'ar_mse':
d = vq_pred[:, -1, 0]
if dur_disturb > 0 and step >= 1:
if random.random() > 0.5:
d = d * (1 + random.random() * dur_disturb)
else:
d = d / (1 + random.random() * dur_disturb)
d = torch.clamp_max(d, self.code_size - 1)
vq_pred = torch.round(d).long()
else:
vq_pred = self.sample_one_step(vq_pred)
decoded[:, step + 1] = torch.clamp_min(vq_pred, 1)
if step == 0:
decoded[:, step + 1] = torch.clamp_min(vq_pred, first_step_min)
else:
decoded[:, step + 1] = ctx_vqcodes[:, step]
decoded = decoded[:, 1:]
decoded_2d = torch.zeros_like(txt_tokens_ori)
decoded_2d.flatten(0, 1)[txt_tokens_withpad > 0] = decoded
if return_state:
return decoded_2d, incremental_state
if return_probs:
return decoded_2d, torch.cat(probs, 1)
return decoded_2d
def streaming_infer(self, txt_tokens, ling_feas, char_tokens, ph2char, bert_embed,
spk_id=None, spk_embed=None, mels_timbre=None,
incremental_state=None, ctx_vqcodes=None, spk_pos_ids_flat=None, return_state=False,
**kwargs):
if incremental_state is None:
incremental_state = {}
x_ling = self.forward_ling_encoder(
txt_tokens, ling_feas, char_tokens, ph2char, bert_embed,
spk_id, spk_embed, mels_timbre)
x_ling = x_ling.flatten(0, 1)
txt_tokens_ori = txt_tokens
txt_tokens_withpad = txt_tokens = txt_tokens.flatten(0, 1)
x_ling = x_ling[txt_tokens > 0][None]
txt_tokens = txt_tokens[txt_tokens > 0][None]
vq_decoded = torch.zeros_like(txt_tokens)
vq_decoded = F.pad(vq_decoded, [1, 0], value=self.code_size + 1)
if incremental_state != {}:
assert ctx_vqcodes is not None
vq_decoded[:, :ctx_vqcodes.shape[1]] = ctx_vqcodes
ctx_vqcodes = None
prompt_length = list(incremental_state.items())[0][1]['prev_key'].shape[2]
for step in tqdm(range(vq_decoded.shape[1] - 1), desc='AR Duration Predictor inference...'):
vq_pred = self(txt_tokens, None, None, None, None,
vq_decoded[:, :step + 1], None, None, None,
incremental_state=incremental_state, x_ling=x_ling,
spk_pos_ids_flat=spk_pos_ids_flat, prompt_length=prompt_length, streaming=True, **kwargs)
if ctx_vqcodes is None or step >= ctx_vqcodes.shape[1]:
if self.hparams['dur_model_type'] == 'ar_mse':
vq_pred = torch.round(vq_pred[:, -1, 0]).long()
else:
vq_pred = self.sample_one_step(vq_pred)
vq_decoded[:, step + 1] = vq_pred
else:
vq_decoded[:, step + 1] = ctx_vqcodes[:, step]
vq_decoded = vq_decoded[:, 1:]
vq_decoded_2d = torch.zeros_like(txt_tokens_ori)
vq_decoded_2d.flatten(0, 1)[txt_tokens_withpad > 0] = vq_decoded
if return_state:
return vq_decoded_2d, incremental_state
return vq_decoded_2d