File size: 1,787 Bytes
14d1720
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import torch

from mtts.datasets.dataset import Tokenizer
from mtts.utils.logging import get_logger

logger = get_logger(__file__)


class TextProcessor():
    def __init__(self, config):
        conf = config['dataset']['train']
        self.emb_tokenizers = []
        for key in conf.keys():
            if key.startswith('emb_type'):
                emb_tok = Tokenizer(conf[key]['vocab'])
                self.emb_tokenizers += [emb_tok]
                logger.info('processed emb {}'.format(conf[key]['_name']))

    def _process(self, input: str):
        segments = input.split('|')
        name = segments[0]
        segments = segments[1:]
        if len(segments) != len(self.emb_tokenizers):
            raise ValueError('Input text and emb_tokensizers are different, {segments}')

        seg_lens = [len(s.split()) for s in segments]
        n = max(seg_lens)
        # for k in seg_lens:
        # if k != n and k != 1:
        # raise ValueError(f'Input segments should share the same length, but {k}!={n} for text {input}')

        segments = [' '.join((s.split() * n)[:n]) if len(s.split()) != n else s for s in segments]
        token_tensor = []
        for seg, tokenizer in zip(segments, self.emb_tokenizers):
            tokens = tokenizer.tokenize(seg)
            token_tensor.append(torch.unsqueeze(tokens, 0))
        token_tensor = torch.cat(token_tensor, 0)
        return name, token_tensor

    def __call__(self, input):
        return self._process(input)


if __name__ == '__main__':
    import yaml
    with open('./config.yaml') as f:
        config = yaml.safe_load(f)
    text_processer = TextProcessor(config)
    text = 'sil ni3 qu4 zuo4 fan4 ba5 sil|sil 你 去 做 饭 吧 sil|0 0 0 0 0 0 0'
    tensor = text_processer(text)
    print(tensor)