nickovchinnikov's picture
Init
9d61c9b
import itertools
from typing import List
from lightning.pytorch.core import LightningModule
import torch
from torch import nn
from torch.optim import AdamW, Optimizer
from torch.optim.lr_scheduler import ExponentialLR
from torch.utils.data import DataLoader
from models.config import (
HifiGanConfig,
HifiGanPretrainingConfig,
)
from models.config import (
PreprocessingConfigHifiGAN as PreprocessingConfig,
)
from training.datasets.hifi_gan_dataset import train_dataloader
from training.loss import (
DiscriminatorLoss,
FeatureMatchingLoss,
GeneratorLoss,
)
from training.preprocess import TacotronSTFT
from .generator import Generator
from .mp_discriminator import MultiPeriodDiscriminator
from .ms_discriminator import MultiScaleDiscriminator
class HifiGan(LightningModule):
r"""HifiGan module.
This module contains the `Generator` and `Discriminator` models, and handles training and optimization.
"""
def __init__(
self,
lang: str = "en",
batch_size: int = 16,
sampling_rate: int = 44100,
):
r"""Initializes the `HifiGan`.
Args:
fine_tuning (bool, optional): Whether to use fine-tuning mode or not. Defaults to False.
lang (str): Language of the dataset.
batch_size (int): The batch size. Defaults to 16.
sampling_rate (int): The sampling rate of the audio. Defaults to 44100.
"""
super().__init__()
self.batch_size = batch_size
self.sampling_rate = sampling_rate
self.lang = lang
self.preprocess_config = PreprocessingConfig(
"multilingual",
sampling_rate=sampling_rate,
)
self.train_config = HifiGanPretrainingConfig()
self.generator = Generator(
h=HifiGanConfig(),
p=self.preprocess_config,
)
self.mpd = MultiPeriodDiscriminator()
self.msd = MultiScaleDiscriminator()
self.feature_loss = FeatureMatchingLoss()
self.discriminator_loss = DiscriminatorLoss()
self.generator_loss = GeneratorLoss()
self.mae_loss = nn.L1Loss()
self.tacotronSTFT = TacotronSTFT(
filter_length=self.preprocess_config.stft.filter_length,
hop_length=self.preprocess_config.stft.hop_length,
win_length=self.preprocess_config.stft.win_length,
n_mel_channels=self.preprocess_config.stft.n_mel_channels,
sampling_rate=self.preprocess_config.sampling_rate,
mel_fmin=self.preprocess_config.stft.mel_fmin,
mel_fmax=self.preprocess_config.stft.mel_fmax,
center=False,
)
# Mark TacotronSTFT as non-trainable
for param in self.tacotronSTFT.parameters():
param.requires_grad = False
# Switch to manual optimization
self.automatic_optimization = False
def forward(self, y_pred: torch.Tensor) -> torch.Tensor:
r"""Performs a forward pass through the UnivNet model.
Args:
y_pred (torch.Tensor): The predicted mel spectrogram.
Returns:
torch.Tensor: The output of the UnivNet model.
"""
wav_prediction = self.generator.forward(y_pred)
return wav_prediction.squeeze()
def training_step(self, batch: List, batch_idx: int):
r"""Performs a training step for the model.
Args:
batch (Tuple[str, Tensor, Tensor]): The batch of data for training. Each item in the list is a tuple containing the ID of the item, the audio waveform, and the mel spectrogram.
batch_idx (int): Index of the batch.
Returns:
dict: A dictionary containing the total loss for the generator and logs for tensorboard.
"""
_, audio, mel = batch
# Access your optimizers
optimizers = self.optimizers()
schedulers = self.lr_schedulers()
opt_generator: Optimizer = optimizers[0] # type: ignore
sch_generator: ExponentialLR = schedulers[0] # type: ignore
opt_discriminator: Optimizer = optimizers[1] # type: ignore
sch_discriminator: ExponentialLR = schedulers[1] # type: ignore
# Generate fake audio
audio_pred = self.generator.forward(mel)
_, fake_mel = self.tacotronSTFT(audio_pred.squeeze(1))
# Train discriminator
opt_discriminator.zero_grad()
mpd_score_real, mpd_score_gen, _, _ = self.mpd.forward(
y=audio,
y_hat=audio_pred.detach(),
)
loss_disc_mpd, _, _ = self.discriminator_loss.forward(
disc_real_outputs=mpd_score_real,
disc_generated_outputs=mpd_score_gen,
)
msd_score_real, msd_score_gen, _, _ = self.msd(
y=audio,
y_hat=audio_pred.detach(),
)
loss_disc_msd, _, _ = self.discriminator_loss(
disc_real_outputs=msd_score_real,
disc_generated_outputs=msd_score_gen,
)
loss_d = loss_disc_msd + loss_disc_mpd
# Step for the discriminator
self.manual_backward(loss_d, retain_graph=True)
opt_discriminator.step()
# Train generator
opt_generator.zero_grad()
loss_mel = self.mae_loss(fake_mel, mel)
_, mpd_score_gen, fmap_mpd_real, fmap_mpd_gen = self.mpd.forward(
y=audio,
y_hat=audio_pred,
)
_, msd_score_gen, fmap_msd_real, fmap_msd_gen = self.msd.forward(
y=audio,
y_hat=audio_pred,
)
loss_fm_mpd = self.feature_loss.forward(
fmap_r=fmap_mpd_real,
fmap_g=fmap_mpd_gen,
)
loss_fm_msd = self.feature_loss.forward(
fmap_r=fmap_msd_real,
fmap_g=fmap_msd_gen,
)
loss_gen_mpd, _ = self.generator_loss.forward(disc_outputs=mpd_score_gen)
loss_gen_msd, _ = self.generator_loss.forward(disc_outputs=msd_score_gen)
loss_g = (
loss_gen_msd
+ loss_gen_mpd
+ loss_fm_msd
+ loss_fm_mpd
+ loss_mel * self.train_config.l1_factor
)
# step for the generator
self.manual_backward(loss_g, retain_graph=True)
opt_generator.step()
# Schedulers step
sch_generator.step()
sch_discriminator.step()
# Gen losses
self.log(
"loss_gen_msd",
loss_gen_msd,
sync_dist=True,
batch_size=self.batch_size,
)
self.log(
"loss_gen_mpd",
loss_gen_mpd,
sync_dist=True,
batch_size=self.batch_size,
)
self.log(
"loss_fm_msd",
loss_fm_msd,
sync_dist=True,
batch_size=self.batch_size,
)
self.log(
"loss_fm_mpd",
loss_fm_mpd,
sync_dist=True,
batch_size=self.batch_size,
)
self.log(
"mel_loss",
loss_mel,
sync_dist=True,
batch_size=self.batch_size,
)
# Disc logs
self.log(
"loss_disc_msd",
loss_disc_msd,
sync_dist=True,
batch_size=self.batch_size,
)
self.log(
"loss_disc_mpd",
loss_disc_mpd,
sync_dist=True,
batch_size=self.batch_size,
)
self.log(
"total_loss_disc",
loss_d,
sync_dist=True,
batch_size=self.batch_size,
)
def configure_optimizers(self):
r"""Configures the optimizers and learning rate schedulers for the `UnivNet` and `Discriminator` models.
This method creates an `AdamW` optimizer and an `ExponentialLR` scheduler for each model.
The learning rate, betas, and decay rate for the optimizers and schedulers are taken from the training configuration.
Returns
tuple: A tuple containing two dictionaries. Each dictionary contains the optimizer and learning rate scheduler for one of the models.
"""
optim_generator = AdamW(
self.generator.parameters(),
self.train_config.learning_rate,
betas=(self.train_config.adam_b1, self.train_config.adam_b2),
)
scheduler_generator = ExponentialLR(
optim_generator,
gamma=self.train_config.lr_decay,
last_epoch=-1,
)
optim_discriminator = AdamW(
itertools.chain(self.msd.parameters(), self.mpd.parameters()),
self.train_config.learning_rate,
betas=(self.train_config.adam_b1, self.train_config.adam_b2),
)
scheduler_discriminator = ExponentialLR(
optim_discriminator,
gamma=self.train_config.lr_decay,
last_epoch=-1,
)
return (
{"optimizer": optim_generator, "lr_scheduler": scheduler_generator},
{"optimizer": optim_discriminator, "lr_scheduler": scheduler_discriminator},
)
def train_dataloader(
self,
root: str = "datasets_cache",
cache: bool = True,
cache_dir: str = "/dev/shm",
) -> DataLoader:
r"""Returns the training dataloader, that is using the LibriTTS dataset.
Args:
root (str): The root directory of the dataset.
cache (bool): Whether to cache the preprocessed data.
cache_dir (str): The directory for the cache. Defaults to "/dev/shm".
Returns:
Tupple[DataLoader, DataLoader]: The training and validation dataloaders.
"""
return train_dataloader(
batch_size=self.batch_size,
num_workers=self.preprocess_config.workers,
root=root,
cache=cache,
cache_dir=cache_dir,
lang=self.lang,
)