Spaces:
Sleeping
Sleeping
import itertools | |
from typing import List | |
from lightning.pytorch.core import LightningModule | |
import torch | |
from torch import nn | |
from torch.optim import AdamW, Optimizer | |
from torch.optim.lr_scheduler import ExponentialLR | |
from torch.utils.data import DataLoader | |
from models.config import ( | |
HifiGanConfig, | |
HifiGanPretrainingConfig, | |
) | |
from models.config import ( | |
PreprocessingConfigHifiGAN as PreprocessingConfig, | |
) | |
from training.datasets.hifi_gan_dataset import train_dataloader | |
from training.loss import ( | |
DiscriminatorLoss, | |
FeatureMatchingLoss, | |
GeneratorLoss, | |
) | |
from training.preprocess import TacotronSTFT | |
from .generator import Generator | |
from .mp_discriminator import MultiPeriodDiscriminator | |
from .ms_discriminator import MultiScaleDiscriminator | |
class HifiGan(LightningModule): | |
r"""HifiGan module. | |
This module contains the `Generator` and `Discriminator` models, and handles training and optimization. | |
""" | |
def __init__( | |
self, | |
lang: str = "en", | |
batch_size: int = 16, | |
sampling_rate: int = 44100, | |
): | |
r"""Initializes the `HifiGan`. | |
Args: | |
fine_tuning (bool, optional): Whether to use fine-tuning mode or not. Defaults to False. | |
lang (str): Language of the dataset. | |
batch_size (int): The batch size. Defaults to 16. | |
sampling_rate (int): The sampling rate of the audio. Defaults to 44100. | |
""" | |
super().__init__() | |
self.batch_size = batch_size | |
self.sampling_rate = sampling_rate | |
self.lang = lang | |
self.preprocess_config = PreprocessingConfig( | |
"multilingual", | |
sampling_rate=sampling_rate, | |
) | |
self.train_config = HifiGanPretrainingConfig() | |
self.generator = Generator( | |
h=HifiGanConfig(), | |
p=self.preprocess_config, | |
) | |
self.mpd = MultiPeriodDiscriminator() | |
self.msd = MultiScaleDiscriminator() | |
self.feature_loss = FeatureMatchingLoss() | |
self.discriminator_loss = DiscriminatorLoss() | |
self.generator_loss = GeneratorLoss() | |
self.mae_loss = nn.L1Loss() | |
self.tacotronSTFT = TacotronSTFT( | |
filter_length=self.preprocess_config.stft.filter_length, | |
hop_length=self.preprocess_config.stft.hop_length, | |
win_length=self.preprocess_config.stft.win_length, | |
n_mel_channels=self.preprocess_config.stft.n_mel_channels, | |
sampling_rate=self.preprocess_config.sampling_rate, | |
mel_fmin=self.preprocess_config.stft.mel_fmin, | |
mel_fmax=self.preprocess_config.stft.mel_fmax, | |
center=False, | |
) | |
# Mark TacotronSTFT as non-trainable | |
for param in self.tacotronSTFT.parameters(): | |
param.requires_grad = False | |
# Switch to manual optimization | |
self.automatic_optimization = False | |
def forward(self, y_pred: torch.Tensor) -> torch.Tensor: | |
r"""Performs a forward pass through the UnivNet model. | |
Args: | |
y_pred (torch.Tensor): The predicted mel spectrogram. | |
Returns: | |
torch.Tensor: The output of the UnivNet model. | |
""" | |
wav_prediction = self.generator.forward(y_pred) | |
return wav_prediction.squeeze() | |
def training_step(self, batch: List, batch_idx: int): | |
r"""Performs a training step for the model. | |
Args: | |
batch (Tuple[str, Tensor, Tensor]): The batch of data for training. Each item in the list is a tuple containing the ID of the item, the audio waveform, and the mel spectrogram. | |
batch_idx (int): Index of the batch. | |
Returns: | |
dict: A dictionary containing the total loss for the generator and logs for tensorboard. | |
""" | |
_, audio, mel = batch | |
# Access your optimizers | |
optimizers = self.optimizers() | |
schedulers = self.lr_schedulers() | |
opt_generator: Optimizer = optimizers[0] # type: ignore | |
sch_generator: ExponentialLR = schedulers[0] # type: ignore | |
opt_discriminator: Optimizer = optimizers[1] # type: ignore | |
sch_discriminator: ExponentialLR = schedulers[1] # type: ignore | |
# Generate fake audio | |
audio_pred = self.generator.forward(mel) | |
_, fake_mel = self.tacotronSTFT(audio_pred.squeeze(1)) | |
# Train discriminator | |
opt_discriminator.zero_grad() | |
mpd_score_real, mpd_score_gen, _, _ = self.mpd.forward( | |
y=audio, | |
y_hat=audio_pred.detach(), | |
) | |
loss_disc_mpd, _, _ = self.discriminator_loss.forward( | |
disc_real_outputs=mpd_score_real, | |
disc_generated_outputs=mpd_score_gen, | |
) | |
msd_score_real, msd_score_gen, _, _ = self.msd( | |
y=audio, | |
y_hat=audio_pred.detach(), | |
) | |
loss_disc_msd, _, _ = self.discriminator_loss( | |
disc_real_outputs=msd_score_real, | |
disc_generated_outputs=msd_score_gen, | |
) | |
loss_d = loss_disc_msd + loss_disc_mpd | |
# Step for the discriminator | |
self.manual_backward(loss_d, retain_graph=True) | |
opt_discriminator.step() | |
# Train generator | |
opt_generator.zero_grad() | |
loss_mel = self.mae_loss(fake_mel, mel) | |
_, mpd_score_gen, fmap_mpd_real, fmap_mpd_gen = self.mpd.forward( | |
y=audio, | |
y_hat=audio_pred, | |
) | |
_, msd_score_gen, fmap_msd_real, fmap_msd_gen = self.msd.forward( | |
y=audio, | |
y_hat=audio_pred, | |
) | |
loss_fm_mpd = self.feature_loss.forward( | |
fmap_r=fmap_mpd_real, | |
fmap_g=fmap_mpd_gen, | |
) | |
loss_fm_msd = self.feature_loss.forward( | |
fmap_r=fmap_msd_real, | |
fmap_g=fmap_msd_gen, | |
) | |
loss_gen_mpd, _ = self.generator_loss.forward(disc_outputs=mpd_score_gen) | |
loss_gen_msd, _ = self.generator_loss.forward(disc_outputs=msd_score_gen) | |
loss_g = ( | |
loss_gen_msd | |
+ loss_gen_mpd | |
+ loss_fm_msd | |
+ loss_fm_mpd | |
+ loss_mel * self.train_config.l1_factor | |
) | |
# step for the generator | |
self.manual_backward(loss_g, retain_graph=True) | |
opt_generator.step() | |
# Schedulers step | |
sch_generator.step() | |
sch_discriminator.step() | |
# Gen losses | |
self.log( | |
"loss_gen_msd", | |
loss_gen_msd, | |
sync_dist=True, | |
batch_size=self.batch_size, | |
) | |
self.log( | |
"loss_gen_mpd", | |
loss_gen_mpd, | |
sync_dist=True, | |
batch_size=self.batch_size, | |
) | |
self.log( | |
"loss_fm_msd", | |
loss_fm_msd, | |
sync_dist=True, | |
batch_size=self.batch_size, | |
) | |
self.log( | |
"loss_fm_mpd", | |
loss_fm_mpd, | |
sync_dist=True, | |
batch_size=self.batch_size, | |
) | |
self.log( | |
"mel_loss", | |
loss_mel, | |
sync_dist=True, | |
batch_size=self.batch_size, | |
) | |
# Disc logs | |
self.log( | |
"loss_disc_msd", | |
loss_disc_msd, | |
sync_dist=True, | |
batch_size=self.batch_size, | |
) | |
self.log( | |
"loss_disc_mpd", | |
loss_disc_mpd, | |
sync_dist=True, | |
batch_size=self.batch_size, | |
) | |
self.log( | |
"total_loss_disc", | |
loss_d, | |
sync_dist=True, | |
batch_size=self.batch_size, | |
) | |
def configure_optimizers(self): | |
r"""Configures the optimizers and learning rate schedulers for the `UnivNet` and `Discriminator` models. | |
This method creates an `AdamW` optimizer and an `ExponentialLR` scheduler for each model. | |
The learning rate, betas, and decay rate for the optimizers and schedulers are taken from the training configuration. | |
Returns | |
tuple: A tuple containing two dictionaries. Each dictionary contains the optimizer and learning rate scheduler for one of the models. | |
""" | |
optim_generator = AdamW( | |
self.generator.parameters(), | |
self.train_config.learning_rate, | |
betas=(self.train_config.adam_b1, self.train_config.adam_b2), | |
) | |
scheduler_generator = ExponentialLR( | |
optim_generator, | |
gamma=self.train_config.lr_decay, | |
last_epoch=-1, | |
) | |
optim_discriminator = AdamW( | |
itertools.chain(self.msd.parameters(), self.mpd.parameters()), | |
self.train_config.learning_rate, | |
betas=(self.train_config.adam_b1, self.train_config.adam_b2), | |
) | |
scheduler_discriminator = ExponentialLR( | |
optim_discriminator, | |
gamma=self.train_config.lr_decay, | |
last_epoch=-1, | |
) | |
return ( | |
{"optimizer": optim_generator, "lr_scheduler": scheduler_generator}, | |
{"optimizer": optim_discriminator, "lr_scheduler": scheduler_discriminator}, | |
) | |
def train_dataloader( | |
self, | |
root: str = "datasets_cache", | |
cache: bool = True, | |
cache_dir: str = "/dev/shm", | |
) -> DataLoader: | |
r"""Returns the training dataloader, that is using the LibriTTS dataset. | |
Args: | |
root (str): The root directory of the dataset. | |
cache (bool): Whether to cache the preprocessed data. | |
cache_dir (str): The directory for the cache. Defaults to "/dev/shm". | |
Returns: | |
Tupple[DataLoader, DataLoader]: The training and validation dataloaders. | |
""" | |
return train_dataloader( | |
batch_size=self.batch_size, | |
num_workers=self.preprocess_config.workers, | |
root=root, | |
cache=cache, | |
cache_dir=cache_dir, | |
lang=self.lang, | |
) | |