File size: 5,444 Bytes
cd5ed10
c837b79
 
cd5ed10
 
 
 
 
 
c837b79
 
 
cd5ed10
 
c837b79
cd5ed10
 
c837b79
 
cd5ed10
 
 
c837b79
cd5ed10
 
 
 
c837b79
cd5ed10
 
 
c837b79
cd5ed10
c837b79
cd5ed10
 
c837b79
 
 
 
cd5ed10
c837b79
cd5ed10
 
 
c837b79
cd5ed10
 
 
 
 
 
 
 
 
c837b79
 
cd5ed10
 
 
 
 
c837b79
cd5ed10
 
c837b79
 
cd5ed10
 
 
 
 
 
 
 
 
c837b79
 
cd5ed10
 
 
 
c837b79
cd5ed10
 
 
 
 
c837b79
cd5ed10
 
 
 
 
c837b79
cd5ed10
 
 
 
 
 
 
 
 
c837b79
 
cd5ed10
 
 
 
 
 
c837b79
cd5ed10
 
c837b79
cd5ed10
c837b79
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cd5ed10
c837b79
 
 
 
cd5ed10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c837b79
 
cd5ed10
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
# %%
import argparse

from tqdm import tqdm
import unicodedata
import re
import pickle
import torch
import NER_medNLP as ner

from EntityNormalizer import EntityNormalizer, DiseaseDict, DrugDict

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

# %% global変数として使う
dict_key = {}


# %%
def to_xml(data):
    with open("key_attr.pkl", "rb") as tf:
        key_attr = pickle.load(tf)

    text = data['text']
    count = 0
    for i, entities in enumerate(data['entities_predicted']):
        if entities == "":
            return
        span = entities['span']
        type_id = id_to_tags[entities['type_id']].split('_')
        tag = type_id[0]

        if not type_id[1] == "":
            attr = ' ' + value_to_key(type_id[1], key_attr) + '=' + '"' + type_id[1] + '"'
        else:
            attr = ""

        if 'norm' in entities:
            attr = attr + ' norm="' + str(entities['norm']) + '"'

        add_tag = "<" + str(tag) + str(attr) + ">"
        text = text[:span[0] + count] + add_tag + text[span[0] + count:]
        count += len(add_tag)

        add_tag = "</" + str(tag) + ">"
        text = text[:span[1] + count] + add_tag + text[span[1] + count:]
        count += len(add_tag)
    return text


def predict_entities(modelpath, sentences_list, len_num_entity_type):
    # model = ner.BertForTokenClassification_pl.load_from_checkpoint(
    #     checkpoint_path = modelpath + ".ckpt"
    # ) 
    # bert_tc = model.bert_tc.cuda()

    model = ner.BertForTokenClassification_pl(modelpath, num_labels=81, lr=1e-5)
    bert_tc = model.bert_tc.to(device)

    MODEL_NAME = 'cl-tohoku/bert-base-japanese-whole-word-masking'
    tokenizer = ner.NER_tokenizer_BIO.from_pretrained(
        MODEL_NAME,
        num_entity_type=len_num_entity_type  # Entityの数を変え忘れないように!
    )

    # entities_list = [] # 正解の固有表現を追加していく
    entities_predicted_list = []  # 抽出された固有表現を追加していく

    text_entities_set = []
    for dataset in sentences_list:
        text_entities = []
        for sample in tqdm(dataset):
            text = sample
            encoding, spans = tokenizer.encode_plus_untagged(
                text, return_tensors='pt'
            )
            encoding = {k: v.to(device) for k, v in encoding.items()}

            with torch.no_grad():
                output = bert_tc(**encoding)
                scores = output.logits
                scores = scores[0].cpu().numpy().tolist()

            # 分類スコアを固有表現に変換する
            entities_predicted = tokenizer.convert_bert_output_to_entities(
                text, scores, spans
            )

            # entities_list.append(sample['entities'])
            entities_predicted_list.append(entities_predicted)
            text_entities.append({'text': text, 'entities_predicted': entities_predicted})
        text_entities_set.append(text_entities)
    return text_entities_set


def combine_sentences(text_entities_set, insert: str):
    documents = []
    for text_entities in tqdm(text_entities_set):
        document = []
        for t in text_entities:
            document.append(to_xml(t))
        documents.append('\n'.join(document))
    return documents


def value_to_key(value, key_attr):  # attributeから属性名を取得
    global dict_key
    if dict_key.get(value) != None:
        return dict_key[value]
    for k in key_attr.keys():
        for v in key_attr[k]:
            if value == v:
                dict_key[v] = k
                return k


# %%
def normalize_entities(text_entities_set):
    disease_normalizer = EntityNormalizer(DiseaseDict(), matching_threshold=50)
    drug_normalizer = EntityNormalizer(DrugDict(), matching_threshold=50)

    for entry in text_entities_set:
        for text_entities in entry:
            entities = text_entities['entities_predicted']
            for entity in entities:
                tag = id_to_tags[entity['type_id']].split('_')[0]

                normalizer = drug_normalizer if tag == 'm-key' \
                    else disease_normalizer if tag == 'd' \
                    else None

                if normalizer is None:
                    continue

                normalization, score = normalizer.normalize(entity['name'])
                entity['norm'] = str(normalization)


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Predict entities from text')
    parser.add_argument('--normalize', action=argparse.BooleanOptionalAction, help='Enable entity normalization')
    args = parser.parse_args()

    with open("id_to_tags.pkl", "rb") as tf:
        id_to_tags = pickle.load(tf)
    with open("key_attr.pkl", "rb") as tf:
        key_attr = pickle.load(tf)
    with open('text.txt') as f:
        articles_raw = f.read()

    article_norm = unicodedata.normalize('NFKC', articles_raw)

    sentences_raw = [s for s in re.split(r'\n', articles_raw) if s != '']
    sentences_norm = [s for s in re.split(r'\n', article_norm) if s != '']

    text_entities_set = predict_entities("sociocom/RealMedNLP_CR_JA", [sentences_norm], len(id_to_tags))

    for i, texts_ent in enumerate(text_entities_set[0]):
        texts_ent['text'] = sentences_raw[i]

    if args.normalize:
        normalize_entities(text_entities_set)

    documents = combine_sentences(text_entities_set, '\n')

    print(documents[0])