import argparse import os import numpy as np import torch import yaml from torch.optim import SGD, Adam from torch.optim.lr_scheduler import CyclicLR, OneCycleLR, ReduceLROnPlateau from torch.utils.data import BatchSampler, DataLoader, SubsetRandomSampler from torch.utils.tensorboard import SummaryWriter from mtts.datasets.dataset import Dataset, collate_fn from mtts.loss import FS2Loss from mtts.models.fs2_model import FastSpeech2 from mtts.optimizer import ScheduledOptim from mtts.utils.logging import get_logger from mtts.utils.utils import save_image logger = get_logger(__file__) class AverageMeter: def __init__(self): self.mel_loss_v = 0.0 self.posnet_loss_v = 0.0 self.d_loss_v = 0.0 self.total_loss_v = 0.0 self._i = 0 def update(self, mel_loss, posnet_loss, d_loss, total_loss): self.mel_loss_v = ((self.mel_loss_v * self._i) + mel_loss.item()) / (self._i + 1) self.posnet_loss_v = ((self.posnet_loss_v * self._i) + posnet_loss.item()) / (self._i + 1) self.d_loss_v = ((self.d_loss_v * self._i) + d_loss.item()) / (self._i + 1) self.total_loss_v = ((self.total_loss_v * self._i) + total_loss.item()) / (self._i + 1) self._i += 1 return self.mel_loss_v, self.posnet_loss_v, self.d_loss_v, self.total_loss_v def split_batch(data, i, n_split): n = data[1].shape[0] k = n // n_split ds = [d[:, i * k:(i + 1) * k] if j == 0 else d[i * k:(i + 1) * k] for j, d in enumerate(data)] return ds def shuffle(data): n = data[1].shape[0] idx = np.random.permutation(n) data_shuffled = [d[:, idx] if i == 0 else d[idx] for i, d in enumerate(data)] return data_shuffled if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('-r', '--restore', type=str, default='') parser.add_argument('-c', '--config', type=str, default='./config.yaml') parser.add_argument('-d', '--device', type=str, default='cuda') args = parser.parse_args() device = args.device logger.info(f'using device {device}') with open(args.config) as f: config = yaml.safe_load(f) logger.info(f.read()) dataset = Dataset(config) dataloader = DataLoader(dataset, batch_size=config['training']['batch_size'], shuffle=False, collate_fn=collate_fn, drop_last=False, num_workers=config['training']['num_workers']) step_per_epoch = len(dataloader) * config['training']['batch_size'] model = FastSpeech2(config) model = model.to(args.device) #model.encoder.emb_layers.to(device) # ? optim_conf = config['optimizer'] optim_class = eval(optim_conf['type']) logger.info(optim_conf['params']) optimizer = optim_class(model.parameters(), **optim_conf['params']) if args.restore != '': logger.info(f'Loading checkpoint {args.restore}') content = torch.load(args.restore) model.load_state_dict(content['model']) optimizer.load_state_dict(content['optimizer']) current_step = content['step'] start_epoch = current_step // step_per_epoch logger.info(f'loaded checkpoint at step {current_step}, epoch {start_epoch}') else: current_step = 0 start_epoch = 0 logger.info(f'Start training from scratch,step={current_step},epoch={start_epoch}') lrs = np.linspace(0, optim_conf['params']['lr'], optim_conf['n_warm_up_step']) Scheduler = eval(config['lr_scheduler']['type']) lr_scheduler = Scheduler(optimizer, **config['lr_scheduler']['params']) loss_fn = FS2Loss().to(device) train_logger = SummaryWriter(config['training']['log_path']) val_logger = SummaryWriter(config['training']['log_path']) avg = AverageMeter() for epoch in range(start_epoch, config['training']['epochs']): model.train() for i, data in enumerate(dataloader): data = shuffle(data) max_src_len = torch.max(data[-2]) max_mel_len = torch.max(data[-1]) for k in range(config['training']['batch_split']): data_split = split_batch(data, k, config['training']['batch_split']) tokens, duration, mel_truth, seq_len, mel_len = data_split #print(mel_len) tokens = tokens.to(device) duration = duration.to(device) mel_truth = mel_truth.to(device) seq_len = seq_len.to(device) mel_len = mel_len.to(device) # if torch.max(log_D) > 50: # logger.info('skipping sample') # continue mel_truth = mel_truth - config['fbank']['mel_mean'] duration = duration - config['duration_predictor']['duration_mean'] output = model(tokens, seq_len, mel_len, duration, max_src_len=max_src_len, max_mel_len=max_mel_len) mel_pred, mel_postnet, d_pred, src_mask, mel_mask, mel_len = output mel_loss, mel_postnet_loss, d_loss = loss_fn(d_pred, duration, mel_pred, mel_postnet, mel_truth, ~src_mask, ~mel_mask) total_loss = mel_postnet_loss + d_loss + mel_loss ml, pl, dl, tl = avg.update(mel_loss, mel_postnet_loss, d_loss, total_loss) lr = optimizer.param_groups[0]['lr'] msg = f'epoch:{epoch},step:{current_step}|{step_per_epoch},loss:{tl:.3},mel:{ml:.3},' msg += f'mel_postnet:{pl:.3},duration:{dl:.3},{lr:.3}' if current_step % config['training']['log_step'] == 0: logger.info(msg) total_loss = total_loss / config['training']['acc_step'] total_loss.backward() if current_step % config['training']['acc_step'] != 0: continue current_step += 1 if current_step < config['optimizer']['n_warm_up_step']: lr = lrs[current_step] optimizer.param_groups[0]['lr'] = lr optimizer.step() optimizer.zero_grad() else: optimizer.step() lr_scheduler.step() optimizer.zero_grad() if current_step % config['training']['synth_step'] == 0: mel_pred = mel_pred.detach().cpu().numpy() mel_truth = mel_truth.detach().cpu().numpy() saved_path = os.path.join(config['training']['log_path'], f'{current_step}.png') save_image(mel_truth[0][:mel_len[0]], mel_pred[0][:mel_len[0]], saved_path) np.save(saved_path + '.npy', mel_pred[0]) if current_step % config['training']['log_step'] == 0: train_logger.add_scalar('total_loss', tl, current_step) train_logger.add_scalar('mel_loss', ml, current_step) train_logger.add_scalar('mel_postnet_loss', pl, current_step) train_logger.add_scalar('duration_loss', dl, current_step) if current_step % config['training']['checkpoint_step'] == 0: if not os.path.exists(config['training']['checkpoint_path']): os.makedirs(config['training']['checkpoint_path']) ckpt_file = os.path.join(config['training']['checkpoint_path'], 'checkpoint_{}.pth.tar'.format(current_step)) content = {'model': model.state_dict(), 'optimizer': optimizer.state_dict(), 'step': current_step} torch.save(content, ckpt_file) logger.info(f'Saved model at step {current_step} to {ckpt_file}') logger.info(f"End of training for epoch {config['training']['epochs']}")