text_to_speech / mtts /text /text_processor.py
wuxulong19950206
First model version
14d1720
import torch
from mtts.datasets.dataset import Tokenizer
from mtts.utils.logging import get_logger
logger = get_logger(__file__)
class TextProcessor():
def __init__(self, config):
conf = config['dataset']['train']
self.emb_tokenizers = []
for key in conf.keys():
if key.startswith('emb_type'):
emb_tok = Tokenizer(conf[key]['vocab'])
self.emb_tokenizers += [emb_tok]
logger.info('processed emb {}'.format(conf[key]['_name']))
def _process(self, input: str):
segments = input.split('|')
name = segments[0]
segments = segments[1:]
if len(segments) != len(self.emb_tokenizers):
raise ValueError('Input text and emb_tokensizers are different, {segments}')
seg_lens = [len(s.split()) for s in segments]
n = max(seg_lens)
# for k in seg_lens:
# if k != n and k != 1:
# raise ValueError(f'Input segments should share the same length, but {k}!={n} for text {input}')
segments = [' '.join((s.split() * n)[:n]) if len(s.split()) != n else s for s in segments]
token_tensor = []
for seg, tokenizer in zip(segments, self.emb_tokenizers):
tokens = tokenizer.tokenize(seg)
token_tensor.append(torch.unsqueeze(tokens, 0))
token_tensor = torch.cat(token_tensor, 0)
return name, token_tensor
def __call__(self, input):
return self._process(input)
if __name__ == '__main__':
import yaml
with open('./config.yaml') as f:
config = yaml.safe_load(f)
text_processer = TextProcessor(config)
text = 'sil ni3 qu4 zuo4 fan4 ba5 sil|sil 你 去 做 饭 吧 sil|0 0 0 0 0 0 0'
tensor = text_processer(text)
print(tensor)