chong.zhang
update
96fe5d9
import math
import numpy as np
import pytorch_lightning as pl
import torch
import torchaudio
import transformers
import yaml
from decoder.discriminator_dac import DACDiscriminator
from decoder.discriminators import MultiPeriodDiscriminator, MultiResolutionDiscriminator
from decoder.feature_extractors import FeatureExtractor
from decoder.heads import FourierHead
from decoder.helpers import plot_spectrogram_to_numpy
from decoder.loss import DiscriminatorLoss, GeneratorLoss, FeatureMatchingLoss, MelSpecReconstructionLoss, DACGANLoss
from decoder.models import Backbone
from decoder.modules import safe_log
from decoder.pretrained_model import instantiate_class
class VocosExp(pl.LightningModule):
# noinspection PyUnusedLocal
def __init__(
self,
feature_extractor: FeatureExtractor,
backbone: Backbone,
head: FourierHead,
resume_config: str,
resume_model: str,
sample_rate: int = 24000,
initial_learning_rate: float = 2e-4,
num_warmup_steps: int = 0,
mel_loss_coeff: float = 45,
mrd_loss_coeff: float = 1.0,
pretrain_mel_steps: int = 0,
decay_mel_coeff: bool = False,
evaluate_utmos: bool = False,
evaluate_pesq: bool = False,
evaluate_periodicty: bool = False,
resume: bool = False,
):
"""
Args:
feature_extractor (FeatureExtractor): An instance of FeatureExtractor to extract features from audio signals.
backbone (Backbone): An instance of Backbone model.
head (FourierHead): An instance of Fourier head to generate spectral coefficients and reconstruct a waveform.
sample_rate (int): Sampling rate of the audio signals.
initial_learning_rate (float): Initial learning rate for the optimizer.
num_warmup_steps (int): Number of steps for the warmup phase of learning rate scheduler. Default is 0.
mel_loss_coeff (float, optional): Coefficient for Mel-spectrogram loss in the loss function. Default is 45.
mrd_loss_coeff (float, optional): Coefficient for Multi Resolution Discriminator loss. Default is 1.0.
pretrain_mel_steps (int, optional): Number of steps to pre-train the model without the GAN objective. Default is 0.
decay_mel_coeff (bool, optional): If True, the Mel-spectrogram loss coefficient is decayed during training. Default is False.
evaluate_utmos (bool, optional): If True, UTMOS scores are computed for each validation run.
evaluate_pesq (bool, optional): If True, PESQ scores are computed for each validation run.
evaluate_periodicty (bool, optional): If True, periodicity scores are computed for each validation run.
"""
super().__init__()
self.save_hyperparameters(ignore=["feature_extractor", "backbone", "head"])
self.feature_extractor = feature_extractor
self.backbone = backbone
self.head = head
self.resume_config = resume_config
self.resume_model = resume_model
self.resume = resume
self.multiperioddisc = MultiPeriodDiscriminator()
self.multiresddisc = MultiResolutionDiscriminator()
self.dac = DACDiscriminator()
self.dacdiscriminator = DACGANLoss(self.dac)
self.disc_loss = DiscriminatorLoss()
self.gen_loss = GeneratorLoss()
self.feat_matching_loss = FeatureMatchingLoss()
self.melspec_loss = MelSpecReconstructionLoss(sample_rate=sample_rate)
self.train_discriminator = False
self.base_mel_coeff = self.mel_loss_coeff = mel_loss_coeff
def configure_optimizers(self):
disc_params = [
{"params": self.multiperioddisc.parameters()},
{"params": self.multiresddisc.parameters()},
{"params": self.dac.parameters()},
]
gen_params = [
{"params": self.feature_extractor.parameters()},
{"params": self.backbone.parameters()},
{"params": self.head.parameters()},
]
opt_disc = torch.optim.AdamW(disc_params, lr=self.hparams.initial_learning_rate)
opt_gen = torch.optim.AdamW(gen_params, lr=self.hparams.initial_learning_rate)
max_steps = self.trainer.max_steps // 2 # Max steps per optimizer
scheduler_disc = transformers.get_cosine_schedule_with_warmup(
opt_disc, num_warmup_steps=self.hparams.num_warmup_steps, num_training_steps=max_steps,
)
scheduler_gen = transformers.get_cosine_schedule_with_warmup(
opt_gen, num_warmup_steps=self.hparams.num_warmup_steps, num_training_steps=max_steps,
)
return (
[opt_disc, opt_gen],
[{"scheduler": scheduler_disc, "interval": "step"}, {"scheduler": scheduler_gen, "interval": "step"}],
)
def forward(self, audio_input, **kwargs):
features, _, commit_loss = self.feature_extractor(audio_input, **kwargs)
# print('1111', self.feature_extractor.state_dict()['encodec.decoder.model.3.convtr.convtr.weight_g'])
x = self.backbone(features, **kwargs)
audio_output = self.head(x)
return audio_output, commit_loss
def training_step(self, batch, batch_idx, optimizer_idx, **kwargs):
audio_input = batch
# train discriminator
if optimizer_idx == 0 and self.train_discriminator:
with torch.no_grad():
audio_hat, _ = self(audio_input, **kwargs)
loss_dac=self.dacdiscriminator.discriminator_loss(audio_hat.unsqueeze(1),audio_input.unsqueeze(1))
real_score_mp, gen_score_mp, _, _ = self.multiperioddisc(y=audio_input, y_hat=audio_hat, **kwargs,)
real_score_mrd, gen_score_mrd, _, _ = self.multiresddisc(y=audio_input, y_hat=audio_hat, **kwargs,)
loss_mp, loss_mp_real, _ = self.disc_loss(
disc_real_outputs=real_score_mp, disc_generated_outputs=gen_score_mp
)
loss_mrd, loss_mrd_real, _ = self.disc_loss(
disc_real_outputs=real_score_mrd, disc_generated_outputs=gen_score_mrd
)
loss_mp /= len(loss_mp_real)
loss_mrd /= len(loss_mrd_real)
loss = loss_mp + self.hparams.mrd_loss_coeff * loss_mrd + loss_dac
self.log("discriminator/total", loss, prog_bar=True)
self.log("discriminator/multi_period_loss", loss_mp)
self.log("discriminator/multi_res_loss", loss_mrd)
self.log("discriminator/dac", loss_dac)
return loss
# train generator
if optimizer_idx == 1:
audio_hat, commit_loss = self(audio_input, **kwargs)
if self.train_discriminator:
loss_dac_1,loss_dac_2 = self.dacdiscriminator.generator_loss(audio_hat.unsqueeze(1),audio_input.unsqueeze(1))
_, gen_score_mp, fmap_rs_mp, fmap_gs_mp = self.multiperioddisc(
y=audio_input, y_hat=audio_hat, **kwargs,
)
_, gen_score_mrd, fmap_rs_mrd, fmap_gs_mrd = self.multiresddisc(
y=audio_input, y_hat=audio_hat, **kwargs,
)
loss_gen_mp, list_loss_gen_mp = self.gen_loss(disc_outputs=gen_score_mp)
loss_gen_mrd, list_loss_gen_mrd = self.gen_loss(disc_outputs=gen_score_mrd)
loss_gen_mp = loss_gen_mp / len(list_loss_gen_mp)
loss_gen_mrd = loss_gen_mrd / len(list_loss_gen_mrd)
loss_fm_mp = self.feat_matching_loss(fmap_r=fmap_rs_mp, fmap_g=fmap_gs_mp) / len(fmap_rs_mp)
loss_fm_mrd = self.feat_matching_loss(fmap_r=fmap_rs_mrd, fmap_g=fmap_gs_mrd) / len(fmap_rs_mrd)
self.log("generator/multi_period_loss", loss_gen_mp)
self.log("generator/multi_res_loss", loss_gen_mrd)
self.log("generator/feature_matching_mp", loss_fm_mp)
self.log("generator/feature_matching_mrd", loss_fm_mrd)
self.log("generator/loss_dac_1", loss_dac_1)
self.log("generator/loss_dac_2", loss_dac_2)
else:
loss_gen_mp = loss_gen_mrd = loss_fm_mp = loss_fm_mrd = 0
mel_loss = self.melspec_loss(audio_hat, audio_input)
loss = (
loss_gen_mp
+ self.hparams.mrd_loss_coeff * loss_gen_mrd
+ loss_fm_mp
+ self.hparams.mrd_loss_coeff * loss_fm_mrd
+ self.mel_loss_coeff * mel_loss
+ 1000 * commit_loss
+ loss_dac_1
+ loss_dac_2
)
self.log("generator/total_loss", loss, prog_bar=True)
self.log("mel_loss_coeff", self.mel_loss_coeff)
self.log("generator/mel_loss", mel_loss)
self.log("commit_loss", commit_loss)
if self.global_step % 1000 == 0 and self.global_rank == 0:
self.logger.experiment.add_audio(
"train/audio_in", audio_input[0].data.cpu(), self.global_step, self.hparams.sample_rate
)
self.logger.experiment.add_audio(
"train/audio_pred", audio_hat[0].data.cpu(), self.global_step, self.hparams.sample_rate
)
with torch.no_grad():
mel = safe_log(self.melspec_loss.mel_spec(audio_input[0]))
mel_hat = safe_log(self.melspec_loss.mel_spec(audio_hat[0]))
self.logger.experiment.add_image(
"train/mel_target",
plot_spectrogram_to_numpy(mel.data.cpu().numpy()),
self.global_step,
dataformats="HWC",
)
self.logger.experiment.add_image(
"train/mel_pred",
plot_spectrogram_to_numpy(mel_hat.data.cpu().numpy()),
self.global_step,
dataformats="HWC",
)
return loss
def on_validation_epoch_start(self):
if self.hparams.evaluate_utmos:
from metrics.UTMOS import UTMOSScore
if not hasattr(self, "utmos_model"):
self.utmos_model = UTMOSScore(device=self.device)
def validation_step(self, batch, batch_idx, **kwargs):
audio_input = batch
audio_hat, commit_loss = self(audio_input, **kwargs)
audio_16_khz = torchaudio.functional.resample(audio_input, orig_freq=self.hparams.sample_rate, new_freq=16000)
audio_hat_16khz = torchaudio.functional.resample(audio_hat, orig_freq=self.hparams.sample_rate, new_freq=16000)
if self.hparams.evaluate_periodicty:
from metrics.periodicity import calculate_periodicity_metrics
periodicity_loss, pitch_loss, f1_score = calculate_periodicity_metrics(audio_16_khz, audio_hat_16khz)
else:
periodicity_loss = pitch_loss = f1_score = 0
if self.hparams.evaluate_utmos:
utmos_score = self.utmos_model.score(audio_hat_16khz.unsqueeze(1)).mean()
else:
utmos_score = torch.zeros(1, device=self.device)
if self.hparams.evaluate_pesq:
from pesq import pesq
pesq_score = 0
for ref, deg in zip(audio_16_khz.cpu().numpy(), audio_hat_16khz.cpu().numpy()):
pesq_score += pesq(16000, ref, deg, "wb", on_error=1)
pesq_score /= len(audio_16_khz)
pesq_score = torch.tensor(pesq_score)
else:
pesq_score = torch.zeros(1, device=self.device)
mel_loss = self.melspec_loss(audio_hat.unsqueeze(1), audio_input.unsqueeze(1))
total_loss = mel_loss + (5 - utmos_score) + (5 - pesq_score) + 1000 * commit_loss
return {
"val_loss": total_loss,
"mel_loss": mel_loss,
"utmos_score": utmos_score,
"pesq_score": pesq_score,
"periodicity_loss": periodicity_loss,
"pitch_loss": pitch_loss,
"f1_score": f1_score,
"audio_input": audio_input[0],
"audio_pred": audio_hat[0],
}
def validation_epoch_end(self, outputs):
if self.global_rank == 0:
*_, audio_in, audio_pred = outputs[0].values()
self.logger.experiment.add_audio(
"val_in", audio_in.data.cpu().numpy(), self.global_step, self.hparams.sample_rate
)
self.logger.experiment.add_audio(
"val_pred", audio_pred.data.cpu().numpy(), self.global_step, self.hparams.sample_rate
)
mel_target = safe_log(self.melspec_loss.mel_spec(audio_in))
mel_hat = safe_log(self.melspec_loss.mel_spec(audio_pred))
self.logger.experiment.add_image(
"val_mel_target",
plot_spectrogram_to_numpy(mel_target.data.cpu().numpy()),
self.global_step,
dataformats="HWC",
)
self.logger.experiment.add_image(
"val_mel_hat",
plot_spectrogram_to_numpy(mel_hat.data.cpu().numpy()),
self.global_step,
dataformats="HWC",
)
avg_loss = torch.stack([x["val_loss"] for x in outputs]).mean()
mel_loss = torch.stack([x["mel_loss"] for x in outputs]).mean()
utmos_score = torch.stack([x["utmos_score"] for x in outputs]).mean()
pesq_score = torch.stack([x["pesq_score"] for x in outputs]).mean()
periodicity_loss = np.array([x["periodicity_loss"] for x in outputs]).mean()
pitch_loss = np.array([x["pitch_loss"] for x in outputs]).mean()
f1_score = np.array([x["f1_score"] for x in outputs]).mean()
self.log("val_loss", avg_loss, sync_dist=True)
self.log("val/mel_loss", mel_loss, sync_dist=True)
self.log("val/utmos_score", utmos_score, sync_dist=True)
self.log("val/pesq_score", pesq_score, sync_dist=True)
self.log("val/periodicity_loss", periodicity_loss, sync_dist=True)
self.log("val/pitch_loss", pitch_loss, sync_dist=True)
self.log("val/f1_score", f1_score, sync_dist=True)
@property
def global_step(self):
"""
Override global_step so that it returns the total number of batches processed
"""
return self.trainer.fit_loop.epoch_loop.total_batch_idx
def on_train_batch_start(self, *args):
if self.global_step >= self.hparams.pretrain_mel_steps:
self.train_discriminator = True
else:
self.train_discriminator = False
def on_train_batch_end(self, *args):
def mel_loss_coeff_decay(current_step, num_cycles=0.5):
max_steps = self.trainer.max_steps // 2
if current_step < self.hparams.num_warmup_steps:
return 1.0
progress = float(current_step - self.hparams.num_warmup_steps) / float(
max(1, max_steps - self.hparams.num_warmup_steps)
)
return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)))
if self.hparams.decay_mel_coeff:
self.mel_loss_coeff = self.base_mel_coeff * mel_loss_coeff_decay(self.global_step + 1)
class WavTokenizer(VocosExp):
"""
WavTokenizer is a subclass of VocosExp that overrides the parent experiment to function as a conditional GAN.
It manages an additional `bandwidth_id` attribute, which denotes a learnable embedding corresponding to
a specific bandwidth value of EnCodec. During training, a random bandwidth_id is generated for each step,
while during validation, a fixed bandwidth_id is used.
"""
def __init__(
self,
feature_extractor: FeatureExtractor,
backbone: Backbone,
head: FourierHead,
resume_config: str,
resume_model: str,
sample_rate: int = 24000,
initial_learning_rate: float = 2e-4,
num_warmup_steps: int = 0,
mel_loss_coeff: float = 45,
mrd_loss_coeff: float = 1.0,
pretrain_mel_steps: int = 0,
decay_mel_coeff: bool = False,
evaluate_utmos: bool = False,
evaluate_pesq: bool = False,
evaluate_periodicty: bool = False,
resume: bool = False,
):
super().__init__(
feature_extractor,
backbone,
head,
resume_config,
resume_model,
sample_rate,
initial_learning_rate,
num_warmup_steps,
mel_loss_coeff,
mrd_loss_coeff,
pretrain_mel_steps,
decay_mel_coeff,
evaluate_utmos,
evaluate_pesq,
evaluate_periodicty,
resume
)
# Override with conditional discriminators
# VocosExp.__init__(self, feature_extractor, backbone, head, resume_config, resume_model)
# if self.resume:
# VocosExp.load_from_checkpoint(self.resume_model)
self.multiperioddisc = MultiPeriodDiscriminator(num_embeddings=len(self.feature_extractor.bandwidths))
self.multiresddisc = MultiResolutionDiscriminator(num_embeddings=len(self.feature_extractor.bandwidths))
self.dac = DACDiscriminator()
if self.resume:
print('加载预训练模型:', self.resume_model)
# with open(self.resume_config, "r") as f:
# config = yaml.safe_load(f)
# feature_extractor = instantiate_class(args=(), init=config['model']['init_args']["feature_extractor"])
# backbone = instantiate_class(args=(), init=config['model']['init_args']["backbone"])
# head = instantiate_class(args=(), init=config['model']['init_args']["head"])
# 不加载量化器部分权重
state_dict_raw = torch.load(self.resume_model, map_location=self.device)['state_dict']
state_dict_fa_qa = dict()
state_dict_fa_en = dict()
state_dict_fa_de = dict()
state_dict_bb = dict()
state_dict_hd = dict()
state_dict_mp = dict()
state_dict_mr = dict()
state_dict_dac = dict()
for k, v in state_dict_raw.items():
# breakpoint()
if k.startswith('feature_extractor.encodec.quantizer'):
# breakpoint()
# print("*****",k)
ss = k[46:48]
if ss[-1] == '.':
num = int(ss[0])
# print("num,k",num,k[36:])
if num <= 7:
state_dict_fa_qa[k[36:]] = v
if k.startswith('feature_extractor.encodec.encoder'):
state_dict_fa_en[k[34:]] = v
if k.startswith('feature_extractor.encodec.decoder'):
state_dict_fa_de[k[34:]] = v
if k.startswith('backbone.'):
state_dict_bb[k[9:]] = v
if k.startswith('head.'):
state_dict_hd[k[5:]] = v
if k.startswith('multiperioddisc.'):
state_dict_mp[k[16:]] = v
if k.startswith('multiresddisc.'):
state_dict_mr[k[14:]] = v
if k.startswith('dac.'):
state_dict_dac[k[4:]] = v
# breakpoint()
# feature_extractor.encodec.quantizer.load_state_dict(state_dict_fa_qa, strict=True)
feature_extractor.encodec.encoder.load_state_dict(state_dict_fa_en, strict=True)
feature_extractor.encodec.decoder.load_state_dict(state_dict_fa_de, strict=True)
feature_extractor.encodec.quantizer.load_state_dict(state_dict_fa_qa, strict=True)
backbone.load_state_dict(state_dict_bb, strict=True)
head.load_state_dict(state_dict_hd, strict=True)
self.feature_extractor = feature_extractor.to(self.device)
self.backbone = backbone.to(self.device)
self.head = head.to(self.device)
self.multiperioddisc.load_state_dict(state_dict_mp, strict=True)
self.multiresddisc.load_state_dict(state_dict_mr, strict=True)
self.dac.load_state_dict(state_dict_dac, strict=True)
def training_step(self, *args):
# print('-------------------train--------------------')
# if self.global_rank == 0 and self.resume:
# config_path = self.resume_config
# model_path = self.resume_model
# self.pretrained_load(config_path, model_path)
# print('加载预训练模型:', model_path)
bandwidth_id = torch.randint(low=0, high=len(self.feature_extractor.bandwidths), size=(1,), device=self.device,)
output = super().training_step(*args, bandwidth_id=bandwidth_id)
return output
def validation_step(self, *args):
# print('-------------------valid--------------------')
bandwidth_id = torch.tensor([0], device=self.device)
output = super().validation_step(*args, bandwidth_id=bandwidth_id)
return output
def validation_epoch_end(self, outputs):
if self.global_rank == 0:
*_, audio_in, _ = outputs[0].values()
# Resynthesis with encodec for reference
self.feature_extractor.encodec.set_target_bandwidth(self.feature_extractor.bandwidths[0])
encodec_audio = self.feature_extractor.encodec(audio_in[None, None, :])
self.logger.experiment.add_audio(
"encodec", encodec_audio[0, 0].data.cpu().numpy(), self.global_step, self.hparams.sample_rate,
)
super().validation_epoch_end(outputs)