Spaces:
Sleeping
Sleeping
import torch | |
from mtts.datasets.dataset import Tokenizer | |
from mtts.utils.logging import get_logger | |
logger = get_logger(__file__) | |
class TextProcessor(): | |
def __init__(self, config): | |
conf = config['dataset']['train'] | |
self.emb_tokenizers = [] | |
for key in conf.keys(): | |
if key.startswith('emb_type'): | |
emb_tok = Tokenizer(conf[key]['vocab']) | |
self.emb_tokenizers += [emb_tok] | |
logger.info('processed emb {}'.format(conf[key]['_name'])) | |
def _process(self, input: str): | |
segments = input.split('|') | |
name = segments[0] | |
segments = segments[1:] | |
if len(segments) != len(self.emb_tokenizers): | |
raise ValueError('Input text and emb_tokensizers are different, {segments}') | |
seg_lens = [len(s.split()) for s in segments] | |
n = max(seg_lens) | |
# for k in seg_lens: | |
# if k != n and k != 1: | |
# raise ValueError(f'Input segments should share the same length, but {k}!={n} for text {input}') | |
segments = [' '.join((s.split() * n)[:n]) if len(s.split()) != n else s for s in segments] | |
token_tensor = [] | |
for seg, tokenizer in zip(segments, self.emb_tokenizers): | |
tokens = tokenizer.tokenize(seg) | |
token_tensor.append(torch.unsqueeze(tokens, 0)) | |
token_tensor = torch.cat(token_tensor, 0) | |
return name, token_tensor | |
def __call__(self, input): | |
return self._process(input) | |
if __name__ == '__main__': | |
import yaml | |
with open('./config.yaml') as f: | |
config = yaml.safe_load(f) | |
text_processer = TextProcessor(config) | |
text = 'sil ni3 qu4 zuo4 fan4 ba5 sil|sil 你 去 做 饭 吧 sil|0 0 0 0 0 0 0' | |
tensor = text_processer(text) | |
print(tensor) | |