Spaces:
Sleeping
Sleeping
File size: 1,787 Bytes
14d1720 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 |
import torch
from mtts.datasets.dataset import Tokenizer
from mtts.utils.logging import get_logger
logger = get_logger(__file__)
class TextProcessor():
def __init__(self, config):
conf = config['dataset']['train']
self.emb_tokenizers = []
for key in conf.keys():
if key.startswith('emb_type'):
emb_tok = Tokenizer(conf[key]['vocab'])
self.emb_tokenizers += [emb_tok]
logger.info('processed emb {}'.format(conf[key]['_name']))
def _process(self, input: str):
segments = input.split('|')
name = segments[0]
segments = segments[1:]
if len(segments) != len(self.emb_tokenizers):
raise ValueError('Input text and emb_tokensizers are different, {segments}')
seg_lens = [len(s.split()) for s in segments]
n = max(seg_lens)
# for k in seg_lens:
# if k != n and k != 1:
# raise ValueError(f'Input segments should share the same length, but {k}!={n} for text {input}')
segments = [' '.join((s.split() * n)[:n]) if len(s.split()) != n else s for s in segments]
token_tensor = []
for seg, tokenizer in zip(segments, self.emb_tokenizers):
tokens = tokenizer.tokenize(seg)
token_tensor.append(torch.unsqueeze(tokens, 0))
token_tensor = torch.cat(token_tensor, 0)
return name, token_tensor
def __call__(self, input):
return self._process(input)
if __name__ == '__main__':
import yaml
with open('./config.yaml') as f:
config = yaml.safe_load(f)
text_processer = TextProcessor(config)
text = 'sil ni3 qu4 zuo4 fan4 ba5 sil|sil 你 去 做 饭 吧 sil|0 0 0 0 0 0 0'
tensor = text_processer(text)
print(tensor)
|