File size: 9,365 Bytes
cd5ed10
 
 
576d564
cd5ed10
 
576d564
 
cd5ed10
 
 
576d564
cd5ed10
 
576d564
 
cd5ed10
576d564
cd5ed10
 
576d564
 
 
cd5ed10
576d564
 
 
 
 
 
 
cd5ed10
 
 
 
 
576d564
cd5ed10
 
 
 
 
576d564
 
cd5ed10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
576d564
cd5ed10
576d564
cd5ed10
 
 
 
576d564
 
cd5ed10
576d564
 
cd5ed10
 
576d564
 
cd5ed10
 
 
576d564
cd5ed10
 
576d564
 
cd5ed10
 
576d564
 
 
cd5ed10
 
 
 
 
 
576d564
 
cd5ed10
 
576d564
cd5ed10
 
 
576d564
cd5ed10
576d564
cd5ed10
 
 
 
 
576d564
cd5ed10
 
 
 
 
 
 
576d564
 
 
cd5ed10
 
576d564
cd5ed10
576d564
cd5ed10
 
 
576d564
cd5ed10
 
 
 
576d564
cd5ed10
 
 
576d564
cd5ed10
 
576d564
cd5ed10
 
 
 
576d564
cd5ed10
576d564
 
 
cd5ed10
 
 
 
576d564
cd5ed10
576d564
cd5ed10
 
 
576d564
cd5ed10
 
 
 
 
 
 
 
576d564
cd5ed10
 
576d564
 
 
 
 
cd5ed10
 
 
576d564
 
 
 
cd5ed10
 
 
 
576d564
cd5ed10
 
 
 
 
 
 
 
 
 
 
 
576d564
cd5ed10
576d564
 
 
cd5ed10
 
 
 
 
 
576d564
 
cd5ed10
 
 
 
576d564
cd5ed10
576d564
cd5ed10
 
 
 
 
 
 
 
576d564
cd5ed10
 
576d564
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
# %%

import itertools

import numpy as np
import pytorch_lightning as pl
import torch
from transformers import BertForTokenClassification, BertJapaneseTokenizer


# %%
# 8-16
# PyTorch Lightningのモデル
class BertForTokenClassification_pl(pl.LightningModule):

    def __init__(self, num_labels, model='sociocom/MedNERN-CR-JA', lr=1e-4):
        super().__init__()
        self.lr = lr
        self.save_hyperparameters()
        self.bert_tc = BertForTokenClassification.from_pretrained(
            model,
            num_labels=num_labels,
            ignore_mismatched_sizes=True
        )

    @classmethod
    def from_pretrained_bin(cls, model_path, num_labels):
        model = cls(num_labels)
        model.load_state_dict(torch.load(model_path))
        return model

    def training_step(self, batch, batch_idx):
        output = self.bert_tc(**batch)
        loss = output.loss
        self.log('train_loss', loss)
        return loss

    def validation_step(self, batch, batch_idx):
        output = self.bert_tc(**batch)
        val_loss = output.loss
        self.log('val_loss', val_loss)

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.lr)


