|
import json |
|
from collections import Counter |
|
import jieba |
|
import os |
|
abs_path = os.path.dirname(os.path.abspath(__file__)) |
|
jieba.load_userdict(os.path.join(abs_path, '../../utils/key_technical_words.txt')) |
|
|
|
class Tokenizer(object): |
|
def __init__(self, args): |
|
self.ann_path = args.ann_path |
|
self.threshold = args.threshold |
|
self.ann = json.loads(open(self.ann_path, 'r', encoding='utf-8-sig').read()) |
|
self.dict_pth = args.dict_pth |
|
self.token2idx, self.idx2token = self.create_vocabulary() |
|
|
|
def create_vocabulary(self): |
|
if self.dict_pth != ' ': |
|
word_dict = json.loads(open(self.dict_pth, 'r', encoding="utf_8_sig").read()) |
|
word_dict[1] = {int(k): v for k, v in word_dict[1].items()} |
|
return word_dict[0], word_dict[1] |
|
else: |
|
total_tokens = [] |
|
split_list = ['train', 'test', 'val'] |
|
for split in split_list: |
|
for example in self.ann[split]: |
|
tokens = list(jieba.lcut(example['finding'])) |
|
for token in tokens: |
|
total_tokens.append(token) |
|
counter = Counter(total_tokens) |
|
vocab = [k for k, v in counter.items()] + ['<unk>'] |
|
token2idx, idx2token = {}, {} |
|
for idx, token in enumerate(vocab): |
|
token2idx[token] = idx + 1 |
|
idx2token[idx + 1] = token |
|
with open('E:/Captionv0/Code/SGF/utils/breast_dict.txt', 'w', encoding='utf-8-sig') as f: |
|
f.write(json.dumps([token2idx, idx2token])) |
|
return token2idx, idx2token |
|
|
|
|
|
def get_token_by_id(self, id): |
|
return self.idx2token[id] |
|
|
|
def get_id_by_token(self, token): |
|
if token not in self.token2idx: |
|
return self.token2idx['<unk>'] |
|
return self.token2idx[token] |
|
|
|
def get_vocab_size(self): |
|
return len(self.token2idx) |
|
|
|
|
|
def __call__(self, report): |
|
tokens = list(jieba.cut(report)) |
|
ids = [] |
|
for token in tokens: |
|
ids.append(self.get_id_by_token(token)) |
|
ids = [0] + ids + [0] |
|
return ids |
|
|
|
def decode(self, ids): |
|
txt = '' |
|
for i, idx in enumerate(ids): |
|
if idx > 0: |
|
if i >= 1: |
|
txt += ' ' |
|
txt += self.idx2token[idx] |
|
else: |
|
break |
|
return txt |
|
|
|
def decode_list(self, ids): |
|
txt = [] |
|
for i, idx in enumerate(ids): |
|
if idx > 0: |
|
txt.append(self.idx2token[idx]) |
|
else:txt.append('<start/end>') |
|
|
|
return txt |
|
|
|
def decode_batch(self, ids_batch): |
|
out = [] |
|
for ids in ids_batch: |
|
out.append(self.decode(ids)) |
|
return out |
|
|
|
def decode_batch_list(self, ids_batch): |
|
out = [] |
|
for ids in ids_batch: |
|
out.append(self.decode_list(ids)) |
|
return out |
|
|
|
|
|
|
|
|