nickovchinnikov's picture
Init
9d61c9b
from typing import List, Optional, Tuple
from lightning.pytorch.core import LightningModule
import torch
from torch.optim import AdamW, Optimizer, swa_utils
from torch.optim.lr_scheduler import ExponentialLR
from torch.utils.data import DataLoader
from models.config import (
PreprocessingConfigUnivNet as PreprocessingConfig,
)
from models.config import (
VocoderFinetuningConfig,
VocoderModelConfig,
VocoderPretrainingConfig,
VoicoderTrainingConfig,
)
from models.helpers.dataloaders import train_dataloader
from training.loss import UnivnetLoss
from .discriminator import Discriminator
from .generator import Generator
class UnivNet(LightningModule):
r"""Univnet module.
This module contains the `Generator` and `Discriminator` models, and handles training and optimization.
"""
def __init__(
self,
fine_tuning: bool = False,
lang: str = "en",
acc_grad_steps: int = 10,
batch_size: int = 6,
root: str = "datasets_cache/LIBRITTS",
checkpoint_path_v1: Optional[str] = "vocoder_pretrained.pt",
):
r"""Initializes the `VocoderModule`.
Args:
fine_tuning (bool, optional): Whether to use fine-tuning mode or not. Defaults to False.
lang (str): Language of the dataset.
acc_grad_steps (int): Accumulated gradient steps.
batch_size (int): The batch size.
root (str, optional): The root directory for the dataset. Defaults to "datasets_cache/LIBRITTS".
checkpoint_path_v1 (str, optional): The path to the checkpoint for the model. If provided, the model weights will be loaded from this checkpoint. Defaults to None.
"""
super().__init__()
# Switch to manual optimization
self.automatic_optimization = False
self.acc_grad_steps = acc_grad_steps
self.batch_size = batch_size
self.lang = lang
self.root = root
model_config = VocoderModelConfig()
preprocess_config = PreprocessingConfig("english_only")
self.univnet = Generator(
model_config=model_config,
preprocess_config=preprocess_config,
)
self.discriminator = Discriminator(model_config=model_config)
# Initialize SWA
self.swa_averaged_univnet = swa_utils.AveragedModel(self.univnet)
self.swa_averaged_discriminator = swa_utils.AveragedModel(self.discriminator)
self.loss = UnivnetLoss()
self.train_config: VoicoderTrainingConfig = (
VocoderFinetuningConfig() if fine_tuning else VocoderPretrainingConfig()
)
# NOTE: this code is used only for the v0.1.0 checkpoint.
# In the future, this code will be removed!
self.checkpoint_path_v1 = checkpoint_path_v1
if checkpoint_path_v1 is not None:
generator, discriminator, _, _ = self.get_weights_v1(checkpoint_path_v1)
self.univnet.load_state_dict(generator, strict=False)
self.discriminator.load_state_dict(discriminator, strict=False)
def get_weights_v1(self, checkpoint_path: str) -> Tuple[dict, dict, dict, dict]:
r"""NOTE: this method is used only for the v0.1.0 checkpoint.
Prepares the weights for the model.
This is required for the model to be loaded from the checkpoint.
Args:
checkpoint_path (str): The path to the checkpoint.
Returns:
Tuple[dict, dict, dict, dict]: The weights for the generator and discriminator.
"""
ckpt_acoustic = torch.load(checkpoint_path, map_location=torch.device("cpu"))
return (
ckpt_acoustic["generator"],
ckpt_acoustic["discriminator"],
ckpt_acoustic["optim_g"],
ckpt_acoustic["optim_d"],
)
def forward(self, y_pred: torch.Tensor) -> torch.Tensor:
r"""Performs a forward pass through the UnivNet model.
Args:
y_pred (torch.Tensor): The predicted mel spectrogram.
Returns:
torch.Tensor: The output of the UnivNet model.
"""
mel_lens = torch.tensor(
[y_pred.shape[2]],
dtype=torch.int32,
device=y_pred.device,
)
wav_prediction = self.univnet.infer(y_pred, mel_lens)
return wav_prediction[0, 0]
def training_step(self, batch: List, batch_idx: int):
r"""Performs a training step for the model.
Args:
batch (List): The batch of data for training. The batch should contain the mel spectrogram, its length, the audio, and the speaker ID.
batch_idx (int): Index of the batch.
Returns:
dict: A dictionary containing the total loss for the generator and logs for tensorboard.
"""
(
_,
_,
_,
_,
_,
mels,
_,
_,
_,
_,
_,
wavs,
_,
) = batch
# Access your optimizers
optimizers = self.optimizers()
schedulers = self.lr_schedulers()
opt_univnet: Optimizer = optimizers[0] # type: ignore
sch_univnet: ExponentialLR = schedulers[0] # type: ignore
opt_discriminator: Optimizer = optimizers[1] # type: ignore
sch_discriminator: ExponentialLR = schedulers[1] # type: ignore
audio = wavs
fake_audio = self.univnet(mels)
res_fake, period_fake = self.discriminator(fake_audio.detach())
res_real, period_real = self.discriminator(audio)
(
total_loss_gen,
total_loss_disc,
stft_loss,
score_loss,
esr_loss,
snr_loss,
) = self.loss.forward(
audio,
fake_audio,
res_fake,
period_fake,
res_real,
period_real,
)
self.log(
"total_loss_gen",
total_loss_gen,
sync_dist=True,
batch_size=self.batch_size,
)
self.log(
"total_loss_disc",
total_loss_disc,
sync_dist=True,
batch_size=self.batch_size,
)
self.log("stft_loss", stft_loss, sync_dist=True, batch_size=self.batch_size)
self.log("esr_loss", esr_loss, sync_dist=True, batch_size=self.batch_size)
self.log("snr_loss", snr_loss, sync_dist=True, batch_size=self.batch_size)
self.log("score_loss", score_loss, sync_dist=True, batch_size=self.batch_size)
# Perform manual optimization
self.manual_backward(total_loss_gen / self.acc_grad_steps, retain_graph=True)
self.manual_backward(total_loss_disc / self.acc_grad_steps, retain_graph=True)
# accumulate gradients of N batches
if (batch_idx + 1) % self.acc_grad_steps == 0:
# clip gradients
self.clip_gradients(
opt_univnet,
gradient_clip_val=0.5,
gradient_clip_algorithm="norm",
)
self.clip_gradients(
opt_discriminator,
gradient_clip_val=0.5,
gradient_clip_algorithm="norm",
)
# optimizer step
opt_univnet.step()
opt_discriminator.step()
# Scheduler step
sch_univnet.step()
sch_discriminator.step()
# zero the gradients
opt_univnet.zero_grad()
opt_discriminator.zero_grad()
def configure_optimizers(self):
r"""Configures the optimizers and learning rate schedulers for the `UnivNet` and `Discriminator` models.
This method creates an `AdamW` optimizer and an `ExponentialLR` scheduler for each model.
The learning rate, betas, and decay rate for the optimizers and schedulers are taken from the training configuration.
Returns
tuple: A tuple containing two dictionaries. Each dictionary contains the optimizer and learning rate scheduler for one of the models.
Examples
```python
vocoder_module = VocoderModule()
optimizers = vocoder_module.configure_optimizers()
print(optimizers)
(
{"optimizer": <torch.optim.adamw.AdamW object at 0x7f8c0c0b3d90>, "lr_scheduler": <torch.optim.lr_scheduler.ExponentialLR object at 0x7f8c0c0b3e50>},
{"optimizer": <torch.optim.adamw.AdamW object at 0x7f8c0c0b3f10>, "lr_scheduler": <torch.optim.lr_scheduler.ExponentialLR object at 0x7f8c0c0b3fd0>}
)
```
"""
optim_univnet = AdamW(
self.univnet.parameters(),
self.train_config.learning_rate,
betas=(self.train_config.adam_b1, self.train_config.adam_b2),
)
scheduler_univnet = ExponentialLR(
optim_univnet,
gamma=self.train_config.lr_decay,
last_epoch=-1,
)
optim_discriminator = AdamW(
self.discriminator.parameters(),
self.train_config.learning_rate,
betas=(self.train_config.adam_b1, self.train_config.adam_b2),
)
scheduler_discriminator = ExponentialLR(
optim_discriminator,
gamma=self.train_config.lr_decay,
last_epoch=-1,
)
# NOTE: this code is used only for the v0.1.0 checkpoint.
# In the future, this code will be removed!
if self.checkpoint_path_v1 is not None:
_, _, optim_g, optim_d = self.get_weights_v1(self.checkpoint_path_v1)
optim_univnet.load_state_dict(optim_g)
optim_discriminator.load_state_dict(optim_d)
return (
{"optimizer": optim_univnet, "lr_scheduler": scheduler_univnet},
{"optimizer": optim_discriminator, "lr_scheduler": scheduler_discriminator},
)
def on_train_epoch_end(self):
r"""Updates the averaged model after each optimizer step with SWA."""
self.swa_averaged_univnet.update_parameters(self.univnet)
self.swa_averaged_discriminator.update_parameters(self.discriminator)
def on_train_end(self):
# Update SWA model after training
swa_utils.update_bn(self.train_dataloader(), self.swa_averaged_univnet)
swa_utils.update_bn(self.train_dataloader(), self.swa_averaged_discriminator)
def train_dataloader(
self,
num_workers: int = 5,
root: str = "datasets_cache/LIBRITTS",
cache: bool = True,
cache_dir: str = "datasets_cache",
mem_cache: bool = False,
url: str = "train-clean-360",
) -> DataLoader:
r"""Returns the training dataloader, that is using the LibriTTS dataset.
Args:
num_workers (int): The number of workers.
root (str): The root directory of the dataset.
cache (bool): Whether to cache the preprocessed data.
cache_dir (str): The directory for the cache.
mem_cache (bool): Whether to use memory cache.
url (str): The URL of the dataset.
Returns:
DataLoader: The training and validation dataloaders.
"""
return train_dataloader(
batch_size=self.batch_size,
num_workers=num_workers,
root=root,
cache=cache,
cache_dir=cache_dir,
mem_cache=mem_cache,
url=url,
lang=self.lang,
)