# %%
class NER_tokenizer_BIO(BertJapaneseTokenizer):

    # 初期化時に固有表現のカテゴリーの数`num_entity_type`を
    # 受け入れるようにする。
    def __init__(self, *args, **kwargs):
        self.num_entity_type = kwargs.pop('num_entity_type')
        super().__init__(*args, **kwargs)

    def encode_plus_tagged(self, text, entities, max_length):
        """
        文章とそれに含まれる固有表現が与えられた時に、
        符号化とラベル列の作成を行う。
        """
        # 固有表現の前後でtextを分割し、それぞれのラベルをつけておく。
        splitted = []  # 分割後の文字列を追加していく
        position = 0

        for entity in entities:
            start = entity['span'][0]
            end = entity['span'][1]
            label = entity['type_id']
            splitted.append({'text': text[position:start], 'label': 0})
            splitted.append({'text': text[start:end], 'label': label})
            position = end
        splitted.append({'text': text[position:], 'label': 0})
        splitted = [s for s in splitted if s['text']]

        # 分割されたそれぞれの文章をトークン化し、ラベルをつける。
        tokens = []  # トークンを追加していく
        labels = []  # ラベルを追加していく
        for s in splitted:
            tokens_splitted = self.tokenize(s['text'])
            label = s['label']
            if label > 0:  # 固有表現
                # まずトークン全てにI-タグを付与
                # 番号順O-tag:0, B-tag:1~タグの数,I-tag:タグの数〜
                labels_splitted = \
                    [label + self.num_entity_type] * len(tokens_splitted)
                # 先頭のトークンをB-タグにする
                labels_splitted[0] = label
            else:  # それ以外
                labels_splitted = [0] * len(tokens_splitted)

            tokens.extend(tokens_splitted)
            labels.extend(labels_splitted)

        # 符号化を行いBERTに入力できる形式にする。
        input_ids = self.convert_tokens_to_ids(tokens)
        encoding = self.prepare_for_model(
            input_ids,
            max_length=max_length,
            padding='max_length',
            truncation=True
        )

        # ラベルに特殊トークンを追加
        # max_lengthで切り取って,その前後に[CLS]と[SEP]を追加するためのラベルを入れる
        labels = [0] + labels[:max_length - 2] + [0]
        # max_lengthに満たない場合は,満たない分を後ろ側に追加する
        labels = labels + [0] * (max_length - len(labels))
        encoding['labels'] = labels

        return encoding

    def encode_plus_untagged(
            self, text, max_length=None, return_tensors=None
    ):
        """
        文章をトークン化し、それぞれのトークンの文章中の位置も特定しておく。
        IO法のトークナイザのencode_plus_untaggedと同じ
        """
        # 文章のトークン化を行い、
        # それぞれのトークンと文章中の文字列を対応づける。
        tokens = []  # トークンを追加していく。
        tokens_original = []  # トークンに対応する文章中の文字列を追加していく。
        words = self.word_tokenizer.tokenize(text)  # MeCabで単語に分割
        for word in words:
            # 単語をサブワードに分割
            tokens_word = self.subword_tokenizer.tokenize(word)
            tokens.extend(tokens_word)
            if tokens_word[0] == '[UNK]':  # 未知語への対応
                tokens_original.append(word)
            else:
                tokens_original.extend([
                    token.replace('##', '') for token in tokens_word
                ])

        # 各トークンの文章中での位置を調べる。(空白の位置を考慮する)
        position = 0
        spans = []  # トークンの位置を追加していく。
        for token in tokens_original:
            l = len(token)
            while 1:
                if token != text[position:position + l]:
                    position += 1
                else:
                    spans.append([position, position + l])
                    position += l
                    break

        # 符号化を行いBERTに入力できる形式にする。
        input_ids = self.convert_tokens_to_ids(tokens)
        encoding = self.prepare_for_model(
            input_ids,
            max_length=max_length,
            padding='max_length' if max_length else False,
            truncation=True if max_length else False
        )
        sequence_length = len(encoding['input_ids'])
        # 特殊トークン[CLS]に対するダミーのspanを追加。
        spans = [[-1, -1]] + spans[:sequence_length - 2]
        # 特殊トークン[SEP]、[PAD]に対するダミーのspanを追加。
        spans = spans + [[-1, -1]] * (sequence_length - len(spans))

        # 必要に応じてtorch.Tensorにする。
        if return_tensors == 'pt':
            encoding = {k: torch.tensor([v]) for k, v in encoding.items()}

        return encoding, spans

    @staticmethod
    def Viterbi(scores_bert, num_entity_type, penalty=10000):
        """
        Viterbiアルゴリズムで最適解を求める。
        """
        m = 2 * num_entity_type + 1
        penalty_matrix = np.zeros([m, m])
        for i in range(m):
            for j in range(1 + num_entity_type, m):
                if not ((i == j) or (i + num_entity_type == j)):
                    penalty_matrix[i, j] = penalty
        path = [[i] for i in range(m)]
        scores_path = scores_bert[0] - penalty_matrix[0, :]
        scores_bert = scores_bert[1:]

        for scores in scores_bert:
            assert len(scores) == 2 * num_entity_type + 1
            score_matrix = np.array(scores_path).reshape(-1, 1) \
                           + np.array(scores).reshape(1, -1) \
                           - penalty_matrix
            scores_path = score_matrix.max(axis=0)
            argmax = score_matrix.argmax(axis=0)
            path_new = []
            for i, idx in enumerate(argmax):
                path_new.append(path[idx] + [i])
            path = path_new

        labels_optimal = path[np.argmax(scores_path)]
        return labels_optimal

    def convert_bert_output_to_entities(self, text, scores, spans):
        """
        文章、分類スコア、各トークンの位置から固有表現を得る。
        分類スコアはサイズが(系列長、ラベル数)の2次元配列
        """
        assert len(spans) == len(scores)
        num_entity_type = self.num_entity_type

        # 特殊トークンに対応する部分を取り除く
        scores = [score for score, span in zip(scores, spans) if span[0] != -1]
        spans = [span for span in spans if span[0] != -1]

        # Viterbiアルゴリズムでラベルの予測値を決める。
        labels = self.Viterbi(scores, num_entity_type)

        # 同じラベルが連続するトークンをまとめて、固有表現を抽出する。
        entities = []
        for label, group \
                in itertools.groupby(enumerate(labels), key=lambda x: x[1]):

            group = list(group)
            start = spans[group[0][0]][0]
            end = spans[group[-1][0]][1]

            if label != 0:  # 固有表現であれば
                if 1 <= label <= num_entity_type:
                    # ラベルが`B-`ならば、新しいentityを追加
                    entity = {
                        "name": text[start:end],
                        "span": [start, end],
                        "type_id": label
                    }
                    entities.append(entity)
                else:
                    # ラベルが`I-`ならば、直近のentityを更新
                    entity['span'][1] = end
                    entity['name'] = text[entity['span'][0]:entity['span'][1]]

        return entities