anamargarida commited on
Commit
f239892
·
verified ·
1 Parent(s): 5d7828b

Update modeling_st2.py

Browse files
Files changed (1) hide show
  1. modeling_st2.py +18 -160
modeling_st2.py CHANGED
@@ -41,23 +41,7 @@ class ST2ModelV2(nn.Module):
41
  self.dropout = nn.Dropout(classifier_dropout)
42
  self.classifier = nn.Linear(self.config.hidden_size, 6)
43
 
44
- if args.mlp:
45
- self.classifier = nn.Sequential(
46
- nn.Linear(self.config.hidden_size, self.config.hidden_size),
47
- nn.ReLU(),
48
- nn.Linear(self.config.hidden_size, 6),
49
- nn.Tanh(),
50
- nn.Linear(6, 6),
51
- )
52
-
53
- if args.add_signal_bias:
54
- self.signal_phrases_layer = nn.Parameter(
55
- torch.normal(
56
- mean=self.model.embeddings.word_embeddings.weight.data.mean(),
57
- std=self.model.embeddings.word_embeddings.weight.data.std(),
58
- size=(1, self.config.hidden_size),
59
- )
60
- )
61
 
62
  if self.args.signal_classification and not self.args.pretrained_signal_detector:
63
  self.signal_classifier = nn.Linear(self.config.hidden_size, 2)
