SIAT-RZJS's picture
Upload 187 files
3b2b066 verified
import json
from collections import Counter
import jieba
import os
abs_path = os.path.dirname(os.path.abspath(__file__))
jieba.load_userdict(os.path.join(abs_path, '../../utils/key_technical_words.txt'))
class Tokenizer(object):
def __init__(self, args):
self.ann_path = args.ann_path
self.threshold = args.threshold
self.ann = json.loads(open(self.ann_path, 'r', encoding='utf-8-sig').read())
self.dict_pth = args.dict_pth
self.token2idx, self.idx2token = self.create_vocabulary()
def create_vocabulary(self):
if self.dict_pth != ' ':
word_dict = json.loads(open(self.dict_pth, 'r', encoding="utf_8_sig").read())
word_dict[1] = {int(k): v for k, v in word_dict[1].items()}
return word_dict[0], word_dict[1]
else:
total_tokens = []
split_list = ['train', 'test', 'val']
for split in split_list:
for example in self.ann[split]:
tokens = list(jieba.lcut(example['finding']))
for token in tokens:
total_tokens.append(token)
counter = Counter(total_tokens)
vocab = [k for k, v in counter.items()] + ['<unk>']
token2idx, idx2token = {}, {}
for idx, token in enumerate(vocab):
token2idx[token] = idx + 1
idx2token[idx + 1] = token
with open('E:/Captionv0/Code/SGF/utils/breast_dict.txt', 'w', encoding='utf-8-sig') as f:
f.write(json.dumps([token2idx, idx2token]))
return token2idx, idx2token
def get_token_by_id(self, id):
return self.idx2token[id]
def get_id_by_token(self, token):
if token not in self.token2idx:
return self.token2idx['<unk>']
return self.token2idx[token]
def get_vocab_size(self):
return len(self.token2idx)
def __call__(self, report):
tokens = list(jieba.cut(report))
ids = []
for token in tokens:
ids.append(self.get_id_by_token(token))
ids = [0] + ids + [0]
return ids
def decode(self, ids):
txt = ''
for i, idx in enumerate(ids):
if idx > 0:
if i >= 1:
txt += ' '
txt += self.idx2token[idx]
else:
break
return txt
def decode_list(self, ids):
txt = []
for i, idx in enumerate(ids):
if idx > 0:
txt.append(self.idx2token[idx])
else:txt.append('<start/end>')
return txt
def decode_batch(self, ids_batch):
out = []
for ids in ids_batch:
out.append(self.decode(ids))
return out
def decode_batch_list(self, ids_batch):
out = []
for ids in ids_batch:
out.append(self.decode_list(ids))
return out