Spaces:
Running
Running
# Copyright (c) 2024 Amphion. | |
# | |
# This code is modified from https://github.com/imdanboy/jets/blob/main/espnet2/gan_tts/jets/generator.py | |
# Licensed under Apache License 2.0 | |
import torch | |
import torch.nn as nn | |
import numpy as np | |
import torch.nn.functional as F | |
from modules.transformer.Models import Encoder, Decoder | |
from modules.transformer.Layers import PostNet | |
from collections import OrderedDict | |
from models.tts.jets.alignments import ( | |
AlignmentModule, | |
viterbi_decode, | |
average_by_duration, | |
make_pad_mask, | |
make_non_pad_mask, | |
get_random_segments, | |
) | |
from models.tts.jets.length_regulator import GaussianUpsampling | |
from models.vocoders.gan.generator.hifigan import HiFiGAN | |
import os | |
import json | |
from utils.util import load_config | |
def get_mask_from_lengths(lengths, max_len=None): | |
device = lengths.device | |
batch_size = lengths.shape[0] | |
if max_len is None: | |
max_len = torch.max(lengths).item() | |
ids = torch.arange(0, max_len).unsqueeze(0).expand(batch_size, -1).to(device) | |
mask = ids >= lengths.unsqueeze(1).expand(-1, max_len) | |
return mask | |
def pad(input_ele, mel_max_length=None): | |
if mel_max_length: | |
max_len = mel_max_length | |
else: | |
max_len = max([input_ele[i].size(0) for i in range(len(input_ele))]) | |
out_list = list() | |
for i, batch in enumerate(input_ele): | |
if len(batch.shape) == 1: | |
one_batch_padded = F.pad( | |
batch, (0, max_len - batch.size(0)), "constant", 0.0 | |
) | |
elif len(batch.shape) == 2: | |
one_batch_padded = F.pad( | |
batch, (0, 0, 0, max_len - batch.size(0)), "constant", 0.0 | |
) | |
out_list.append(one_batch_padded) | |
out_padded = torch.stack(out_list) | |
return out_padded | |
class VarianceAdaptor(nn.Module): | |
"""Variance Adaptor""" | |
def __init__(self, cfg): | |
super(VarianceAdaptor, self).__init__() | |
self.duration_predictor = VariancePredictor(cfg) | |
self.length_regulator = LengthRegulator() | |
self.pitch_predictor = VariancePredictor(cfg) | |
self.energy_predictor = VariancePredictor(cfg) | |
# assign the pitch/energy feature level | |
if cfg.preprocess.use_frame_pitch: | |
self.pitch_feature_level = "frame_level" | |
self.pitch_dir = cfg.preprocess.pitch_dir | |
else: | |
self.pitch_feature_level = "phoneme_level" | |
self.pitch_dir = cfg.preprocess.phone_pitch_dir | |
if cfg.preprocess.use_frame_energy: | |
self.energy_feature_level = "frame_level" | |
self.energy_dir = cfg.preprocess.energy_dir | |
else: | |
self.energy_feature_level = "phoneme_level" | |
self.energy_dir = cfg.preprocess.phone_energy_dir | |
assert self.pitch_feature_level in ["phoneme_level", "frame_level"] | |
assert self.energy_feature_level in ["phoneme_level", "frame_level"] | |
pitch_quantization = cfg.model.variance_embedding.pitch_quantization | |
energy_quantization = cfg.model.variance_embedding.energy_quantization | |
n_bins = cfg.model.variance_embedding.n_bins | |
assert pitch_quantization in ["linear", "log"] | |
assert energy_quantization in ["linear", "log"] | |
with open( | |
os.path.join( | |
cfg.preprocess.processed_dir, | |
cfg.dataset[0], | |
self.energy_dir, | |
"statistics.json", | |
) | |
) as f: | |
stats = json.load(f) | |
stats = stats[cfg.dataset[0] + "_" + cfg.dataset[0]] | |
mean, std = ( | |
stats["voiced_positions"]["mean"], | |
stats["voiced_positions"]["std"], | |
) | |
energy_min = (stats["total_positions"]["min"] - mean) / std | |
energy_max = (stats["total_positions"]["max"] - mean) / std | |
with open( | |
os.path.join( | |
cfg.preprocess.processed_dir, | |
cfg.dataset[0], | |
self.pitch_dir, | |
"statistics.json", | |
) | |
) as f: | |
stats = json.load(f) | |
stats = stats[cfg.dataset[0] + "_" + cfg.dataset[0]] | |
mean, std = ( | |
stats["voiced_positions"]["mean"], | |
stats["voiced_positions"]["std"], | |
) | |
pitch_min = (stats["total_positions"]["min"] - mean) / std | |
pitch_max = (stats["total_positions"]["max"] - mean) / std | |
if pitch_quantization == "log": | |
self.pitch_bins = nn.Parameter( | |
torch.exp( | |
torch.linspace(np.log(pitch_min), np.log(pitch_max), n_bins - 1) | |
), | |
requires_grad=False, | |
) | |
else: | |
self.pitch_bins = nn.Parameter( | |
torch.linspace(pitch_min, pitch_max, n_bins - 1), | |
requires_grad=False, | |
) | |
if energy_quantization == "log": | |
self.energy_bins = nn.Parameter( | |
torch.exp( | |
torch.linspace(np.log(energy_min), np.log(energy_max), n_bins - 1) | |
), | |
requires_grad=False, | |
) | |
else: | |
self.energy_bins = nn.Parameter( | |
torch.linspace(energy_min, energy_max, n_bins - 1), | |
requires_grad=False, | |
) | |
self.pitch_embedding = nn.Embedding( | |
n_bins, cfg.model.transformer.encoder_hidden | |
) | |
self.energy_embedding = nn.Embedding( | |
n_bins, cfg.model.transformer.encoder_hidden | |
) | |
def get_pitch_embedding(self, x, target, mask, control): | |
prediction = self.pitch_predictor(x, mask) | |
if target is not None: | |
embedding = self.pitch_embedding(torch.bucketize(target, self.pitch_bins)) | |
else: | |
prediction = prediction * control | |
embedding = self.pitch_embedding( | |
torch.bucketize(prediction, self.pitch_bins) | |
) | |
return prediction, embedding | |
def get_energy_embedding(self, x, target, mask, control): | |
prediction = self.energy_predictor(x, mask) | |
if target is not None: | |
embedding = self.energy_embedding(torch.bucketize(target, self.energy_bins)) | |
else: | |
prediction = prediction * control | |
embedding = self.energy_embedding( | |
torch.bucketize(prediction, self.energy_bins) | |
) | |
return prediction, embedding | |
def forward( | |
self, | |
x, | |
src_mask, | |
mel_mask=None, | |
max_len=None, | |
pitch_target=None, | |
energy_target=None, | |
duration_target=None, | |
p_control=1.0, | |
e_control=1.0, | |
d_control=1.0, | |
pitch_embedding=None, | |
energy_embedding=None, | |
): | |
log_duration_prediction = self.duration_predictor(x, src_mask) | |
x = x + pitch_embedding | |
x = x + energy_embedding | |
pitch_prediction = self.pitch_predictor(x, src_mask) | |
energy_prediction = self.energy_predictor(x, src_mask) | |
if duration_target is not None: | |
x, mel_len = self.length_regulator(x, duration_target, max_len) | |
duration_rounded = duration_target | |
else: | |
duration_rounded = torch.clamp( | |
(torch.round(torch.exp(log_duration_prediction) - 1) * d_control), | |
min=0, | |
) | |
x, mel_len = self.length_regulator(x, duration_rounded, max_len) | |
mel_mask = get_mask_from_lengths(mel_len) | |
return ( | |
x, | |
pitch_prediction, | |
energy_prediction, | |
log_duration_prediction, | |
duration_rounded, | |
mel_len, | |
mel_mask, | |
) | |
def inference( | |
self, | |
x, | |
src_mask, | |
mel_mask=None, | |
max_len=None, | |
pitch_target=None, | |
energy_target=None, | |
duration_target=None, | |
p_control=1.0, | |
e_control=1.0, | |
d_control=1.0, | |
pitch_embedding=None, | |
energy_embedding=None, | |
): | |
p_outs = self.pitch_predictor(x, src_mask) | |
e_outs = self.energy_predictor(x, src_mask) | |
d_outs = self.duration_predictor(x, src_mask) | |
return p_outs, e_outs, d_outs | |
class LengthRegulator(nn.Module): | |
"""Length Regulator""" | |
def __init__(self): | |
super(LengthRegulator, self).__init__() | |
def LR(self, x, duration, max_len): | |
device = x.device | |
output = list() | |
mel_len = list() | |
for batch, expand_target in zip(x, duration): | |
expanded = self.expand(batch, expand_target) | |
output.append(expanded) | |
mel_len.append(expanded.shape[0]) | |
if max_len is not None: | |
output = pad(output, max_len) | |
else: | |
output = pad(output) | |
return output, torch.LongTensor(mel_len).to(device) | |
def expand(self, batch, predicted): | |
out = list() | |
for i, vec in enumerate(batch): | |
expand_size = predicted[i].item() | |
out.append(vec.expand(max(int(expand_size), 0), -1)) | |
out = torch.cat(out, 0) | |
return out | |
def forward(self, x, duration, max_len): | |
output, mel_len = self.LR(x, duration, max_len) | |
return output, mel_len | |
class VariancePredictor(nn.Module): | |
"""Duration, Pitch and Energy Predictor""" | |
def __init__(self, cfg): | |
super(VariancePredictor, self).__init__() | |
self.input_size = cfg.model.transformer.encoder_hidden | |
self.filter_size = cfg.model.variance_predictor.filter_size | |
self.kernel = cfg.model.variance_predictor.kernel_size | |
self.conv_output_size = cfg.model.variance_predictor.filter_size | |
self.dropout = cfg.model.variance_predictor.dropout | |
self.conv_layer = nn.Sequential( | |
OrderedDict( | |
[ | |
( | |
"conv1d_1", | |
Conv( | |
self.input_size, | |
self.filter_size, | |
kernel_size=self.kernel, | |
padding=(self.kernel - 1) // 2, | |
), | |
), | |
("relu_1", nn.ReLU()), | |
("layer_norm_1", nn.LayerNorm(self.filter_size)), | |
("dropout_1", nn.Dropout(self.dropout)), | |
( | |
"conv1d_2", | |
Conv( | |
self.filter_size, | |
self.filter_size, | |
kernel_size=self.kernel, | |
padding=1, | |
), | |
), | |
("relu_2", nn.ReLU()), | |
("layer_norm_2", nn.LayerNorm(self.filter_size)), | |
("dropout_2", nn.Dropout(self.dropout)), | |
] | |
) | |
) | |
self.linear_layer = nn.Linear(self.conv_output_size, 1) | |
def forward(self, encoder_output, mask): | |
out = self.conv_layer(encoder_output) | |
out = self.linear_layer(out) | |
out = out.squeeze(-1) | |
if mask is not None: | |
out = out.masked_fill(mask, 0.0) | |
return out | |
class Conv(nn.Module): | |
""" | |
Convolution Module | |
""" | |
def __init__( | |
self, | |
in_channels, | |
out_channels, | |
kernel_size=1, | |
stride=1, | |
padding=0, | |
dilation=1, | |
bias=True, | |
w_init="linear", | |
): | |
""" | |
:param in_channels: dimension of input | |
:param out_channels: dimension of output | |
:param kernel_size: size of kernel | |
:param stride: size of stride | |
:param padding: size of padding | |
:param dilation: dilation rate | |
:param bias: boolean. if True, bias is included. | |
:param w_init: str. weight inits with xavier initialization. | |
""" | |
super(Conv, self).__init__() | |
self.conv = nn.Conv1d( | |
in_channels, | |
out_channels, | |
kernel_size=kernel_size, | |
stride=stride, | |
padding=padding, | |
dilation=dilation, | |
bias=bias, | |
) | |
def forward(self, x): | |
x = x.contiguous().transpose(1, 2) | |
x = self.conv(x) | |
x = x.contiguous().transpose(1, 2) | |
return x | |
class Jets(nn.Module): | |
def __init__(self, cfg) -> None: | |
super(Jets, self).__init__() | |
self.cfg = cfg | |
self.encoder = Encoder(cfg.model) | |
self.variance_adaptor = VarianceAdaptor(cfg) | |
self.decoder = Decoder(cfg.model) | |
self.length_regulator_infer = LengthRegulator() | |
self.mel_linear = nn.Linear( | |
cfg.model.transformer.decoder_hidden, | |
cfg.preprocess.n_mel, | |
) | |
self.postnet = PostNet(n_mel_channels=cfg.preprocess.n_mel) | |
self.speaker_emb = None | |
if cfg.train.multi_speaker_training: | |
with open( | |
os.path.join( | |
cfg.preprocess.processed_dir, cfg.dataset[0], "spk2id.json" | |
), | |
"r", | |
) as f: | |
n_speaker = len(json.load(f)) | |
self.speaker_emb = nn.Embedding( | |
n_speaker, | |
cfg.model.transformer.encoder_hidden, | |
) | |
output_dim = cfg.preprocess.n_mel | |
attention_dim = 256 | |
self.alignment_module = AlignmentModule(attention_dim, output_dim) | |
# NOTE(kan-bayashi): We use continuous pitch + FastPitch style avg | |
pitch_embed_kernel_size: int = 9 | |
pitch_embed_dropout: float = 0.5 | |
self.pitch_embed = torch.nn.Sequential( | |
torch.nn.Conv1d( | |
in_channels=1, | |
out_channels=attention_dim, | |
kernel_size=pitch_embed_kernel_size, | |
padding=(pitch_embed_kernel_size - 1) // 2, | |
), | |
torch.nn.Dropout(pitch_embed_dropout), | |
) | |
# NOTE(kan-bayashi): We use continuous enegy + FastPitch style avg | |
energy_embed_kernel_size: int = 9 | |
energy_embed_dropout: float = 0.5 | |
self.energy_embed = torch.nn.Sequential( | |
torch.nn.Conv1d( | |
in_channels=1, | |
out_channels=attention_dim, | |
kernel_size=energy_embed_kernel_size, | |
padding=(energy_embed_kernel_size - 1) // 2, | |
), | |
torch.nn.Dropout(energy_embed_dropout), | |
) | |
# define length regulator | |
self.length_regulator = GaussianUpsampling() | |
self.segment_size = cfg.train.segment_size | |
# Define HiFiGAN generator | |
hifi_cfg = load_config("egs/vocoder/gan/hifigan/exp_config.json") | |
# hifi_cfg.model.hifigan.resblock_kernel_sizes = [3, 7, 11] | |
hifi_cfg.preprocess.n_mel = attention_dim | |
self.generator = HiFiGAN(hifi_cfg) | |
def _source_mask(self, ilens: torch.Tensor) -> torch.Tensor: | |
"""Make masks for self-attention. | |
Args: | |
ilens (LongTensor): Batch of lengths (B,). | |
Returns: | |
Tensor: Mask tensor for self-attention. | |
dtype=torch.uint8 in PyTorch 1.2- | |
dtype=torch.bool in PyTorch 1.2+ (including 1.2) | |
Examples: | |
>>> ilens = [5, 3] | |
>>> self._source_mask(ilens) | |
tensor([[[1, 1, 1, 1, 1], | |
[1, 1, 1, 0, 0]]], dtype=torch.uint8) | |
""" | |
x_masks = make_non_pad_mask(ilens).to(next(self.parameters()).device) | |
return x_masks.unsqueeze(-2) | |
def forward(self, data, p_control=1.0, e_control=1.0, d_control=1.0): | |
speakers = data["spk_id"] | |
texts = data["texts"] | |
src_lens = data["text_len"] | |
max_src_len = max(src_lens) | |
feats = data["mel"] | |
mel_lens = data["target_len"] if "target_len" in data else None | |
feats_lengths = mel_lens | |
max_mel_len = max(mel_lens) if "target_len" in data else None | |
p_targets = data["pitch"] if "pitch" in data else None | |
e_targets = data["energy"] if "energy" in data else None | |
src_masks = get_mask_from_lengths(src_lens, max_src_len) | |
mel_masks = ( | |
get_mask_from_lengths(mel_lens, max_mel_len) | |
if mel_lens is not None | |
else None | |
) | |
output = self.encoder(texts, src_masks) | |
if self.speaker_emb is not None: | |
output = output + self.speaker_emb(speakers).unsqueeze(1).expand( | |
-1, max_src_len, -1 | |
) | |
# Forward alignment module and obtain duration, averaged pitch, energy | |
h_masks = make_pad_mask(src_lens).to(output.device) | |
log_p_attn = self.alignment_module( | |
output, feats, src_lens, feats_lengths, h_masks | |
) | |
ds, bin_loss = viterbi_decode(log_p_attn, src_lens, feats_lengths) | |
ps = average_by_duration( | |
ds, p_targets.squeeze(-1), src_lens, feats_lengths | |
).unsqueeze(-1) | |
es = average_by_duration( | |
ds, e_targets.squeeze(-1), src_lens, feats_lengths | |
).unsqueeze(-1) | |
p_embs = self.pitch_embed(ps.transpose(1, 2)).transpose(1, 2) | |
e_embs = self.energy_embed(es.transpose(1, 2)).transpose(1, 2) | |
# FastSpeech2 variance adaptor | |
( | |
output, | |
p_predictions, | |
e_predictions, | |
log_d_predictions, | |
d_rounded, | |
mel_lens, | |
mel_masks, | |
) = self.variance_adaptor( | |
output, | |
src_masks, | |
mel_masks, | |
max_mel_len, | |
p_targets, | |
e_targets, | |
ds, | |
p_control, | |
e_control, | |
d_control, | |
ps, | |
es, | |
) | |
# forward decoder | |
zs, _ = self.decoder(output, mel_masks) # (B, T_feats, adim) | |
# get random segments | |
z_segments, z_start_idxs = get_random_segments( | |
zs.transpose(1, 2), | |
feats_lengths, | |
self.segment_size, | |
) | |
# forward generator | |
wav = self.generator(z_segments) | |
return ( | |
wav, | |
bin_loss, | |
log_p_attn, | |
z_start_idxs, | |
log_d_predictions, | |
ds, | |
p_predictions, | |
ps, | |
e_predictions, | |
es, | |
src_lens, | |
feats_lengths, | |
) | |
def inference(self, data, p_control=1.0, e_control=1.0, d_control=1.0): | |
speakers = data["spk_id"] | |
texts = data["texts"] | |
src_lens = data["text_len"] | |
max_src_len = max(src_lens) | |
mel_lens = data["target_len"] if "target_len" in data else None | |
feats_lengths = mel_lens | |
max_mel_len = max(mel_lens) if "target_len" in data else None | |
p_targets = data["pitch"] if "pitch" in data else None | |
e_targets = data["energy"] if "energy" in data else None | |
d_targets = data["durations"] if "durations" in data else None | |
src_masks = get_mask_from_lengths(src_lens, max_src_len) | |
mel_masks = ( | |
get_mask_from_lengths(mel_lens, max_mel_len) | |
if mel_lens is not None | |
else None | |
) | |
x_masks = self._source_mask(src_lens) | |
hs = self.encoder(texts, src_masks) | |
( | |
p_outs, | |
e_outs, | |
d_outs, | |
) = self.variance_adaptor.inference( | |
hs, | |
src_masks, | |
) | |
p_embs = self.pitch_embed(p_outs.unsqueeze(-1).transpose(1, 2)).transpose(1, 2) | |
e_embs = self.energy_embed(e_outs.unsqueeze(-1).transpose(1, 2)).transpose(1, 2) | |
hs = hs + e_embs + p_embs | |
# Duration predictor inference mode: log_d_pred to d_pred | |
offset = 1.0 | |
d_predictions = torch.clamp( | |
torch.round(d_outs.exp() - offset), min=0 | |
).long() # avoid negative value | |
# forward decoder | |
hs, mel_len = self.length_regulator_infer(hs, d_predictions, max_mel_len) | |
mel_mask = get_mask_from_lengths(mel_len) | |
zs, _ = self.decoder(hs, mel_mask) # (B, T_feats, adim) | |
# forward generator | |
wav = self.generator(zs.transpose(1, 2)) | |
return wav, d_predictions | |