import json from collections import Counter import jieba import os abs_path = os.path.dirname(os.path.abspath(__file__)) jieba.load_userdict(os.path.join(abs_path, '../../utils/key_technical_words.txt')) class Tokenizer(object): def __init__(self, args): self.ann_path = args.ann_path self.threshold = args.threshold self.ann = json.loads(open(self.ann_path, 'r', encoding='utf-8-sig').read()) self.dict_pth = args.dict_pth self.token2idx, self.idx2token = self.create_vocabulary() def create_vocabulary(self): if self.dict_pth != ' ': word_dict = json.loads(open(self.dict_pth, 'r', encoding="utf_8_sig").read()) word_dict[1] = {int(k): v for k, v in word_dict[1].items()} return word_dict[0], word_dict[1] else: total_tokens = [] split_list = ['train', 'test', 'val'] for split in split_list: for example in self.ann[split]: tokens = list(jieba.lcut(example['finding'])) for token in tokens: total_tokens.append(token) counter = Counter(total_tokens) vocab = [k for k, v in counter.items()] + [''] token2idx, idx2token = {}, {} for idx, token in enumerate(vocab): token2idx[token] = idx + 1 idx2token[idx + 1] = token with open('E:/Captionv0/Code/SGF/utils/breast_dict.txt', 'w', encoding='utf-8-sig') as f: f.write(json.dumps([token2idx, idx2token])) return token2idx, idx2token def get_token_by_id(self, id): return self.idx2token[id] def get_id_by_token(self, token): if token not in self.token2idx: return self.token2idx[''] return self.token2idx[token] def get_vocab_size(self): return len(self.token2idx) def __call__(self, report): tokens = list(jieba.cut(report)) ids = [] for token in tokens: ids.append(self.get_id_by_token(token)) ids = [0] + ids + [0] return ids def decode(self, ids): txt = '' for i, idx in enumerate(ids): if idx > 0: if i >= 1: txt += ' ' txt += self.idx2token[idx] else: break return txt def decode_list(self, ids): txt = [] for i, idx in enumerate(ids): if idx > 0: txt.append(self.idx2token[idx]) else:txt.append('') return txt def decode_batch(self, ids_batch): out = [] for ids in ids_batch: out.append(self.decode(ids)) return out def decode_batch_list(self, ids_batch): out = [] for ids in ids_batch: out.append(self.decode_list(ids)) return out