File size: 2,940 Bytes
3b2b066
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import json
from collections import Counter
import jieba
import os
abs_path = os.path.dirname(os.path.abspath(__file__))
jieba.load_userdict(os.path.join(abs_path, '../../utils/key_technical_words.txt'))

class Tokenizer(object):
    def __init__(self, args):
        self.ann_path = args.ann_path
        self.threshold = args.threshold
        self.ann = json.loads(open(self.ann_path, 'r', encoding='utf-8-sig').read())
        self.dict_pth = args.dict_pth
        self.token2idx, self.idx2token = self.create_vocabulary()

    def create_vocabulary(self):
        if self.dict_pth != ' ':
            word_dict = json.loads(open(self.dict_pth, 'r', encoding="utf_8_sig").read())
            word_dict[1] = {int(k): v for k, v in word_dict[1].items()}
            return word_dict[0], word_dict[1]
        else:
            total_tokens = []
            split_list = ['train', 'test', 'val']
            for split in split_list:
                for example in self.ann[split]:
                    tokens = list(jieba.lcut(example['finding']))
                    for token in tokens:
                        total_tokens.append(token)
            counter = Counter(total_tokens)
            vocab = [k for k, v in counter.items()] + ['<unk>']
            token2idx, idx2token = {}, {}
            for idx, token in enumerate(vocab):
                token2idx[token] = idx + 1
                idx2token[idx + 1] = token
            with open('E:/Captionv0/Code/SGF/utils/breast_dict.txt', 'w', encoding='utf-8-sig') as f:
                f.write(json.dumps([token2idx, idx2token]))
            return token2idx, idx2token


    def get_token_by_id(self, id):
        return self.idx2token[id]

    def get_id_by_token(self, token):
        if token not in self.token2idx:
            return self.token2idx['<unk>']
        return self.token2idx[token]

    def get_vocab_size(self):
        return len(self.token2idx)


    def __call__(self, report):
        tokens = list(jieba.cut(report))
        ids = []
        for token in tokens:
            ids.append(self.get_id_by_token(token))
        ids = [0] + ids + [0]
        return ids

    def decode(self, ids):
        txt = ''
        for i, idx in enumerate(ids):
            if idx > 0:
                if i >= 1:
                    txt += ' '
                txt += self.idx2token[idx]
            else:
                break
        return txt

    def decode_list(self, ids):
        txt = []
        for i, idx in enumerate(ids):
            if idx > 0:
                txt.append(self.idx2token[idx])
            else:txt.append('<start/end>')

        return txt

    def decode_batch(self, ids_batch):
        out = []
        for ids in ids_batch:
            out.append(self.decode(ids))
        return out

    def decode_batch_list(self, ids_batch):
        out = []
        for ids in ids_batch:
            out.append(self.decode_list(ids))
        return out