|
from multiprocessing.pool import Pool |
|
|
|
import matplotlib |
|
|
|
from utils.pl_utils import data_loader |
|
from utils.training_utils import RSQRTSchedule |
|
from network.vocoders.base_vocoder import get_vocoder_cls, BaseVocoder |
|
from modules.fastspeech.pe import PitchExtractor |
|
|
|
matplotlib.use('Agg') |
|
import os |
|
import numpy as np |
|
from tqdm import tqdm |
|
import torch.distributed as dist |
|
|
|
from training.task.base_task import BaseTask |
|
from utils.hparams import hparams |
|
from utils.text_encoder import TokenTextEncoder |
|
import json |
|
from preprocessing.hubertinfer import Hubertencoder |
|
import torch |
|
import torch.optim |
|
import torch.utils.data |
|
import utils |
|
|
|
|
|
|
|
class TtsTask(BaseTask): |
|
def __init__(self, *args, **kwargs): |
|
self.vocoder = None |
|
self.phone_encoder = Hubertencoder(hparams['hubert_path']) |
|
|
|
|
|
|
|
self.saving_result_pool = None |
|
self.saving_results_futures = None |
|
self.stats = {} |
|
super().__init__(*args, **kwargs) |
|
|
|
def build_scheduler(self, optimizer): |
|
return RSQRTSchedule(optimizer) |
|
|
|
def build_optimizer(self, model): |
|
self.optimizer = optimizer = torch.optim.AdamW( |
|
model.parameters(), |
|
lr=hparams['lr']) |
|
return optimizer |
|
|
|
def build_dataloader(self, dataset, shuffle, max_tokens=None, max_sentences=None, |
|
required_batch_size_multiple=-1, endless=False, batch_by_size=True): |
|
devices_cnt = torch.cuda.device_count() |
|
if devices_cnt == 0: |
|
devices_cnt = 1 |
|
if required_batch_size_multiple == -1: |
|
required_batch_size_multiple = devices_cnt |
|
|
|
def shuffle_batches(batches): |
|
np.random.shuffle(batches) |
|
return batches |
|
|
|
if max_tokens is not None: |
|
max_tokens *= devices_cnt |
|
if max_sentences is not None: |
|
max_sentences *= devices_cnt |
|
indices = dataset.ordered_indices() |
|
if batch_by_size: |
|
batch_sampler = utils.batch_by_size( |
|
indices, dataset.num_tokens, max_tokens=max_tokens, max_sentences=max_sentences, |
|
required_batch_size_multiple=required_batch_size_multiple, |
|
) |
|
else: |
|
batch_sampler = [] |
|
for i in range(0, len(indices), max_sentences): |
|
batch_sampler.append(indices[i:i + max_sentences]) |
|
|
|
if shuffle: |
|
batches = shuffle_batches(list(batch_sampler)) |
|
if endless: |
|
batches = [b for _ in range(1000) for b in shuffle_batches(list(batch_sampler))] |
|
else: |
|
batches = batch_sampler |
|
if endless: |
|
batches = [b for _ in range(1000) for b in batches] |
|
num_workers = dataset.num_workers |
|
if self.trainer.use_ddp: |
|
num_replicas = dist.get_world_size() |
|
rank = dist.get_rank() |
|
batches = [x[rank::num_replicas] for x in batches if len(x) % num_replicas == 0] |
|
return torch.utils.data.DataLoader(dataset, |
|
collate_fn=dataset.collater, |
|
batch_sampler=batches, |
|
num_workers=num_workers, |
|
pin_memory=False) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def build_optimizer(self, model): |
|
self.optimizer = optimizer = torch.optim.AdamW( |
|
model.parameters(), |
|
lr=hparams['lr']) |
|
return optimizer |
|
|
|
def test_start(self): |
|
self.saving_result_pool = Pool(8) |
|
self.saving_results_futures = [] |
|
self.vocoder: BaseVocoder = get_vocoder_cls(hparams)() |
|
if hparams.get('pe_enable') is not None and hparams['pe_enable']: |
|
self.pe = PitchExtractor().cuda() |
|
utils.load_ckpt(self.pe, hparams['pe_ckpt'], 'model', strict=True) |
|
self.pe.eval() |
|
def test_end(self, outputs): |
|
self.saving_result_pool.close() |
|
[f.get() for f in tqdm(self.saving_results_futures)] |
|
self.saving_result_pool.join() |
|
return {} |
|
|
|
|
|
|
|
|
|
def weights_nonzero_speech(self, target): |
|
|
|
|
|
dim = target.size(-1) |
|
return target.abs().sum(-1, keepdim=True).ne(0).float().repeat(1, 1, dim) |
|
|
|
if __name__ == '__main__': |
|
TtsTask.start() |
|
|