from dataclasses import dataclass, field from typing import Dict, List, Tuple, Union import torch import torch.nn as nn import torchaudio from coqpit import Coqpit from torch.nn import functional as F from torch.utils.data import DataLoader from trainer.torch import DistributedSampler from trainer.trainer_utils import get_optimizer, get_scheduler from TTS.tts.configs.xtts_config import XttsConfig from TTS.tts.datasets.dataset import TTSDataset from TTS.tts.layers.tortoise.arch_utils import TorchMelSpectrogram from TTS.tts.layers.xtts.dvae import DiscreteVAE from TTS.tts.layers.xtts.tokenizer import VoiceBpeTokenizer from TTS.tts.layers.xtts.trainer.dataset import XTTSDataset from TTS.tts.models.base_tts import BaseTTS from TTS.tts.models.xtts import Xtts, XttsArgs, XttsAudioConfig from TTS.utils.io import load_fsspec @dataclass class GPTTrainerConfig(XttsConfig): lr: float = 5e-06 training_seed: int = 1 optimizer_wd_only_on_weights: bool = False weighted_loss_attrs: dict = field(default_factory=lambda: {}) weighted_loss_multipliers: dict = field(default_factory=lambda: {}) test_sentences: List[dict] = field(default_factory=lambda: []) @dataclass class XttsAudioConfig(XttsAudioConfig): dvae_sample_rate: int = 22050 @dataclass class GPTArgs(XttsArgs): min_conditioning_length: int = 66150 max_conditioning_length: int = 132300 gpt_loss_text_ce_weight: float = 0.01 gpt_loss_mel_ce_weight: float = 1.0 gpt_num_audio_tokens: int = 8194 debug_loading_failures: bool = False max_wav_length: int = 255995 # ~11.6 seconds max_text_length: int = 200 tokenizer_file: str = "" mel_norm_file: str = "https://coqui.gateway.scarf.sh/v0.14.0_models/mel_norms.pth" dvae_checkpoint: str = "" xtts_checkpoint: str = "" gpt_checkpoint: str = "" # if defined it will replace the gpt weights on xtts model vocoder: str = "" # overide vocoder key on the config to avoid json write issues def callback_clearml_load_save(operation_type, model_info): # return None means skip the file upload/log, returning model_info will continue with the log/upload # you can also change the upload destination file name model_info.upload_filename or check the local file size with Path(model_info.local_model_path).stat().st_size assert operation_type in ("load", "save") # print(operation_type, model_info.__dict__) if "similarities.pth" in model_info.__dict__["local_model_path"]: return None return model_info class GPTTrainer(BaseTTS): def __init__(self, config: Coqpit): """ Tortoise GPT training class """ super().__init__(config, ap=None, tokenizer=None) self.config = config # init XTTS model self.xtts = Xtts(self.config) # create the tokenizer with the target vocabulary self.xtts.tokenizer = VoiceBpeTokenizer(self.args.tokenizer_file) # init gpt encoder and hifigan decoder self.xtts.init_models() if self.args.xtts_checkpoint: self.load_checkpoint(self.config, self.args.xtts_checkpoint, eval=False, strict=False) # set mel stats if self.args.mel_norm_file: self.xtts.mel_stats = load_fsspec(self.args.mel_norm_file) # load GPT if available if self.args.gpt_checkpoint: gpt_checkpoint = torch.load(self.args.gpt_checkpoint, map_location=torch.device("cpu")) # deal with coqui Trainer exported model if "model" in gpt_checkpoint.keys() and "config" in gpt_checkpoint.keys(): print("Coqui Trainer checkpoint detected! Converting it!") gpt_checkpoint = gpt_checkpoint["model"] states_keys = list(gpt_checkpoint.keys()) for key in states_keys: if "gpt." in key: new_key = key.replace("gpt.", "") gpt_checkpoint[new_key] = gpt_checkpoint[key] del gpt_checkpoint[key] else: del gpt_checkpoint[key] # edit checkpoint if the number of tokens is changed to ensures the better transfer learning possible if ( "text_embedding.weight" in gpt_checkpoint and gpt_checkpoint["text_embedding.weight"].shape != self.xtts.gpt.text_embedding.weight.shape ): num_new_tokens = ( self.xtts.gpt.text_embedding.weight.shape[0] - gpt_checkpoint["text_embedding.weight"].shape[0] ) print(f" > Loading checkpoint with {num_new_tokens} additional tokens.") # add new tokens to a linear layer (text_head) emb_g = gpt_checkpoint["text_embedding.weight"] new_row = torch.randn(num_new_tokens, emb_g.shape[1]) start_token_row = emb_g[-1, :] emb_g = torch.cat([emb_g, new_row], axis=0) emb_g[-1, :] = start_token_row gpt_checkpoint["text_embedding.weight"] = emb_g # add new weights to the linear layer (text_head) text_head_weight = gpt_checkpoint["text_head.weight"] start_token_row = text_head_weight[-1, :] new_entry = torch.randn(num_new_tokens, self.xtts.gpt.text_head.weight.shape[1]) text_head_weight = torch.cat([text_head_weight, new_entry], axis=0) text_head_weight[-1, :] = start_token_row gpt_checkpoint["text_head.weight"] = text_head_weight # add new biases to the linear layer (text_head) text_head_bias = gpt_checkpoint["text_head.bias"] start_token_row = text_head_bias[-1] new_bias_entry = torch.zeros(num_new_tokens) text_head_bias = torch.cat([text_head_bias, new_bias_entry], axis=0) text_head_bias[-1] = start_token_row gpt_checkpoint["text_head.bias"] = text_head_bias self.xtts.gpt.load_state_dict(gpt_checkpoint, strict=True) print(">> GPT weights restored from:", self.args.gpt_checkpoint) # Mel spectrogram extractor for conditioning if self.args.gpt_use_perceiver_resampler: self.torch_mel_spectrogram_style_encoder = TorchMelSpectrogram( filter_length=2048, hop_length=256, win_length=1024, normalize=False, sampling_rate=config.audio.sample_rate, mel_fmin=0, mel_fmax=8000, n_mel_channels=80, mel_norm_file=self.args.mel_norm_file, ) else: self.torch_mel_spectrogram_style_encoder = TorchMelSpectrogram( filter_length=4096, hop_length=1024, win_length=4096, normalize=False, sampling_rate=config.audio.sample_rate, mel_fmin=0, mel_fmax=8000, n_mel_channels=80, mel_norm_file=self.args.mel_norm_file, ) # Load DVAE self.dvae = DiscreteVAE( channels=80, normalization=None, positional_dims=1, num_tokens=self.args.gpt_num_audio_tokens - 2, codebook_dim=512, hidden_dim=512, num_resnet_blocks=3, kernel_size=3, num_layers=2, use_transposed_convs=False, ) self.dvae.eval() if self.args.dvae_checkpoint: dvae_checkpoint = torch.load(self.args.dvae_checkpoint, map_location=torch.device("cpu")) self.dvae.load_state_dict(dvae_checkpoint, strict=False) print(">> DVAE weights restored from:", self.args.dvae_checkpoint) else: raise RuntimeError( "You need to specify config.model_args.dvae_checkpoint path to be able to train the GPT decoder!!" ) # Mel spectrogram extractor for DVAE self.torch_mel_spectrogram_dvae = TorchMelSpectrogram( mel_norm_file=self.args.mel_norm_file, sampling_rate=config.audio.dvae_sample_rate ) @property def device(self): return next(self.parameters()).device def forward(self, text_inputs, text_lengths, audio_codes, wav_lengths, cond_mels, cond_idxs, cond_lens): """ Forward pass that uses both text and voice in either text conditioning mode or voice conditioning mode (actuated by `text_first`). text_inputs: long tensor, (b,t) text_lengths: long tensor, (b,) mel_inputs: long tensor, (b,m) wav_lengths: long tensor, (b,) cond_mels: MEL float tensor, (b, num_samples, 80,t_m) cond_idxs: cond start and end indexs, (b, 2) cond_lens: long tensor, (b,) """ losses = self.xtts.gpt( text_inputs, text_lengths, audio_codes, wav_lengths, cond_mels=cond_mels, cond_idxs=cond_idxs, cond_lens=cond_lens, ) return losses @torch.no_grad() def test_run(self, assets) -> Tuple[Dict, Dict]: # pylint: disable=W0613 if self.config.test_sentences: # init gpt for inference mode self.xtts.gpt.init_gpt_for_inference(kv_cache=self.args.kv_cache, use_deepspeed=False) self.xtts.gpt.eval() test_audios = {} print(" | > Synthesizing test sentences.") for idx, s_info in enumerate(self.config.test_sentences): wav = self.xtts.synthesize( s_info["text"], self.config, s_info["speaker_wav"], s_info["language"], gpt_cond_len=3, )["wav"] test_audios["{}-audio".format(idx)] = wav # delete inference layers del self.xtts.gpt.gpt_inference del self.xtts.gpt.gpt.wte return {"audios": test_audios} def test_log( self, outputs: dict, logger: "Logger", assets: dict, steps: int # pylint: disable=unused-argument ) -> None: logger.test_audios(steps, outputs["audios"], self.args.output_sample_rate) def format_batch(self, batch: Dict) -> Dict: return batch @torch.no_grad() # torch no grad to avoid gradients from the pre-processing and DVAE codes extraction def format_batch_on_device(self, batch): """Compute spectrograms on the device.""" batch["text_lengths"] = batch["text_lengths"] batch["wav_lengths"] = batch["wav_lengths"] batch["text_inputs"] = batch["padded_text"] batch["cond_idxs"] = batch["cond_idxs"] # compute conditioning mel specs # transform waves from torch.Size([B, num_cond_samples, 1, T] to torch.Size([B * num_cond_samples, 1, T] because if is faster than iterate the tensor B, num_cond_samples, C, T = batch["conditioning"].size() conditioning_reshaped = batch["conditioning"].view(B * num_cond_samples, C, T) paired_conditioning_mel = self.torch_mel_spectrogram_style_encoder(conditioning_reshaped) # transform torch.Size([B * num_cond_samples, n_mel, T_mel]) in torch.Size([B, num_cond_samples, n_mel, T_mel]) n_mel = self.torch_mel_spectrogram_style_encoder.n_mel_channels # paired_conditioning_mel.size(1) T_mel = paired_conditioning_mel.size(2) paired_conditioning_mel = paired_conditioning_mel.view(B, num_cond_samples, n_mel, T_mel) # get the conditioning embeddings batch["cond_mels"] = paired_conditioning_mel # compute codes using DVAE if self.config.audio.sample_rate != self.config.audio.dvae_sample_rate: dvae_wav = torchaudio.functional.resample( batch["wav"], orig_freq=self.config.audio.sample_rate, new_freq=self.config.audio.dvae_sample_rate, lowpass_filter_width=64, rolloff=0.9475937167399596, resampling_method="kaiser_window", beta=14.769656459379492, ) else: dvae_wav = batch["wav"] dvae_mel_spec = self.torch_mel_spectrogram_dvae(dvae_wav) codes = self.dvae.get_codebook_indices(dvae_mel_spec) batch["audio_codes"] = codes # delete useless batch tensors del batch["padded_text"] del batch["wav"] del batch["conditioning"] return batch def train_step(self, batch, criterion): loss_dict = {} cond_mels = batch["cond_mels"] text_inputs = batch["text_inputs"] text_lengths = batch["text_lengths"] audio_codes = batch["audio_codes"] wav_lengths = batch["wav_lengths"] cond_idxs = batch["cond_idxs"] cond_lens = batch["cond_lens"] loss_text, loss_mel, _ = self.forward( text_inputs, text_lengths, audio_codes, wav_lengths, cond_mels, cond_idxs, cond_lens ) loss_dict["loss_text_ce"] = loss_text * self.args.gpt_loss_text_ce_weight loss_dict["loss_mel_ce"] = loss_mel * self.args.gpt_loss_mel_ce_weight loss_dict["loss"] = loss_dict["loss_text_ce"] + loss_dict["loss_mel_ce"] return {"model_outputs": None}, loss_dict def eval_step(self, batch, criterion): # ignore masking for more consistent evaluation batch["cond_idxs"] = None return self.train_step(batch, criterion) def on_train_epoch_start(self, trainer): trainer.model.eval() # the whole model to eval # put gpt model in training mode trainer.model.xtts.gpt.train() def on_init_end(self, trainer): # pylint: disable=W0613 # ignore similarities.pth on clearml save/upload if self.config.dashboard_logger.lower() == "clearml": from clearml.binding.frameworks import WeightsFileHandler WeightsFileHandler.add_pre_callback(callback_clearml_load_save) @torch.no_grad() def inference( self, x, aux_input=None, ): # pylint: disable=dangerous-default-value return None @staticmethod def get_criterion(): return None def get_sampler(self, dataset: TTSDataset, num_gpus=1): # sampler for DDP batch_sampler = DistributedSampler(dataset) if num_gpus > 1 else None return batch_sampler def get_data_loader( self, config: Coqpit, assets: Dict, is_eval: bool, samples: Union[List[Dict], List[List]], verbose: bool, num_gpus: int, rank: int = None, ) -> "DataLoader": # pylint: disable=W0613 if is_eval and not config.run_eval: loader = None else: # init dataloader dataset = XTTSDataset(self.config, samples, self.xtts.tokenizer, config.audio.sample_rate, is_eval) # wait all the DDP process to be ready if num_gpus > 1: torch.distributed.barrier() # sort input sequences from short to long # dataset.preprocess_samples() # get samplers sampler = self.get_sampler(dataset, num_gpus) # ignore sampler when is eval because if we changed the sampler parameter we will not be able to compare previous runs if sampler is None or is_eval: loader = DataLoader( dataset, batch_size=config.eval_batch_size if is_eval else config.batch_size, shuffle=False, drop_last=False, collate_fn=dataset.collate_fn, num_workers=config.num_eval_loader_workers if is_eval else config.num_loader_workers, pin_memory=False, ) else: loader = DataLoader( dataset, batch_sampler=sampler, collate_fn=dataset.collate_fn, num_workers=config.num_eval_loader_workers if is_eval else config.num_loader_workers, pin_memory=False, ) return loader def get_optimizer(self) -> List: """Initiate and return the optimizer based on the config parameters.""" # ToDo: deal with multi GPU training if self.config.optimizer_wd_only_on_weights: # parameters to only GPT model net = self.xtts.gpt # normalizations norm_modules = ( nn.BatchNorm2d, nn.InstanceNorm2d, nn.BatchNorm1d, nn.InstanceNorm1d, nn.BatchNorm3d, nn.InstanceNorm3d, nn.GroupNorm, nn.LayerNorm, ) # nn.Embedding emb_modules = (nn.Embedding, nn.EmbeddingBag) param_names_notweights = set() all_param_names = set() param_map = {} for mn, m in net.named_modules(): for k, v in m.named_parameters(): v.is_bias = k.endswith(".bias") v.is_weight = k.endswith(".weight") v.is_norm = isinstance(m, norm_modules) v.is_emb = isinstance(m, emb_modules) fpn = "%s.%s" % (mn, k) if mn else k # full param name all_param_names.add(fpn) param_map[fpn] = v if v.is_bias or v.is_norm or v.is_emb: param_names_notweights.add(fpn) params_names_notweights = sorted(list(param_names_notweights)) params_notweights = [param_map[k] for k in params_names_notweights] params_names_weights = sorted(list(all_param_names ^ param_names_notweights)) params_weights = [param_map[k] for k in params_names_weights] groups = [ {"params": params_weights, "weight_decay": self.config.optimizer_params["weight_decay"]}, {"params": params_notweights, "weight_decay": 0}, ] # torch.optim.AdamW opt = get_optimizer( self.config.optimizer, self.config.optimizer_params, self.config.lr, parameters=groups, ) opt._group_names = [params_names_weights, params_names_notweights] return opt return get_optimizer( self.config.optimizer, self.config.optimizer_params, self.config.lr, # optimize only for the GPT model parameters=self.xtts.gpt.parameters(), ) def get_scheduler(self, optimizer) -> List: """Set the scheduler for the optimizer. Args: optimizer: `torch.optim.Optimizer`. """ return get_scheduler(self.config.lr_scheduler, self.config.lr_scheduler_params, optimizer) def load_checkpoint( self, config, checkpoint_path, eval=False, strict=True, cache_storage="/tmp/tts_cache", target_protocol="s3", target_options={"anon": True}, ): # pylint: disable=unused-argument, disable=W0201, disable=W0102, redefined-builtin """Load the model checkpoint and setup for training or inference""" state = self.xtts.get_compatible_checkpoint_state_dict(checkpoint_path) # load the model weights self.xtts.load_state_dict(state, strict=strict) if eval: self.xtts.gpt.init_gpt_for_inference(kv_cache=self.args.kv_cache, use_deepspeed=False) self.eval() assert not self.training @staticmethod def init_from_config(config: "GPTTrainerConfig", samples: Union[List[List], List[Dict]] = None): """Initiate model from config Args: config (GPTTrainerConfig): Model config. samples (Union[List[List], List[Dict]]): Training samples to parse speaker ids for training. Defaults to None. """ return GPTTrainer(config)