Spaces:
Running
Running
File size: 4,709 Bytes
26925fd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 |
from multiprocessing.pool import Pool
import matplotlib
from utils.pl_utils import data_loader
from utils.training_utils import RSQRTSchedule
from vocoders.base_vocoder import get_vocoder_cls, BaseVocoder
from modules.fastspeech.pe import PitchExtractor
matplotlib.use('Agg')
import os
import numpy as np
from tqdm import tqdm
import torch.distributed as dist
from tasks.base_task import BaseTask
from utils.hparams import hparams
from utils.text_encoder import TokenTextEncoder
import json
import torch
import torch.optim
import torch.utils.data
import utils
class TtsTask(BaseTask):
def __init__(self, *args, **kwargs):
self.vocoder = None
self.phone_encoder = self.build_phone_encoder(hparams['binary_data_dir'])
self.padding_idx = self.phone_encoder.pad()
self.eos_idx = self.phone_encoder.eos()
self.seg_idx = self.phone_encoder.seg()
self.saving_result_pool = None
self.saving_results_futures = None
self.stats = {}
super().__init__(*args, **kwargs)
def build_scheduler(self, optimizer):
return RSQRTSchedule(optimizer)
def build_optimizer(self, model):
self.optimizer = optimizer = torch.optim.AdamW(
model.parameters(),
lr=hparams['lr'])
return optimizer
def build_dataloader(self, dataset, shuffle, max_tokens=None, max_sentences=None,
required_batch_size_multiple=-1, endless=False, batch_by_size=True):
devices_cnt = torch.cuda.device_count()
if devices_cnt == 0:
devices_cnt = 1
if required_batch_size_multiple == -1:
required_batch_size_multiple = devices_cnt
def shuffle_batches(batches):
np.random.shuffle(batches)
return batches
if max_tokens is not None:
max_tokens *= devices_cnt
if max_sentences is not None:
max_sentences *= devices_cnt
indices = dataset.ordered_indices()
if batch_by_size:
batch_sampler = utils.batch_by_size(
indices, dataset.num_tokens, max_tokens=max_tokens, max_sentences=max_sentences,
required_batch_size_multiple=required_batch_size_multiple,
)
else:
batch_sampler = []
for i in range(0, len(indices), max_sentences):
batch_sampler.append(indices[i:i + max_sentences])
if shuffle:
batches = shuffle_batches(list(batch_sampler))
if endless:
batches = [b for _ in range(1000) for b in shuffle_batches(list(batch_sampler))]
else:
batches = batch_sampler
if endless:
batches = [b for _ in range(1000) for b in batches]
num_workers = dataset.num_workers
if self.trainer.use_ddp:
num_replicas = dist.get_world_size()
rank = dist.get_rank()
batches = [x[rank::num_replicas] for x in batches if len(x) % num_replicas == 0]
return torch.utils.data.DataLoader(dataset,
collate_fn=dataset.collater,
batch_sampler=batches,
num_workers=num_workers,
pin_memory=False)
def build_phone_encoder(self, data_dir):
phone_list_file = os.path.join(data_dir, 'phone_set.json')
phone_list = json.load(open(phone_list_file))
return TokenTextEncoder(None, vocab_list=phone_list, replace_oov=',')
def build_optimizer(self, model):
self.optimizer = optimizer = torch.optim.AdamW(
model.parameters(),
lr=hparams['lr'])
return optimizer
def test_start(self):
self.saving_result_pool = Pool(8)
self.saving_results_futures = []
self.vocoder: BaseVocoder = get_vocoder_cls(hparams)()
if hparams.get('pe_enable') is not None and hparams['pe_enable']:
self.pe = PitchExtractor().cuda()
utils.load_ckpt(self.pe, hparams['pe_ckpt'], 'model', strict=True)
self.pe.eval()
def test_end(self, outputs):
self.saving_result_pool.close()
[f.get() for f in tqdm(self.saving_results_futures)]
self.saving_result_pool.join()
return {}
##########
# utils
##########
def weights_nonzero_speech(self, target):
# target : B x T x mel
# Assign weight 1.0 to all labels except for padding (id=0).
dim = target.size(-1)
return target.abs().sum(-1, keepdim=True).ne(0).float().repeat(1, 1, dim)
if __name__ == '__main__':
TtsTask.start()
|