|
import os |
|
|
|
import random |
|
import hydra |
|
import numpy as np |
|
import librosa |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
import torch.optim as optim |
|
import pytorch_lightning as pl |
|
from vq import CodecEncoder, CodecDecoderVocos |
|
from module import HiFiGANMultiPeriodDiscriminator, SpecDiscriminator |
|
from criterions import GANLoss, MultiResolutionMelSpectrogramLoss, MultiResolutionSTFTLoss |
|
from common.schedulers import WarmupLR |
|
from transformers import AutoModel |
|
from vq.module import SemanticDecoder,SemanticEncoder |
|
from transformers import AutoFeatureExtractor, Wav2Vec2BertModel |
|
import sys |
|
|
|
|
|
|
|
|
|
|
|
from transformers import AutoModel, AutoFeatureExtractor |
|
|
|
|
|
class CodecLightningModule(pl.LightningModule): |
|
def __init__(self, cfg): |
|
super().__init__() |
|
self.cfg = cfg |
|
self.ocwd = hydra.utils.get_original_cwd() |
|
self.construct_model() |
|
self.construct_criteria() |
|
self.save_hyperparameters() |
|
self.automatic_optimization = False |
|
|
|
def construct_model(self): |
|
|
|
|
|
enccfg = self.cfg.model.codec_encoder |
|
|
|
|
|
self.CodecEnc = CodecEncoder( |
|
|
|
ngf=enccfg.ngf, |
|
up_ratios=enccfg.up_ratios, |
|
dilations=enccfg.dilations, |
|
hidden_dim=enccfg['hidden_dim'], |
|
depth=enccfg['depth'], |
|
heads=enccfg['heads'], |
|
pos_meb_dim=enccfg['pos_meb_dim'], |
|
) |
|
|
|
|
|
deccfg = self.cfg.model.codec_decoder |
|
|
|
self.generator = CodecDecoderVocos( |
|
hidden_dim=deccfg.hidden_dim, |
|
depth=deccfg.depth, |
|
heads=deccfg.heads, |
|
pos_meb_dim=deccfg.pos_meb_dim, |
|
hop_length=960, |
|
vq_num_quantizers=deccfg.vq_num_quantizers, |
|
vq_dim=deccfg.vq_dim, |
|
vq_commit_weight=deccfg.vq_commit_weight, |
|
vq_weight_init=deccfg.vq_weight_init, |
|
vq_full_commit_loss=deccfg.vq_full_commit_loss, |
|
codebook_size=deccfg.codebook_size, |
|
codebook_dim=deccfg.codebook_dim , |
|
|
|
) |
|
|
|
|
|
|
|
|
|
mpdcfg = self.cfg.model.mpd |
|
self.discriminator = HiFiGANMultiPeriodDiscriminator( |
|
periods=mpdcfg.periods, |
|
max_downsample_channels=mpdcfg.max_downsample_channels, |
|
channels=mpdcfg.channels, |
|
channel_increasing_factor=mpdcfg.channel_increasing_factor, |
|
) |
|
|
|
|
|
mstftcfg = self.cfg.model.mstft |
|
self.spec_discriminator = SpecDiscriminator( |
|
stft_params=mstftcfg.stft_params, |
|
in_channels=mstftcfg.in_channels, |
|
out_channels=mstftcfg.out_channels, |
|
kernel_sizes=mstftcfg.kernel_sizes, |
|
channels=mstftcfg.channels, |
|
max_downsample_channels=mstftcfg.max_downsample_channels, |
|
downsample_scales=mstftcfg.downsample_scales, |
|
use_weight_norm=mstftcfg.use_weight_norm, |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.speaker_model = AutoModel.from_pretrained("microsoft/wavlm-base-plus-sv") |
|
self.speaker_feature_extractor = AutoFeatureExtractor.from_pretrained("microsoft/wavlm-base-plus-sv") |
|
self.speaker_model.eval() |
|
self.speaker_model.requires_grad_(False) |
|
|
|
self.fc_prior = nn.Linear(1024 + 1024, deccfg.vq_dim, ) |
|
self.fc_post_a = nn.Linear(deccfg.vq_dim, deccfg.hidden_dim ) |
|
self.fc_post_s = nn.Linear(deccfg.vq_dim, 1024) |
|
|
|
self.SemanticDecoder_module = SemanticDecoder(1024, 1024, 1024) |
|
self.SemanticEncoder_module = SemanticEncoder(1024, 1024, 1024) |
|
self.semantic_model = Wav2Vec2BertModel.from_pretrained("facebook/w2v-bert-2.0", output_hidden_states=True) |
|
self.semantic_model.eval() |
|
self.semantic_model.requires_grad_(False) |
|
|
|
|
|
|
|
|
|
|
|
|
|
def construct_criteria(self): |
|
cfg = self.cfg.train |
|
self.criteria = nn.ModuleDict() |
|
if cfg.use_mel_loss: |
|
self.criteria['mel_loss'] = MultiResolutionMelSpectrogramLoss(sample_rate=self.cfg.preprocess.audio.sr) |
|
if cfg.use_stft_loss: |
|
self.criteria['stft_loss'] = MultiResolutionSTFTLoss( |
|
fft_sizes=cfg.stft_loss_params.fft_sizes, |
|
hop_sizes=cfg.stft_loss_params.hop_sizes, |
|
win_sizes=cfg.stft_loss_params.win_lengths |
|
) |
|
if cfg.use_feat_match_loss: |
|
self.criteria['fm_loss'] = nn.L1Loss() |
|
self.criteria['gan_loss'] = GANLoss() |
|
self.criteria['l1_loss'] = nn.L1Loss() |
|
self.criteria['l2_loss'] = nn.MSELoss() |
|
print(self.criteria) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def forward(self, batch): |
|
wav = batch['wav'] |
|
feats = batch['feats'] |
|
|
|
vq_emb = self.CodecEnc(wav.unsqueeze(1)) |
|
|
|
with torch.no_grad(): |
|
semantic_target = self.semantic_model(feats) |
|
semantic_target = semantic_target.hidden_states[16].detach() |
|
|
|
T_codec = vq_emb.shape[1] |
|
T_semantic = semantic_target.shape[1] |
|
|
|
|
|
semantic_target_for_loss = semantic_target.clone() |
|
|
|
if T_codec != T_semantic: |
|
semantic_target = F.interpolate( |
|
semantic_target.transpose(1, 2), |
|
size=T_codec, |
|
mode='linear', |
|
align_corners=False |
|
).transpose(1, 2) |
|
|
|
semantic_target_transposed = semantic_target.transpose(1, 2) |
|
semantic_target_processed = self.SemanticEncoder_module(semantic_target_transposed) |
|
semantic_target_processed = semantic_target_processed.transpose(1, 2) |
|
|
|
vq_emb = torch.cat([semantic_target_processed, vq_emb], dim=2) |
|
vq_emb = self.fc_prior(vq_emb) |
|
|
|
vq_emb = vq_emb.transpose(1, 2) |
|
vq_post_emb, vq_code, vq_loss = self.generator(vq_emb, vq=True) |
|
|
|
vq_post_emb_t = vq_post_emb.transpose(1, 2) |
|
|
|
semantic_recon = self.fc_post_s(vq_post_emb_t) |
|
|
|
semantic_recon_transposed = semantic_recon.transpose(1, 2) |
|
semantic_recon = self.SemanticDecoder_module(semantic_recon_transposed) |
|
semantic_recon = semantic_recon.transpose(1, 2) |
|
|
|
|
|
if T_codec != T_semantic: |
|
semantic_recon_for_loss = F.interpolate( |
|
semantic_recon.transpose(1, 2), |
|
size=T_semantic, |
|
mode='linear', |
|
align_corners=False |
|
).transpose(1, 2) |
|
else: |
|
semantic_recon_for_loss = semantic_recon |
|
|
|
|
|
gen_input = self.fc_post_a(vq_post_emb_t) |
|
y_, _ = self.generator(gen_input.transpose(1, 2), vq=False) |
|
y = wav.unsqueeze(1) |
|
|
|
output = { |
|
'gt_wav': y, |
|
'gen_wav': y_, |
|
'vq_loss': vq_loss, |
|
'vq_code': vq_code, |
|
'semantic_recon_loss': F.mse_loss(semantic_recon_for_loss, semantic_target_for_loss), |
|
} |
|
return output |
|
|
|
@torch.inference_mode() |
|
def inference(self, wav): |
|
vq_emb = self.CodecEnc(wav.unsqueeze(1)) |
|
vq_post_emb, vq_code, vq_loss = self.generator(vq_emb, vq=True) |
|
y_ = self.generator(vq_post_emb, vq=False).squeeze(1) |
|
return y_ |
|
|
|
def compute_disc_loss(self, batch, output): |
|
y, y_ = output['gt_wav'], output['gen_wav'] |
|
y_ = y_.detach() |
|
p = self.discriminator(y) |
|
p_ = self.discriminator(y_) |
|
|
|
real_loss_list, fake_loss_list = [], [] |
|
for i in range(len(p)): |
|
real_loss, fake_loss = self.criteria['gan_loss'].disc_loss(p[i][-1], p_[i][-1]) |
|
real_loss_list.append(real_loss) |
|
fake_loss_list.append(fake_loss) |
|
|
|
if hasattr(self, 'spec_discriminator'): |
|
sd_p = self.spec_discriminator(y) |
|
sd_p_ = self.spec_discriminator(y_) |
|
|
|
for i in range(len(sd_p)): |
|
real_loss, fake_loss = self.criteria['gan_loss'].disc_loss(sd_p[i][-1], sd_p_[i][-1]) |
|
real_loss_list.append(real_loss) |
|
fake_loss_list.append(fake_loss) |
|
|
|
real_loss = sum(real_loss_list) |
|
fake_loss = sum(fake_loss_list) |
|
|
|
disc_loss = real_loss + fake_loss |
|
disc_loss = self.cfg.train.lambdas.lambda_disc * disc_loss |
|
|
|
output = { |
|
'real_loss': real_loss, |
|
'fake_loss': fake_loss, |
|
'disc_loss': disc_loss, |
|
} |
|
return output |
|
|
|
def compute_gen_loss(self, batch, output): |
|
y, y_ = output['gt_wav'], output['gen_wav'] |
|
vq_loss, vq_code = output['vq_loss'], output['vq_code'] |
|
semantic_recon_loss = output['semantic_recon_loss'] |
|
|
|
|
|
gen_loss = 0.0 |
|
self.set_discriminator_gradients(False) |
|
output_dict = {} |
|
cfg = self.cfg.train |
|
|
|
|
|
if cfg.use_mel_loss: |
|
mel_loss = self.criteria['mel_loss'](y_.squeeze(1), y.squeeze(1)) |
|
gen_loss += mel_loss * cfg.lambdas.lambda_mel_loss |
|
output_dict['mel_loss'] = mel_loss |
|
|
|
|
|
p_ = self.discriminator(y_) |
|
adv_loss_list = [] |
|
for i in range(len(p_)): |
|
adv_loss_list.append(self.criteria['gan_loss'].gen_loss(p_[i][-1])) |
|
if hasattr(self, 'spec_discriminator'): |
|
sd_p_ = self.spec_discriminator(y_) |
|
for i in range(len(sd_p_)): |
|
adv_loss_list.append(self.criteria['gan_loss'].gen_loss(sd_p_[i][-1])) |
|
adv_loss = sum(adv_loss_list) |
|
gen_loss += adv_loss * cfg.lambdas.lambda_adv |
|
output_dict['adv_loss'] = adv_loss |
|
|
|
|
|
if cfg.use_feat_match_loss: |
|
fm_loss = 0.0 |
|
with torch.no_grad(): |
|
p = self.discriminator(y) |
|
for i in range(len(p_)): |
|
for j in range(len(p_[i]) - 1): |
|
fm_loss += self.criteria['fm_loss'](p_[i][j], p[i][j].detach()) |
|
gen_loss += fm_loss * cfg.lambdas.lambda_feat_match_loss |
|
output_dict['fm_loss'] = fm_loss |
|
if hasattr(self, 'spec_discriminator'): |
|
spec_fm_loss = 0.0 |
|
with torch.no_grad(): |
|
sd_p = self.spec_discriminator(y) |
|
for i in range(len(sd_p_)): |
|
for j in range(len(sd_p_[i]) - 1): |
|
spec_fm_loss += self.criteria['fm_loss'](sd_p_[i][j], sd_p[i][j].detach()) |
|
gen_loss += spec_fm_loss * cfg.lambdas.lambda_feat_match_loss |
|
output_dict['spec_fm_loss'] = spec_fm_loss |
|
|
|
|
|
if vq_loss is not None: |
|
vq_loss = sum(vq_loss) |
|
gen_loss += vq_loss |
|
output_dict['vq_loss'] = vq_loss |
|
|
|
|
|
output_dict['semantic_recon_loss'] = semantic_recon_loss |
|
gen_loss += output_dict['semantic_recon_loss'] * cfg.lambdas.lambda_semantic_loss |
|
|
|
|
|
|
|
|
|
|
|
self.set_discriminator_gradients(True) |
|
output_dict['gen_loss'] = gen_loss |
|
return output_dict |
|
|
|
def training_step(self, batch, batch_idx): |
|
output = self(batch) |
|
|
|
gen_opt, disc_opt = self.optimizers() |
|
gen_sche, disc_sche = self.lr_schedulers() |
|
|
|
|
|
disc_losses = self.compute_disc_loss(batch, output) |
|
disc_loss = disc_losses['disc_loss'] |
|
disc_opt.zero_grad() |
|
self.manual_backward(disc_loss) |
|
self.clip_gradients( |
|
disc_opt, |
|
gradient_clip_val=self.cfg.train.disc_grad_clip, |
|
gradient_clip_algorithm='norm' |
|
) |
|
disc_opt.step() |
|
disc_sche.step() |
|
|
|
|
|
gen_losses = self.compute_gen_loss(batch, output) |
|
gen_loss = gen_losses['gen_loss'] |
|
gen_opt.zero_grad() |
|
self.manual_backward(gen_loss) |
|
self.clip_gradients( |
|
gen_opt, |
|
gradient_clip_val=self.cfg.train.gen_grad_clip, |
|
gradient_clip_algorithm='norm' |
|
) |
|
gen_opt.step() |
|
gen_sche.step() |
|
|
|
|
|
self.log_dict( |
|
disc_losses, |
|
on_step=True, |
|
on_epoch=True, |
|
prog_bar=True, |
|
logger=True, |
|
batch_size=self.cfg.dataset.train.batch_size, |
|
sync_dist=True |
|
) |
|
self.log_dict( |
|
gen_losses, |
|
on_step=True, |
|
on_epoch=True, |
|
prog_bar=True, |
|
logger=True, |
|
batch_size=self.cfg.dataset.train.batch_size, |
|
sync_dist=True |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def validation_step(self, batch, batch_idx): |
|
output = self(batch) |
|
y = output['gt_wav'] |
|
y_ = output['gen_wav'] |
|
|
|
|
|
y_audio = y.squeeze(1).cpu().numpy() |
|
y_recon_audio = y_.squeeze(1).cpu().numpy() |
|
|
|
embeddings1_list = [] |
|
embeddings2_list = [] |
|
|
|
|
|
for i in range(y_audio.shape[0]): |
|
|
|
y_16k = librosa.resample(y_audio[i], orig_sr=self.cfg.preprocess.audio.sr, target_sr=16000) |
|
y_recon_16k = librosa.resample(y_recon_audio[i], orig_sr=self.cfg.preprocess.audio.sr, target_sr=16000) |
|
|
|
|
|
inputs1 = self.speaker_feature_extractor( |
|
y_16k, |
|
sampling_rate=16000, |
|
return_tensors="pt" |
|
).to(self.device) |
|
|
|
inputs2 = self.speaker_feature_extractor( |
|
y_recon_16k, |
|
sampling_rate=16000, |
|
return_tensors="pt" |
|
).to(self.device) |
|
|
|
|
|
with torch.no_grad(): |
|
outputs1 = self.speaker_model(**inputs1) |
|
outputs2 = self.speaker_model(**inputs2) |
|
|
|
|
|
embedding1 = torch.mean(outputs1.last_hidden_state, dim=1) |
|
embedding2 = torch.mean(outputs2.last_hidden_state, dim=1) |
|
|
|
|
|
embedding1 = F.normalize(embedding1, p=2, dim=1) |
|
embedding2 = F.normalize(embedding2, p=2, dim=1) |
|
|
|
embeddings1_list.append(embedding1) |
|
embeddings2_list.append(embedding2) |
|
|
|
|
|
embeddings1 = torch.cat(embeddings1_list, dim=0) |
|
embeddings2 = torch.cat(embeddings2_list, dim=0) |
|
|
|
|
|
sim = F.cosine_similarity(embeddings1, embeddings2) |
|
sim = sim.mean() |
|
|
|
self.log('val/sim', sim, on_step=False, on_epoch=True, prog_bar=True, logger=True) |
|
|
|
return {'sim': sim} |
|
|
|
|
|
|
|
def test_step(self, batch, batch_idx): |
|
|
|
pass |
|
|
|
def configure_optimizers(self): |
|
from itertools import chain |
|
|
|
|
|
disc_params = self.discriminator.parameters() |
|
|
|
disc_params = chain(disc_params, self.spec_discriminator.parameters()) |
|
|
|
|
|
gen_params = chain( |
|
self.CodecEnc.parameters(), |
|
self.generator.parameters(), |
|
|
|
self.fc_prior.parameters(), |
|
self.fc_post_a.parameters(), |
|
self.fc_post_s.parameters(), |
|
self.SemanticDecoder_module.parameters(), |
|
self.SemanticEncoder_module.parameters() |
|
) |
|
|
|
|
|
gen_opt = optim.AdamW(gen_params, **self.cfg.train.gen_optim_params) |
|
disc_opt = optim.AdamW(disc_params, **self.cfg.train.disc_optim_params) |
|
|
|
|
|
gen_sche = WarmupLR(gen_opt, **self.cfg.train.gen_schedule_params) |
|
disc_sche = WarmupLR(disc_opt, **self.cfg.train.disc_schedule_params) |
|
|
|
print(f'Generator optim: {gen_opt}') |
|
print(f'Discriminator optim: {disc_opt}') |
|
|
|
return [gen_opt, disc_opt], [gen_sche, disc_sche] |
|
|
|
def set_discriminator_gradients(self, flag=True): |
|
for p in self.discriminator.parameters(): |
|
p.requires_grad = flag |
|
|
|
if hasattr(self, 'spec_discriminator'): |
|
for p in self.spec_discriminator.parameters(): |
|
p.requires_grad = flag |
|
|