@@ -89,37 +73,22 @@ class ST2ModelV2(nn.Module):
89
  """
90
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
91
 
92
- if signal_bias_mask is not None and not self.args.signal_bias_on_top_of_lm:
93
- inputs_embeds = self.signal_phrases_bias(input_ids, signal_bias_mask)
94
-
95
- outputs = self.model(
96
- # input_ids,
97
- attention_mask=attention_mask,
98
- token_type_ids=token_type_ids,
99
- position_ids=position_ids,
100
- head_mask=head_mask,
101
- inputs_embeds=inputs_embeds,
102
- output_attentions=output_attentions,
103
- output_hidden_states=output_hidden_states,
104
- return_dict=return_dict,
105
- )
106
- else:
107
 
108
- outputs = self.model(
109
- input_ids,
110
- attention_mask=attention_mask,
111
- token_type_ids=token_type_ids,
112
- position_ids=position_ids,
113
- head_mask=head_mask,
114
- inputs_embeds=inputs_embeds,
115
- output_attentions=output_attentions,
116
- output_hidden_states=output_hidden_states,
117
- return_dict=return_dict,
118
- )
119
 
120
  sequence_output = outputs[0]
121
- if signal_bias_mask is not None and self.args.signal_bias_on_top_of_lm:
122
- sequence_output[signal_bias_mask == 1] += self.signal_phrases_layer
123
 
124
  sequence_output = self.dropout(sequence_output)
125
  logits = self.classifier(sequence_output) # [batch_size, max_seq_length, 6]
@@ -147,36 +116,7 @@ class ST2ModelV2(nn.Module):
147
  # start_logits = start_logits.squeeze(-1).contiguous()
148
  # end_logits = end_logits.squeeze(-1).contiguous()
149
 
150
- arg0_loss = None
151
- arg1_loss = None
152
- sig_loss = None
153
- total_loss = None
154
- signal_classification_loss = None
155
- if start_positions is not None and end_positions is not None:
156
- loss_fct = nn.CrossEntropyLoss()
157
-
158
- start_arg0_loss = loss_fct(start_arg0_logits, start_positions[:, 0])
159
- end_arg0_loss = loss_fct(end_arg0_logits, end_positions[:, 0])
160
- arg0_loss = (start_arg0_loss + end_arg0_loss) / 2
161
-
162
- start_arg1_loss = loss_fct(start_arg1_logits, start_positions[:, 1])
163
- end_arg1_loss = loss_fct(end_arg1_logits, end_positions[:, 1])
164
- arg1_loss = (start_arg1_loss + end_arg1_loss) / 2
165
-
166
- # sig_loss = 0.
167
- start_sig_loss = loss_fct(start_sig_logits, start_positions[:, 2])
168
- end_sig_loss = loss_fct(end_sig_logits, end_positions[:, 2])
169
- sig_loss = (start_sig_loss + end_sig_loss) / 2
170
-
171
- if sig_loss.isnan():
172
- sig_loss = 0.
173
-
174
- if self.args.signal_classification and not self.args.pretrained_signal_detector:
175
- signal_classification_labels = end_positions[:, 2] != -100
176
- signal_classification_loss = loss_fct(signal_classification_logits, signal_classification_labels.long())
177
- total_loss = (arg0_loss + arg1_loss + sig_loss + signal_classification_loss) / 4
178
- else:
179
- total_loss = (arg0_loss + arg1_loss + sig_loss) / 3
180
 
181
 
182
  return {
@@ -187,69 +127,10 @@ class ST2ModelV2(nn.Module):
187
  'start_sig_logits': start_sig_logits,
188
  'end_sig_logits': end_sig_logits,
189
  'signal_classification_logits': signal_classification_logits,
190
- 'arg0_loss': arg0_loss,
191
- 'arg1_loss': arg1_loss,
192
- 'sig_loss': sig_loss,
193
- 'signal_classification_loss': signal_classification_loss,
194
- 'loss': total_loss,
195
  }
196
 
197
- """
198
- def save_pretrained(self, save_directory):
199
-
200
- #Save model state dict as safetensor, configuration, and tokenizer files.
201
-
202
- # Ensure the directory exists
203
- os.makedirs(save_directory, exist_ok=True)
204
-
205
- # Save model state dict as safetensor (use torch.save for PyTorch model)
206
- model_path = os.path.join(save_directory, "model.safetensor")
207
- save_file(self.state_dict(), model_path)
208
-
209
- # Save config if available
210
- config_save_path = os.path.join(save_directory, 'config.json')
211
- self.config.to_json_file(config_save_path)
212
-
213
-
214
- # Save tokenizer
215
- if hasattr(self, 'tokenizer') and self.tokenizer is not None:
216
- tokenizer_save_path = os.path.join(save_directory, 'tokenizer')
217
- self.tokenizer.save_pretrained(tokenizer_save_path)
218
- """
219
-
220
-
221
- def save_pretrained(self, save_directory):
222
- """
223
- Save model state dict as safetensor, PyTorch .bin format, configuration, and tokenizer files.
224
- """
225
- # Ensure the directory exists
226
- os.makedirs(save_directory, exist_ok=True)
227
-
228
- # Save model state dict as safetensor
229
- model_path_safetensor = os.path.join(save_directory, "model.safetensors")
230
- save_file(self.state_dict(), model_path_safetensor) # Save as .safetensors
231
-
232
- # Save model state dict as PyTorch .bin (traditional format)
233
- model_path_bin = os.path.join(save_directory, "pytorch_model.bin")
234
- torch.save(self.state_dict(), model_path_bin) # Save as .bin using PyTorch's torch.save()
235
-
236
- # Save config if available
237
- config_save_path = os.path.join(save_directory, 'config.json')
238
- self.config.to_json_file(config_save_path)
239
-
240
- """
241
- # Save tokenizer if it exists
242
- if hasattr(self, 'tokenizer') and self.tokenizer is not None:
243
- tokenizer_save_path = os.path.join(save_directory, 'tokenizer')
244
- self.tokenizer.save_pretrained(tokenizer_save_path)
245
- """
246
-
247
-
248
- def signal_phrases_bias(self, input_ids, signal_bias_mask):
249
- inputs_embeds = self.model.get_input_embeddings()(input_ids)
250
- inputs_embeds[signal_bias_mask == 1] += self.signal_phrases_layer # self.signal_phrases_layer(inputs_embeds[signal_bias_mask == 1])
251
-
252
- return inputs_embeds
253
 
254
  def position_selector(
255
  self,
@@ -332,27 +213,11 @@ class ST2ModelV2(nn.Module):
332
  start_effect_logits,
333
  end_cause_logits,
334
  end_effect_logits,
335
- attention_mask,
336
  word_ids,
337
  topk=5
338
  ):
339
  # basic post processing (removing logits from [CLS], [SEP], [PAD])
340
 
341
- start_cause_logits -= (1 - attention_mask) * 1e4
342
- end_cause_logits -= (1 - attention_mask) * 1e4
343
- start_effect_logits -= (1 - attention_mask) * 1e4
344
- end_effect_logits -= (1 - attention_mask) * 1e4
345
-
346
- start_cause_logits[0] = -1e4
347
- end_cause_logits[0] = -1e4
348
- start_effect_logits[0] = -1e4
349
- end_effect_logits[0] = -1e4
350
-
351
- start_cause_logits[len(word_ids) - 1] = -1e4
352
- end_cause_logits[len(word_ids) - 1] = -1e4
353
- start_effect_logits[len(word_ids) - 1] = -1e4
354
- end_effect_logits[len(word_ids) - 1] = -1e4
355
-
356
  start_cause_logits = torch.log(torch.softmax(start_cause_logits, dim=-1))
357
  end_cause_logits = torch.log(torch.softmax(end_cause_logits, dim=-1))
358
  start_effect_logits = torch.log(torch.softmax(start_effect_logits, dim=-1))
@@ -360,19 +225,12 @@ class ST2ModelV2(nn.Module):
360
 
361
  scores = dict()
362
  for i in range(len(end_cause_logits)):
363
- if attention_mask[i] == 0:
364
- break
365
  for j in range(i + 1, len(start_effect_logits)):
366
- if attention_mask[j] == 0:
367
- break
368
  scores[str((i, j, "before"))] = end_cause_logits[i].item() + start_effect_logits[j].item()
369
 
370
  for i in range(len(end_effect_logits)):
371
- if attention_mask[i] == 0:
372
- break
373
  for j in range(i + 1, len(start_cause_logits)):
374
- if attention_mask[j] == 0:
375
- break
376
  scores[str((i, j, "after"))] = start_cause_logits[j].item() + end_effect_logits[i].item()
377
 
378
 
 
41
  self.dropout = nn.Dropout(classifier_dropout)
42
  self.classifier = nn.Linear(self.config.hidden_size, 6)
43
 
44
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
 
46
  if self.args.signal_classification and not self.args.pretrained_signal_detector:
47
  self.signal_classifier = nn.Linear(self.config.hidden_size, 2)
 
73
  """
