T-Almeida commited on
Commit
909a7d4
·
verified ·
1 Parent(s): 9ba689d

Upload model

Browse files
Files changed (3) hide show
  1. config.json +2 -1
  2. model.safetensors +3 -0
  3. modeling_bionexttagger.py +440 -0
config.json CHANGED
@@ -7,7 +7,8 @@
7
  "attention_probs_dropout_prob": 0.1,
8
  "augmentation": "unk",
9
  "auto_map": {
10
- "AutoConfig": "configuration_bionexttager.BioNextTaggerConfig"
 
11
  },
12
  "classifier_dropout": null,
13
  "context_size": 2,
 
7
  "attention_probs_dropout_prob": 0.1,
8
  "augmentation": "unk",
9
  "auto_map": {
10
+ "AutoConfig": "configuration_bionexttager.BioNextTaggerConfig",
11
+ "AutoModel": "modeling_bionexttagger.BioNextTaggerModel"
12
  },
13
  "classifier_dropout": null,
14
  "context_size": 2,
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e0f62f81f49e7c3d3704ea79b6c9714ce76f8eaf27c70d2b4339ded3be5aed95
3
+ size 1334004696
modeling_bionexttagger.py ADDED
@@ -0,0 +1,440 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import os
3
+ from typing import Optional, Union
4
+ from transformers import AutoModel, PreTrainedModel, AutoConfig
5
+ from transformers.modeling_outputs import TokenClassifierOutput
6
+ from torch import nn
7
+ from torch.nn import CrossEntropyLoss
8
+
9
+ from typing import List, Optional
10
+
11
+ import torch
12
+ from itertools import islice
13
+ from .configuration_bionexttager import BioNextTaggerConfig
14
+
15
+
16
+ NUM_PER_LAYER = 16
17
+
18
+ class BioNextTaggerModel(PreTrainedModel):
19
+ config_class = BioNextTaggerConfig
20
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
21
+
22
+ def __init__(self, config):
23
+ super().__init__(config)
24
+ self.num_labels = config.num_labels
25
+ self.bert = AutoModel.from_pretrained(config._name_or_path, config=config.get_backbonemodel_config(), add_pooling_layer=False)
26
+ # self.vocab_size = config.vocab_size
27
+ classifier_dropout = (config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob)
28
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
29
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
30
+ self.dense_activation = nn.GELU(approximate='none')
31
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
32
+ self.crf = CRF(num_tags=config.num_labels, batch_first=True)
33
+ self.reduction=config.crf_reduction
34
+
35
+ if self.config.freeze == True:
36
+ self.manage_freezing()
37
+
38
+ #self.bert.init_weights() # load pretrained weights
39
+
40
+ def manage_freezing(self):
41
+ for _, param in self.bert.embeddings.named_parameters():
42
+ param.requires_grad = False
43
+
44
+ num_encoders_to_freeze = self.config.num_frozen_encoder
45
+ if num_encoders_to_freeze > 0:
46
+ for _, param in islice(self.bert.encoder.named_parameters(), num_encoders_to_freeze*NUM_PER_LAYER):
47
+ param.requires_grad = False
48
+
49
+
50
+ def forward(self,
51
+ input_ids=None,
52
+ attention_mask=None,
53
+ token_type_ids=None,
54
+ position_ids=None,
55
+ head_mask=None,
56
+ inputs_embeds=None,
57
+ labels=None,
58
+ output_attentions=None,
59
+ output_hidden_states=None,
60
+ return_dict=None
61
+ ):
62
+ # Default `model.config.use_return_dict´ is `True´
63
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
64
+
65
+ outputs = self.bert(input_ids,
66
+ attention_mask=attention_mask,
67
+ token_type_ids=token_type_ids,
68
+ position_ids=position_ids,
69
+ head_mask=head_mask,
70
+ inputs_embeds=inputs_embeds,
71
+ output_attentions=output_attentions,
72
+ output_hidden_states=output_hidden_states,
73
+ return_dict=return_dict)
74
+
75
+ sequence_output = outputs[0]
76
+ sequence_output = self.dropout(sequence_output) # B S E
77
+ dense_output = self.dense(sequence_output)
78
+ dense_output = self.dense_activation(dense_output)
79
+ logits = self.classifier(dense_output)
80
+ #logits = self.classifier(sequence_output)
81
+
82
+ loss = None
83
+ if labels is not None:
84
+ # During train/test as we don't pass labels during inference
85
+
86
+ # loss
87
+ return self.crf(logits, labels, reduction=self.reduction), logits
88
+ else:
89
+ # decoded tags
90
+ # NOTE: This gather operation (multiGPU) not work here, bc it uses tensors that are on CPU...
91
+ return torch.Tensor(self.crf.decode(logits))
92
+
93
+
94
+
95
+ # Taken from https://github.com/kmkurn/pytorch-crf/blob/master/torchcrf/__init__.py and fixed got uint8 warning
96
+
97
+
98
+
99
+ LARGE_NEGATIVE_NUMBER = -1e9
100
+
101
+ class CRF(nn.Module):
102
+ """Conditional random field.
103
+ This module implements a conditional random field [LMP01]_. The forward computation
104
+ of this class computes the log likelihood of the given sequence of tags and
105
+ emission score tensor. This class also has `~CRF.decode` method which finds
106
+ the best tag sequence given an emission score tensor using `Viterbi algorithm`_.
107
+ Args:
108
+ num_tags: Number of tags.
109
+ batch_first: Whether the first dimension corresponds to the size of a minibatch.
110
+ Attributes:
111
+ start_transitions (`~torch.nn.Parameter`): Start transition score tensor of size
112
+ ``(num_tags,)``.
113
+ end_transitions (`~torch.nn.Parameter`): End transition score tensor of size
114
+ ``(num_tags,)``.
115
+ transitions (`~torch.nn.Parameter`): Transition score tensor of size
116
+ ``(num_tags, num_tags)``.
117
+ .. [LMP01] Lafferty, J., McCallum, A., Pereira, F. (2001).
118
+ "Conditional random fields: Probabilistic models for segmenting and
119
+ labeling sequence data". *Proc. 18th International Conf. on Machine
120
+ Learning*. Morgan Kaufmann. pp. 282–289.
121
+ .. _Viterbi algorithm: https://en.wikipedia.org/wiki/Viterbi_algorithm
122
+ """
123
+
124
+ def __init__(self, num_tags: int, batch_first: bool = False) -> None:
125
+ if num_tags <= 0:
126
+ raise ValueError(f'invalid number of tags: {num_tags}')
127
+ super().__init__()
128
+ self.num_tags = num_tags
129
+ self.batch_first = batch_first
130
+ self.start_transitions = nn.Parameter(torch.empty(num_tags))
131
+ self.end_transitions = nn.Parameter(torch.empty(num_tags))
132
+ self.transitions = nn.Parameter(torch.empty(num_tags, num_tags))
133
+
134
+ self.reset_parameters()
135
+ self.mask_impossible_transitions()
136
+
137
+ def reset_parameters(self) -> None:
138
+ """Initialize the transition parameters.
139
+ The parameters will be initialized randomly from a uniform distribution
140
+ between -0.1 and 0.1.
141
+ """
142
+ nn.init.uniform_(self.start_transitions, -0.1, 0.1)
143
+ nn.init.uniform_(self.end_transitions, -0.1, 0.1)
144
+ nn.init.uniform_(self.transitions, -0.1, 0.1)
145
+
146
+ def mask_impossible_transitions(self) -> None:
147
+ """Set the value of impossible transitions to LARGE_NEGATIVE_NUMBER
148
+ - start transition value of I-X
149
+ - transition score of O -> I
150
+ """
151
+ with torch.no_grad():
152
+
153
+ for i in range(6):
154
+ self.start_transitions[i*2+2] = LARGE_NEGATIVE_NUMBER
155
+ #O to any I
156
+ self.transitions[0][i*2+2] = LARGE_NEGATIVE_NUMBER
157
+
158
+ #I to any other I
159
+ for j in range(6):
160
+ if j!=i:
161
+ self.transitions[i*2+1][j*2+2] = LARGE_NEGATIVE_NUMBER
162
+ self.transitions[i*2+2][j*2+2] = LARGE_NEGATIVE_NUMBER
163
+
164
+
165
+
166
+ def __repr__(self) -> str:
167
+ return f'{self.__class__.__name__}(num_tags={self.num_tags})'
168
+
169
+ def forward(
170
+ self,
171
+ emissions: torch.Tensor,
172
+ tags: torch.LongTensor,
173
+ mask: Optional[torch.ByteTensor] = None,
174
+ reduction: str = 'sum',
175
+ ) -> torch.Tensor:
176
+ """Compute the conditional log likelihood of a sequence of tags given emission scores.
177
+ Args:
178
+ emissions (`~torch.Tensor`): Emission score tensor of size
179
+ ``(seq_length, batch_size, num_tags)`` if ``batch_first`` is ``False``,
180
+ ``(batch_size, seq_length, num_tags)`` otherwise.
181
+ tags (`~torch.LongTensor`): Sequence of tags tensor of size
182
+ ``(seq_length, batch_size)`` if ``batch_first`` is ``False``,
183
+ ``(batch_size, seq_length)`` otherwise.
184
+ mask (`~torch.ByteTensor`): Mask tensor of size ``(seq_length, batch_size)``
185
+ if ``batch_first`` is ``False``, ``(batch_size, seq_length)`` otherwise.
186
+ reduction: Specifies the reduction to apply to the output:
187
+ ``none|sum|mean|token_mean``. ``none``: no reduction will be applied.
188
+ ``sum``: the output will be summed over batches. ``mean``: the output will be
189
+ averaged over batches. ``token_mean``: the output will be averaged over tokens.
190
+ Returns:
191
+ `~torch.Tensor`: The log likelihood. This will have size ``(batch_size,)`` if
192
+ reduction is ``none``, ``()`` otherwise.
193
+ """
194
+ #self.mask_impossible_transitions()
195
+ self._validate(emissions, tags=tags, mask=mask)
196
+ if reduction not in ('none', 'sum', 'mean', 'token_mean'):
197
+ raise ValueError(f'invalid reduction: {reduction}')
198
+ if mask is None:
199
+ mask = torch.ones_like(tags, dtype=torch.uint8)
200
+
201
+ if self.batch_first:
202
+ emissions = emissions.transpose(0, 1)
203
+ tags = tags.transpose(0, 1)
204
+ mask = mask.transpose(0, 1)
205
+
206
+ # shape: (batch_size,)
207
+ numerator = self._compute_score(emissions, tags, mask)
208
+ # shape: (batch_size,)
209
+ denominator = self._compute_normalizer(emissions, mask)
210
+ # shape: (batch_size,)
211
+ llh = numerator - denominator
212
+ nllh = -llh
213
+
214
+ if reduction == 'none':
215
+ return nllh
216
+ if reduction == 'sum':
217
+ return nllh.sum()
218
+ if reduction == 'mean':
219
+ return nllh.mean()
220
+ assert reduction == 'token_mean'
221
+ return nllh.sum() / mask.type_as(emissions).sum()
222
+
223
+ def decode(self, emissions: torch.Tensor,
224
+ mask: Optional[torch.ByteTensor] = None) -> List[List[int]]:
225
+ """Find the most likely tag sequence using Viterbi algorithm.
226
+ Args:
227
+ emissions (`~torch.Tensor`): Emission score tensor of size
228
+ ``(seq_length, batch_size, num_tags)`` if ``batch_first`` is ``False``,
229
+ ``(batch_size, seq_length, num_tags)`` otherwise.
230
+ mask (`~torch.ByteTensor`): Mask tensor of size ``(seq_length, batch_size)``
231
+ if ``batch_first`` is ``False``, ``(batch_size, seq_length)`` otherwise.
232
+ Returns:
233
+ List of list containing the best tag sequence for each batch.
234
+ """
235
+ self._validate(emissions, mask=mask)
236
+ if mask is None:
237
+ mask = emissions.new_ones(emissions.shape[:2], dtype=torch.uint8)
238
+
239
+ if self.batch_first:
240
+ emissions = emissions.transpose(0, 1)
241
+ mask = mask.transpose(0, 1)
242
+
243
+ return self._viterbi_decode(emissions, mask)
244
+
245
+ def _validate(
246
+ self,
247
+ emissions: torch.Tensor,
248
+ tags: Optional[torch.LongTensor] = None,
249
+ mask: Optional[torch.ByteTensor] = None) -> None:
250
+ if emissions.dim() != 3:
251
+ raise ValueError(f'emissions must have dimension of 3, got {emissions.dim()}')
252
+ if emissions.size(2) != self.num_tags:
253
+ raise ValueError(
254
+ f'expected last dimension of emissions is {self.num_tags}, '
255
+ f'got {emissions.size(2)}')
256
+
257
+ if tags is not None:
258
+ if emissions.shape[:2] != tags.shape:
259
+ raise ValueError(
260
+ 'the first two dimensions of emissions and tags must match, '
261
+ f'got {tuple(emissions.shape[:2])} and {tuple(tags.shape)}')
262
+
263
+ if mask is not None:
264
+ if emissions.shape[:2] != mask.shape:
265
+ raise ValueError(
266
+ 'the first two dimensions of emissions and mask must match, '
267
+ f'got {tuple(emissions.shape[:2])} and {tuple(mask.shape)}')
268
+ no_empty_seq = not self.batch_first and mask[0].all()
269
+ no_empty_seq_bf = self.batch_first and mask[:, 0].all()
270
+ if not no_empty_seq and not no_empty_seq_bf:
271
+ raise ValueError('mask of the first timestep must all be on')
272
+
273
+ def _compute_score(
274
+ self, emissions: torch.Tensor, tags: torch.LongTensor,
275
+ mask: torch.ByteTensor) -> torch.Tensor:
276
+ # emissions: (seq_length, batch_size, num_tags)
277
+ # tags: (seq_length, batch_size)
278
+ # mask: (seq_length, batch_size)
279
+ assert emissions.dim() == 3 and tags.dim() == 2
280
+ assert emissions.shape[:2] == tags.shape
281
+ assert emissions.size(2) == self.num_tags
282
+ assert mask.shape == tags.shape
283
+ assert mask[0].all()
284
+
285
+ seq_length, batch_size = tags.shape
286
+ mask = mask.type_as(emissions)
287
+
288
+ # Start transition score and first emission
289
+ # shape: (batch_size,)
290
+ score = self.start_transitions[tags[0]]
291
+ score += emissions[0, torch.arange(batch_size), tags[0]]
292
+
293
+ for i in range(1, seq_length):
294
+ # Transition score to next tag, only added if next timestep is valid (mask == 1)
295
+ # shape: (batch_size,)
296
+ score += self.transitions[tags[i - 1], tags[i]] * mask[i]
297
+
298
+ # Emission score for next tag, only added if next timestep is valid (mask == 1)
299
+ # shape: (batch_size,)
300
+ score += emissions[i, torch.arange(batch_size), tags[i]] * mask[i]
301
+
302
+ # End transition score
303
+ # shape: (batch_size,)
304
+ seq_ends = mask.long().sum(dim=0) - 1
305
+ # shape: (batch_size,)
306
+ last_tags = tags[seq_ends, torch.arange(batch_size)]
307
+ # shape: (batch_size,)
308
+ score += self.end_transitions[last_tags]
309
+
310
+ return score
311
+
312
+ def _compute_normalizer(
313
+ self, emissions: torch.Tensor, mask: torch.ByteTensor) -> torch.Tensor:
314
+ # emissions: (seq_length, batch_size, num_tags)
315
+ # mask: (seq_length, batch_size)
316
+ assert emissions.dim() == 3 and mask.dim() == 2
317
+ assert emissions.shape[:2] == mask.shape
318
+ assert emissions.size(2) == self.num_tags
319
+ assert mask[0].all()
320
+
321
+ seq_length = emissions.size(0)
322
+
323
+ # Start transition score and first emission; score has size of
324
+ # (batch_size, num_tags) where for each batch, the j-th column stores
325
+ # the score that the first timestep has tag j
326
+ # shape: (batch_size, num_tags)
327
+ score = self.start_transitions + emissions[0]
328
+
329
+ for i in range(1, seq_length):
330
+ # Broadcast score for every possible next tag
331
+ # shape: (batch_size, num_tags, 1)
332
+ broadcast_score = score.unsqueeze(2)
333
+
334
+ # Broadcast emission score for every possible current tag
335
+ # shape: (batch_size, 1, num_tags)
336
+ broadcast_emissions = emissions[i].unsqueeze(1)
337
+
338
+ # Compute the score tensor of size (batch_size, num_tags, num_tags) where
339
+ # for each sample, entry at row i and column j stores the sum of scores of all
340
+ # possible tag sequences so far that end with transitioning from tag i to tag j
341
+ # and emitting
342
+ # shape: (batch_size, num_tags, num_tags)
343
+ next_score = broadcast_score + self.transitions + broadcast_emissions
344
+
345
+ # Sum over all possible current tags, but we're in score space, so a sum
346
+ # becomes a log-sum-exp: for each sample, entry i stores the sum of scores of
347
+ # all possible tag sequences so far, that end in tag i
348
+ # shape: (batch_size, num_tags)
349
+ next_score = torch.logsumexp(next_score, dim=1)
350
+
351
+ # Set score to the next score if this timestep is valid (mask == 1)
352
+ # shape: (batch_size, num_tags)
353
+ score = torch.where(mask[i].unsqueeze(1).bool(), next_score, score)
354
+
355
+ # End transition score
356
+ # shape: (batch_size, num_tags)
357
+ score += self.end_transitions
358
+
359
+ # Sum (log-sum-exp) over all possible tags
360
+ # shape: (batch_size,)
361
+ return torch.logsumexp(score, dim=1)
362
+
363
+ def _viterbi_decode(self, emissions: torch.FloatTensor,
364
+ mask: torch.ByteTensor) -> List[List[int]]:
365
+ # emissions: (seq_length, batch_size, num_tags)
366
+ # mask: (seq_length, batch_size)
367
+ assert emissions.dim() == 3 and mask.dim() == 2
368
+ assert emissions.shape[:2] == mask.shape
369
+ assert emissions.size(2) == self.num_tags
370
+ assert mask[0].all()
371
+
372
+ seq_length, batch_size = mask.shape
373
+
374
+ # Start transition and first emission
375
+ # shape: (batch_size, num_tags)
376
+ score = self.start_transitions + emissions[0]
377
+ history = []
378
+
379
+ # score is a tensor of size (batch_size, num_tags) where for every batch,
380
+ # value at column j stores the score of the best tag sequence so far that ends
381
+ # with tag j
382
+ # history saves where the best tags candidate transitioned from; this is used
383
+ # when we trace back the best tag sequence
384
+
385
+ # Viterbi algorithm recursive case: we compute the score of the best tag sequence
386
+ # for every possible next tag
387
+ for i in range(1, seq_length):
388
+ # Broadcast viterbi score for every possible next tag
389
+ # shape: (batch_size, num_tags, 1)
390
+ broadcast_score = score.unsqueeze(2)
391
+
392
+ # Broadcast emission score for every possible current tag
393
+ # shape: (batch_size, 1, num_tags)
394
+ broadcast_emission = emissions[i].unsqueeze(1)
395
+
396
+ # Compute the score tensor of size (batch_size, num_tags, num_tags) where
397
+ # for each sample, entry at row i and column j stores the score of the best
398
+ # tag sequence so far that ends with transitioning from tag i to tag j and emitting
399
+ # shape: (batch_size, num_tags, num_tags)
400
+ next_score = broadcast_score + self.transitions + broadcast_emission
401
+
402
+ # Find the maximum score over all possible current tag
403
+ # shape: (batch_size, num_tags)
404
+ next_score, indices = next_score.max(dim=1)
405
+
406
+ # Set score to the next score if this timestep is valid (mask == 1)
407
+ # and save the index that produces the next score
408
+ # shape: (batch_size, num_tags)
409
+ score = torch.where(mask[i].unsqueeze(1).bool(), next_score, score)
410
+ history.append(indices)
411
+
412
+ # End transition score
413
+ # shape: (batch_size, num_tags)
414
+ score += self.end_transitions
415
+
416
+ # Now, compute the best path for each sample
417
+
418
+ # shape: (batch_size,)
419
+ seq_ends = mask.long().sum(dim=0) - 1
420
+ best_tags_list = []
421
+
422
+ for idx in range(batch_size):
423
+ # Find the tag which maximizes the score at the last timestep; this is our best tag
424
+ # for the last timestep
425
+ _, best_last_tag = score[idx].max(dim=0)
426
+ best_tags = [best_last_tag.item()]
427
+
428
+ # We trace back where the best last tag comes from, append that to our best tag
429
+ # sequence, and trace it back again, and so on
430
+ for hist in reversed(history[:seq_ends[idx]]):
431
+ best_last_tag = hist[idx][best_tags[-1]]
432
+ best_tags.append(best_last_tag.item())
433
+
434
+ # Reverse the order because we start from the last timestep
435
+ best_tags.reverse()
436
+ best_tags_list.append(best_tags)
437
+
438
+ return best_tags_list
439
+
440
+