XCodec2_24khz / lightning_module.py
Respair's picture
Upload folder using huggingface_hub
59b7eeb verified
import os
import random
import hydra
import numpy as np
import librosa
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import pytorch_lightning as pl
from vq import CodecEncoder, CodecDecoderVocos
from module import HiFiGANMultiPeriodDiscriminator, SpecDiscriminator
from criterions import GANLoss, MultiResolutionMelSpectrogramLoss, MultiResolutionSTFTLoss
from common.schedulers import WarmupLR
from transformers import AutoModel
from vq.module import SemanticDecoder,SemanticEncoder
from transformers import AutoFeatureExtractor, Wav2Vec2BertModel
import sys
# sys.path.append('/home/ubuntu/X-Codec-2.0/UniSpeech/downstreams/speaker_verification') # We use wavlm_large_finetune as a vadidation metric during training, https://github.com/microsoft/UniSpeech/tree/main/downstreams/speaker_verification
# from verification import init_model
# model_spk = init_model('wavlm_large','/aifs4su/data/zheny/models_fd_ckpt/wavlm_large_finetune.pth')
from transformers import AutoModel, AutoFeatureExtractor
class CodecLightningModule(pl.LightningModule):
def __init__(self, cfg):
super().__init__()
self.cfg = cfg
self.ocwd = hydra.utils.get_original_cwd()
self.construct_model()
self.construct_criteria()
self.save_hyperparameters()
self.automatic_optimization = False
def construct_model(self):
# 初始化 Codec Encoder
enccfg = self.cfg.model.codec_encoder
self.CodecEnc = CodecEncoder(
ngf=enccfg.ngf,
up_ratios=enccfg.up_ratios,
dilations=enccfg.dilations,
hidden_dim=enccfg['hidden_dim'],
depth=enccfg['depth'],
heads=enccfg['heads'],
pos_meb_dim=enccfg['pos_meb_dim'],
)
# 初始化 Codec Decoder
deccfg = self.cfg.model.codec_decoder
self.generator = CodecDecoderVocos(
hidden_dim=deccfg.hidden_dim,
depth=deccfg.depth,
heads=deccfg.heads,
pos_meb_dim=deccfg.pos_meb_dim,
hop_length=960,
vq_num_quantizers=deccfg.vq_num_quantizers, # VQ 量化器数量
vq_dim=deccfg.vq_dim, # VQ 维度
vq_commit_weight=deccfg.vq_commit_weight, # VQ 提交权重
vq_weight_init=deccfg.vq_weight_init, # VQ 权重初始化
vq_full_commit_loss=deccfg.vq_full_commit_loss, # 是否使用完整的提交损失
codebook_size=deccfg.codebook_size, # 码本大小
codebook_dim=deccfg.codebook_dim , # 码本维度
# 隐藏层维度
)
# 初始化 MultiPeriod Discriminator
mpdcfg = self.cfg.model.mpd
self.discriminator = HiFiGANMultiPeriodDiscriminator(
periods=mpdcfg.periods,
max_downsample_channels=mpdcfg.max_downsample_channels,
channels=mpdcfg.channels,
channel_increasing_factor=mpdcfg.channel_increasing_factor,
)
# 初始化 Spectral Discriminator
mstftcfg = self.cfg.model.mstft
self.spec_discriminator = SpecDiscriminator(
stft_params=mstftcfg.stft_params,
in_channels=mstftcfg.in_channels,
out_channels=mstftcfg.out_channels,
kernel_sizes=mstftcfg.kernel_sizes,
channels=mstftcfg.channels,
max_downsample_channels=mstftcfg.max_downsample_channels,
downsample_scales=mstftcfg.downsample_scales,
use_weight_norm=mstftcfg.use_weight_norm,
)
# 单独编译需要优化的子模块
# self.CodecEnc = torch.compile(self.CodecEnc)
# self.generator.backbone = torch.compile(self.generator )
# self.mel_conv = torch.compile(self.mel_conv)
# self.model_spk = model_spk .eval()
# self.semantic_model = AutoModel.from_pretrained("microsoft/wavlm-large")
# self.semantic_model.eval()
# self.semantic_model.requires_grad_(False)
self.speaker_model = AutoModel.from_pretrained("microsoft/wavlm-base-plus-sv")
self.speaker_feature_extractor = AutoFeatureExtractor.from_pretrained("microsoft/wavlm-base-plus-sv")
self.speaker_model.eval()
self.speaker_model.requires_grad_(False)
self.fc_prior = nn.Linear(1024 + 1024, deccfg.vq_dim, )
self.fc_post_a = nn.Linear(deccfg.vq_dim, deccfg.hidden_dim )
self.fc_post_s = nn.Linear(deccfg.vq_dim, 1024)
self.SemanticDecoder_module = SemanticDecoder(1024, 1024, 1024)
self.SemanticEncoder_module = SemanticEncoder(1024, 1024, 1024)
self.semantic_model = Wav2Vec2BertModel.from_pretrained("facebook/w2v-bert-2.0", output_hidden_states=True)
self.semantic_model.eval()
self.semantic_model.requires_grad_(False)
# self.register_buffer('mel_basis', mel_basis)
# self.perception_model = AutoModel.from_pretrained("facebook/wav2vec2-large-xlsr-53")
# self.perception_model.eval()
# self.perception_model.requires_grad_(False)
def construct_criteria(self):
cfg = self.cfg.train
self.criteria = nn.ModuleDict()
if cfg.use_mel_loss:
self.criteria['mel_loss'] = MultiResolutionMelSpectrogramLoss(sample_rate=self.cfg.preprocess.audio.sr)
if cfg.use_stft_loss:
self.criteria['stft_loss'] = MultiResolutionSTFTLoss(
fft_sizes=cfg.stft_loss_params.fft_sizes,
hop_sizes=cfg.stft_loss_params.hop_sizes,
win_sizes=cfg.stft_loss_params.win_lengths
)
if cfg.use_feat_match_loss:
self.criteria['fm_loss'] = nn.L1Loss()
self.criteria['gan_loss'] = GANLoss()
self.criteria['l1_loss'] = nn.L1Loss()
self.criteria['l2_loss'] = nn.MSELoss()
print(self.criteria)
# def forward(self, batch):
# wav = batch['wav']
# feats= batch['feats']
# vq_emb = self.CodecEnc(wav.unsqueeze(1))
# vq_emb = vq_emb.transpose(1, 2)
# with torch.no_grad():
# semantic_target = self.semantic_model(feats[:,0,:,:])
# semantic_target = semantic_target.hidden_states[16]
# semantic_target = semantic_target.detach()
# semantic_target = semantic_target.transpose(1, 2)
# semantic_target_processed = self.SemanticEncoder_module(semantic_target)
# # 拼接语义嵌入和编码器输出
# vq_emb = torch.cat([semantic_target_processed, vq_emb], dim=1)
# vq_emb = self.fc_prior(vq_emb.transpose(1, 2)).transpose(1, 2)
# vq_post_emb, vq_code, vq_loss = self.generator(vq_emb, vq=True)
# semantic_recon = self.fc_post_s(vq_post_emb.transpose(1, 2)).transpose(1, 2)
# semantic_recon = self.SemanticDecoder_module(semantic_recon)
# y_ ,_ = self.generator(
# self.fc_post_a(vq_post_emb.transpose(1, 2)) ,
# vq=False
# )
# y = wav.unsqueeze(1)
# # gt_perceptual = self.perception_model(wav.squeeze(1), output_hidden_states=True) .hidden_states
# # gen_perceptual = self.perception_model(y_.squeeze(1), output_hidden_states=True) .hidden_states
# # gt_perceptual_se = gt_perceptual[10:22]
# # gen_perceptual_se = gen_perceptual[10:22]
# # perceptual_se_loss = [tensor1 - tensor2 for tensor1, tensor2 in zip(gt_perceptual_se, gen_perceptual_se)]
# # # 使用列表推导式逐元素相减
# # perceptual_se_loss_l2 = [F.mse_loss(tensor1.detach(), tensor2) for tensor1, tensor2 in zip(gt_perceptual_se, gen_perceptual_se)]
# # perceptual_se_loss_l2 =torch.stack(perceptual_se_loss_l2).mean()
# output = {
# 'gt_wav': y,
# 'gen_wav': y_,
# 'vq_loss': vq_loss,
# 'vq_code': vq_code,
# 'semantic_recon_loss': F.mse_loss(semantic_recon, semantic_target),
# # 'perceptual_se_loss_l2': perceptual_se_loss_l2,
# }
# return output
def forward(self, batch):
wav = batch['wav']
feats = batch['feats']
vq_emb = self.CodecEnc(wav.unsqueeze(1))
with torch.no_grad():
semantic_target = self.semantic_model(feats)
semantic_target = semantic_target.hidden_states[16].detach()
T_codec = vq_emb.shape[1]
T_semantic = semantic_target.shape[1]
semantic_target_for_loss = semantic_target.clone()
if T_codec != T_semantic:
semantic_target = F.interpolate(
semantic_target.transpose(1, 2),
size=T_codec,
mode='linear',
align_corners=False
).transpose(1, 2)
semantic_target_transposed = semantic_target.transpose(1, 2)
semantic_target_processed = self.SemanticEncoder_module(semantic_target_transposed)
semantic_target_processed = semantic_target_processed.transpose(1, 2)
vq_emb = torch.cat([semantic_target_processed, vq_emb], dim=2)
vq_emb = self.fc_prior(vq_emb)
vq_emb = vq_emb.transpose(1, 2)
vq_post_emb, vq_code, vq_loss = self.generator(vq_emb, vq=True)
vq_post_emb_t = vq_post_emb.transpose(1, 2)
semantic_recon = self.fc_post_s(vq_post_emb_t)
semantic_recon_transposed = semantic_recon.transpose(1, 2)
semantic_recon = self.SemanticDecoder_module(semantic_recon_transposed)
semantic_recon = semantic_recon.transpose(1, 2)
# Interpolate back for loss calculation
if T_codec != T_semantic:
semantic_recon_for_loss = F.interpolate(
semantic_recon.transpose(1, 2),
size=T_semantic,
mode='linear',
align_corners=False
).transpose(1, 2)
else:
semantic_recon_for_loss = semantic_recon
# Audio generation
gen_input = self.fc_post_a(vq_post_emb_t)
y_, _ = self.generator(gen_input.transpose(1, 2), vq=False)
y = wav.unsqueeze(1)
output = {
'gt_wav': y,
'gen_wav': y_,
'vq_loss': vq_loss,
'vq_code': vq_code,
'semantic_recon_loss': F.mse_loss(semantic_recon_for_loss, semantic_target_for_loss),
}
return output
@torch.inference_mode()
def inference(self, wav):
vq_emb = self.CodecEnc(wav.unsqueeze(1))
vq_post_emb, vq_code, vq_loss = self.generator(vq_emb, vq=True)
y_ = self.generator(vq_post_emb, vq=False).squeeze(1) # [B, T]
return y_
def compute_disc_loss(self, batch, output):
y, y_ = output['gt_wav'], output['gen_wav']
y_ = y_.detach()
p = self.discriminator(y)
p_ = self.discriminator(y_)
real_loss_list, fake_loss_list = [], []
for i in range(len(p)):
real_loss, fake_loss = self.criteria['gan_loss'].disc_loss(p[i][-1], p_[i][-1])
real_loss_list.append(real_loss)
fake_loss_list.append(fake_loss)
if hasattr(self, 'spec_discriminator'):
sd_p = self.spec_discriminator(y)
sd_p_ = self.spec_discriminator(y_)
for i in range(len(sd_p)):
real_loss, fake_loss = self.criteria['gan_loss'].disc_loss(sd_p[i][-1], sd_p_[i][-1])
real_loss_list.append(real_loss)
fake_loss_list.append(fake_loss)
real_loss = sum(real_loss_list)
fake_loss = sum(fake_loss_list)
disc_loss = real_loss + fake_loss
disc_loss = self.cfg.train.lambdas.lambda_disc * disc_loss
output = {
'real_loss': real_loss,
'fake_loss': fake_loss,
'disc_loss': disc_loss,
}
return output
def compute_gen_loss(self, batch, output):
y, y_ = output['gt_wav'], output['gen_wav']
vq_loss, vq_code = output['vq_loss'], output['vq_code']
semantic_recon_loss = output['semantic_recon_loss']
# perceptual_se_loss_l2 = output['perceptual_se_loss_l2']
# x_feat_recon_loss = output['x_feat_recon_loss']
gen_loss = 0.0
self.set_discriminator_gradients(False)
output_dict = {}
cfg = self.cfg.train
# Mel spectrogram loss
if cfg.use_mel_loss:
mel_loss = self.criteria['mel_loss'](y_.squeeze(1), y.squeeze(1))
gen_loss += mel_loss * cfg.lambdas.lambda_mel_loss
output_dict['mel_loss'] = mel_loss
# GAN loss
p_ = self.discriminator(y_)
adv_loss_list = []
for i in range(len(p_)):
adv_loss_list.append(self.criteria['gan_loss'].gen_loss(p_[i][-1]))
if hasattr(self, 'spec_discriminator'):
sd_p_ = self.spec_discriminator(y_)
for i in range(len(sd_p_)):
adv_loss_list.append(self.criteria['gan_loss'].gen_loss(sd_p_[i][-1]))
adv_loss = sum(adv_loss_list)
gen_loss += adv_loss * cfg.lambdas.lambda_adv
output_dict['adv_loss'] = adv_loss
# Feature Matching loss
if cfg.use_feat_match_loss:
fm_loss = 0.0
with torch.no_grad():
p = self.discriminator(y)
for i in range(len(p_)):
for j in range(len(p_[i]) - 1):
fm_loss += self.criteria['fm_loss'](p_[i][j], p[i][j].detach())
gen_loss += fm_loss * cfg.lambdas.lambda_feat_match_loss
output_dict['fm_loss'] = fm_loss
if hasattr(self, 'spec_discriminator'):
spec_fm_loss = 0.0
with torch.no_grad():
sd_p = self.spec_discriminator(y)
for i in range(len(sd_p_)):
for j in range(len(sd_p_[i]) - 1):
spec_fm_loss += self.criteria['fm_loss'](sd_p_[i][j], sd_p[i][j].detach())
gen_loss += spec_fm_loss * cfg.lambdas.lambda_feat_match_loss
output_dict['spec_fm_loss'] = spec_fm_loss
# VQ loss
if vq_loss is not None:
vq_loss = sum(vq_loss)
gen_loss += vq_loss
output_dict['vq_loss'] = vq_loss
# Semantic reconstruction loss
output_dict['semantic_recon_loss'] = semantic_recon_loss
gen_loss += output_dict['semantic_recon_loss'] * cfg.lambdas.lambda_semantic_loss
# Perceptual loss
# output_dict['perceptual_se_loss_l2'] = perceptual_se_loss_l2
# gen_loss += output_dict['perceptual_se_loss_l2'] * cfg.lambdas.lambda_perceptual_loss
self.set_discriminator_gradients(True)
output_dict['gen_loss'] = gen_loss
return output_dict
def training_step(self, batch, batch_idx):
output = self(batch)
gen_opt, disc_opt = self.optimizers()
gen_sche, disc_sche = self.lr_schedulers()
# 训练判别器
disc_losses = self.compute_disc_loss(batch, output)
disc_loss = disc_losses['disc_loss']
disc_opt.zero_grad()
self.manual_backward(disc_loss)
self.clip_gradients(
disc_opt,
gradient_clip_val=self.cfg.train.disc_grad_clip,
gradient_clip_algorithm='norm'
)
disc_opt.step()
disc_sche.step()
# 训练生成器
gen_losses = self.compute_gen_loss(batch, output)
gen_loss = gen_losses['gen_loss']
gen_opt.zero_grad()
self.manual_backward(gen_loss)
self.clip_gradients(
gen_opt,
gradient_clip_val=self.cfg.train.gen_grad_clip,
gradient_clip_algorithm='norm'
)
gen_opt.step()
gen_sche.step()
# 记录损失
self.log_dict(
disc_losses,
on_step=True,
on_epoch=True,
prog_bar=True,
logger=True,
batch_size=self.cfg.dataset.train.batch_size,
sync_dist=True
)
self.log_dict(
gen_losses,
on_step=True,
on_epoch=True,
prog_bar=True,
logger=True,
batch_size=self.cfg.dataset.train.batch_size,
sync_dist=True
)
# def validation_step(self, batch, batch_idx):
# # 您可以在此处实现验证逻辑
# output = self(batch)
# y = output['gt_wav'] # 真实音频
# y_ = output['gen_wav']
# # 生成的重建音频
# embeddings1 = self.model_spk( y.squeeze(1))
# # 处理目标文件
# embeddings2 = self.model_spk(y_.squeeze(1))
# # 计算余弦相似度
# sim = F.cosine_similarity(embeddings1, embeddings2)
# sim = sim.mean()
# self.log('val/sim', sim, on_step=False, on_epoch=True, prog_bar=True, logger=True)
# return {'sim': sim}
def validation_step(self, batch, batch_idx):
output = self(batch)
y = output['gt_wav'] # shape: [B, 1, T]
y_ = output['gen_wav'] # shape: [B, 1, T]
# Remove channel dimension and move to CPU for processing
y_audio = y.squeeze(1).cpu().numpy() # [B, T]
y_recon_audio = y_.squeeze(1).cpu().numpy() # [B, T]
embeddings1_list = []
embeddings2_list = []
# Process each sample in the batch
for i in range(y_audio.shape[0]):
# Resample from 24kHz to 16kHz
y_16k = librosa.resample(y_audio[i], orig_sr=self.cfg.preprocess.audio.sr, target_sr=16000)
y_recon_16k = librosa.resample(y_recon_audio[i], orig_sr=self.cfg.preprocess.audio.sr, target_sr=16000)
# Extract features
inputs1 = self.speaker_feature_extractor(
y_16k,
sampling_rate=16000,
return_tensors="pt"
).to(self.device)
inputs2 = self.speaker_feature_extractor(
y_recon_16k,
sampling_rate=16000,
return_tensors="pt"
).to(self.device)
# Get embeddings
with torch.no_grad():
outputs1 = self.speaker_model(**inputs1)
outputs2 = self.speaker_model(**inputs2)
# Mean pooling over time dimension
embedding1 = torch.mean(outputs1.last_hidden_state, dim=1)
embedding2 = torch.mean(outputs2.last_hidden_state, dim=1)
# L2 normalize
embedding1 = F.normalize(embedding1, p=2, dim=1)
embedding2 = F.normalize(embedding2, p=2, dim=1)
embeddings1_list.append(embedding1)
embeddings2_list.append(embedding2)
# Stack embeddings
embeddings1 = torch.cat(embeddings1_list, dim=0)
embeddings2 = torch.cat(embeddings2_list, dim=0)
# Calculate cosine similarity
sim = F.cosine_similarity(embeddings1, embeddings2)
sim = sim.mean()
self.log('val/sim', sim, on_step=False, on_epoch=True, prog_bar=True, logger=True)
return {'sim': sim}
def test_step(self, batch, batch_idx):
# 您可以在此处实现测试逻辑
pass
def configure_optimizers(self):
from itertools import chain
# 判别器参数
disc_params = self.discriminator.parameters()
# if hasattr(self, 'spec_discriminator'):
disc_params = chain(disc_params, self.spec_discriminator.parameters())
# 生成器参数
gen_params = chain(
self.CodecEnc.parameters(),
self.generator.parameters(),
# self.mel_conv.parameters(),
self.fc_prior.parameters(),
self.fc_post_a.parameters(),
self.fc_post_s.parameters(),
self.SemanticDecoder_module.parameters(),
self.SemanticEncoder_module.parameters()
)
# 优化器
gen_opt = optim.AdamW(gen_params, **self.cfg.train.gen_optim_params)
disc_opt = optim.AdamW(disc_params, **self.cfg.train.disc_optim_params)
# 学习率调度器
gen_sche = WarmupLR(gen_opt, **self.cfg.train.gen_schedule_params)
disc_sche = WarmupLR(disc_opt, **self.cfg.train.disc_schedule_params)
print(f'Generator optim: {gen_opt}')
print(f'Discriminator optim: {disc_opt}')
return [gen_opt, disc_opt], [gen_sche, disc_sche]
def set_discriminator_gradients(self, flag=True):
for p in self.discriminator.parameters():
p.requires_grad = flag
if hasattr(self, 'spec_discriminator'):
for p in self.spec_discriminator.parameters():
p.requires_grad = flag