74
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
75
 
76
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
 
78
+ outputs = self.model(
79
+ input_ids,
80
+ attention_mask=attention_mask,
81
+ token_type_ids=token_type_ids,
82
+ position_ids=position_ids,
83
+ head_mask=head_mask,
84
+ inputs_embeds=inputs_embeds,
85
+ output_attentions=output_attentions,
86
+ output_hidden_states=output_hidden_states,
87
+ return_dict=return_dict,
88
+ )
89
 
90
  sequence_output = outputs[0]
91
+
 
92
 
93
  sequence_output = self.dropout(sequence_output)
94
  logits = self.classifier(sequence_output) # [batch_size, max_seq_length, 6]
 
116
  # start_logits = start_logits.squeeze(-1).contiguous()
117
  # end_logits = end_logits.squeeze(-1).contiguous()
118
 
119
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
120
 
121
 
122
  return {
 
127
  'start_sig_logits': start_sig_logits,
128
  'end_sig_logits': end_sig_logits,
129
  'signal_classification_logits': signal_classification_logits,
130
+
 
 
 
 
131
  }
132
 
133
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
134
 
135
  def position_selector(
136
  self,
 
213
  start_effect_logits,
214
  end_cause_logits,
215
  end_effect_logits,
 
216
  word_ids,
217
  topk=5
218
  ):
219
  # basic post processing (removing logits from [CLS], [SEP], [PAD])
220
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
221
  start_cause_logits = torch.log(torch.softmax(start_cause_logits, dim=-1))
222
  end_cause_logits = torch.log(torch.softmax(end_cause_logits, dim=-1))
223
  start_effect_logits = torch.log(torch.softmax(start_effect_logits, dim=-1))
 
225
 
226
  scores = dict()
227
  for i in range(len(end_cause_logits)):
228
+
 
229
  for j in range(i + 1, len(start_effect_logits)):
 
 
230
  scores[str((i, j, "before"))] = end_cause_logits[i].item() + start_effect_logits[j].item()
231
 
232
  for i in range(len(end_effect_logits)):
 
 
233
  for j in range(i + 1, len(start_cause_logits)):
 
 
234
  scores[str((i, j, "after"))] = start_cause_logits[j].item() + end_effect_logits[i].item()
235
 
236