anamargarida commited on
Commit
e4e2e9d
·
verified ·
1 Parent(s): c55aca5

Rename ST2ModelV2_6.py to modeling_st2.py

Browse files
Files changed (1) hide show
  1. ST2ModelV2_6.py → modeling_st2.py +256 -67
ST2ModelV2_6.py → modeling_st2.py RENAMED
@@ -7,79 +7,199 @@ from transformers import (
7
  AutoConfig,
8
  AutoModelForSequenceClassification
9
  )
10
- from statistics import mode
11
- from safetensors.torch import load_file
12
-
13
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  class ST2ModelV2(nn.Module):
15
- def __init__(self, args, config):
16
  super(ST2ModelV2, self).__init__()
17
  self.args = args
18
- self.config = config
19
 
20
- # Load the base model (e.g., Roberta)
21
- self.model = AutoModel.from_pretrained("roberta-large", config=config)
 
22
 
23
- # Define classifier layers
24
  classifier_dropout = self.args.dropout
25
  self.dropout = nn.Dropout(classifier_dropout)
26
  self.classifier = nn.Linear(self.config.hidden_size, 6)
27
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  if self.args.signal_classification and not self.args.pretrained_signal_detector:
29
  self.signal_classifier = nn.Linear(self.config.hidden_size, 2)
30
 
31
-
32
-
33
  def forward(
34
  self,
35
- input_ids=None,
36
- attention_mask=None,
37
- token_type_ids=None,
38
- position_ids=None,
39
- signal_bias_mask=None,
40
- head_mask=None,
41
- inputs_embeds=None,
42
- start_positions=None,
43
- end_positions=None,
44
- output_attentions=None,
45
- output_hidden_states=None,
46
- return_dict=None,
47
  ):
 
 
 
 
 
 
 
 
 
 
48
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
49
 
50
- # Ensure that self.model is not None before calling forward
51
- if self.model is None:
52
- raise ValueError("The model weights have not been loaded. Use from_pretrained() to load them.")
53
-
54
- outputs = self.model(
55
- input_ids=input_ids,
56
- attention_mask=attention_mask,
57
- token_type_ids=token_type_ids,
58
- position_ids=position_ids,
59
- head_mask=head_mask,
60
- inputs_embeds=inputs_embeds,
61
- output_attentions=output_attentions,
62
- output_hidden_states=output_hidden_states,
63
- return_dict=return_dict,
64
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
 
66
  sequence_output = outputs[0]
 
 
67
 
68
  sequence_output = self.dropout(sequence_output)
69
- logits = self.classifier(sequence_output)
70
-
71
- # Split logits
72
  start_arg0_logits, end_arg0_logits, start_arg1_logits, end_arg1_logits, start_sig_logits, end_sig_logits = logits.split(1, dim=-1)
73
- start_arg0_logits = start_arg0_logits.squeeze(-1)
74
- end_arg0_logits = end_arg0_logits.squeeze(-1)
75
- start_arg1_logits = start_arg1_logits.squeeze(-1)
76
- end_arg1_logits = end_arg1_logits.squeeze(-1)
77
- start_sig_logits = start_sig_logits.squeeze(-1)
78
- end_sig_logits = end_sig_logits.squeeze(-1)
 
 
 
 
 
 
 
 
 
 
79
 
80
  signal_classification_logits = None
81
  if self.args.signal_classification and not self.args.pretrained_signal_detector:
82
  signal_classification_logits = self.signal_classifier(sequence_output[:, 0, :])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
 
84
  return {
85
  'start_arg0_logits': start_arg0_logits,
@@ -88,27 +208,71 @@ class ST2ModelV2(nn.Module):
88
  'end_arg1_logits': end_arg1_logits,
89
  'start_sig_logits': start_sig_logits,
90
  'end_sig_logits': end_sig_logits,
91
- 'signal_classification_logits': signal_classification_logits
 
 
 
 
 
92
  }
93
 
94
- @classmethod
95
- def from_pretrained(cls, model_name, config=None, args=None, **kwargs):
96
- """
97
- Custom from_pretrained method to load the model from Hugging Face and initialize
98
- any additional components such as the classifier.
99
- """
100
- # Load the configuration
101
- config = AutoConfig.from_pretrained(model_name) if config is None else config
102
-
103
- # Instantiate the model
104
- model = cls(args, config)
105
-
106
- # Load the pre-trained weights into the model
107
- model.model = AutoModel.from_pretrained(model_name, config=config, **kwargs, use_safetensors=False)
108
 
 
109
 
110
- return model
 
 
 
 
 
 
 
 
 
111
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
112
  def position_selector(
113
  self,
114
  start_cause_logits,
@@ -117,7 +281,7 @@ class ST2ModelV2(nn.Module):
117
  end_effect_logits,
118
  attention_mask,
119
  word_ids,
120
- ):
121
  # basic post processing (removing logits from [CLS], [SEP], [PAD])
122
  start_cause_logits -= (1 - attention_mask) * 1e4
123
  end_cause_logits -= (1 - attention_mask) * 1e4
@@ -190,9 +354,27 @@ class ST2ModelV2(nn.Module):
190
  start_effect_logits,
191
  end_cause_logits,
192
  end_effect_logits,
 
 
193
  topk=5
194
- ):
195
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
196
  start_cause_logits = torch.log(torch.softmax(start_cause_logits, dim=-1))
197
  end_cause_logits = torch.log(torch.softmax(end_cause_logits, dim=-1))
198
  start_effect_logits = torch.log(torch.softmax(start_effect_logits, dim=-1))
@@ -200,12 +382,19 @@ class ST2ModelV2(nn.Module):
200
 
201
  scores = dict()
202
  for i in range(len(end_cause_logits)):
203
-
 
204
  for j in range(i + 1, len(start_effect_logits)):
 
 
205
  scores[str((i, j, "before"))] = end_cause_logits[i].item() + start_effect_logits[j].item()
206
 
207
  for i in range(len(end_effect_logits)):
 
 
208
  for j in range(i + 1, len(start_cause_logits)):
 
 
209
  scores[str((i, j, "after"))] = start_cause_logits[j].item() + end_effect_logits[i].item()
210
 
211
 
 
7
  AutoConfig,
8
  AutoModelForSequenceClassification
9
  )
10
+ import os
11
+ from safetensors.torch import save_file
12
+
13
+
14
+
15
+ class SignalDetector(nn.Module):
16
+ def __init__(self, model_and_tokenizer_path) -> None:
17
+ super().__init__()
18
+ self.tokenizer = AutoTokenizer.from_pretrained(model_and_tokenizer_path)
19
+ self.signal_detector = AutoModelForSequenceClassification.from_pretrained(model_and_tokenizer_path)
20
+ self.signal_detector.eval()
21
+ self.signal_detector.cuda()
22
+
23
+ @torch.no_grad()
24
+ def predict(self, text: str) -> int:
25
+ input_ids = self.tokenizer.encode(text)
26
+ input_ids = torch.tensor([input_ids]).cuda()
27
+ outputs = self.signal_detector(input_ids)
28
+ return outputs[0].argmax().item()
29
+
30
+
31
  class ST2ModelV2(nn.Module):
32
+ def __init__(self, args):
33
  super(ST2ModelV2, self).__init__()
34
  self.args = args
 
35
 
36
+ self.config = AutoConfig.from_pretrained(args.model_name_or_path)
37
+ self.model = AutoModel.from_pretrained(args.model_name_or_path)
38
+ self.tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path)
39
 
 
40
  classifier_dropout = self.args.dropout
41
  self.dropout = nn.Dropout(classifier_dropout)
42
  self.classifier = nn.Linear(self.config.hidden_size, 6)
43
 
44
+ if args.mlp:
45
+ self.classifier = nn.Sequential(
46
+ nn.Linear(self.config.hidden_size, self.config.hidden_size),
47
+ nn.ReLU(),
48
+ nn.Linear(self.config.hidden_size, 6),
49
+ nn.Tanh(),
50
+ nn.Linear(6, 6),
51
+ )
52
+
53
+ if args.add_signal_bias:
54
+ self.signal_phrases_layer = nn.Parameter(
55
+ torch.normal(
56
+ mean=self.model.embeddings.word_embeddings.weight.data.mean(),
57
+ std=self.model.embeddings.word_embeddings.weight.data.std(),
58
+ size=(1, self.config.hidden_size),
59
+ )
60
+ )
61
+
62
  if self.args.signal_classification and not self.args.pretrained_signal_detector:
63
  self.signal_classifier = nn.Linear(self.config.hidden_size, 2)
64
 
 
 
65
  def forward(
66
  self,
67
+ input_ids: Optional[torch.Tensor] = None,
68
+ attention_mask: Optional[torch.Tensor] = None,
69
+ token_type_ids: Optional[torch.Tensor] = None,
70
+ position_ids: Optional[torch.Tensor] = None,
71
+ signal_bias_mask: Optional[torch.Tensor] = None,
72
+ head_mask: Optional[torch.Tensor] = None,
73
+ inputs_embeds: Optional[torch.Tensor] = None,
74
+ start_positions: Optional[torch.Tensor] = None, # [batch_size, 3]
75
+ end_positions: Optional[torch.Tensor] = None, # [batch_size, 3]
76
+ output_attentions: Optional[bool] = None,
77
+ output_hidden_states: Optional[bool] = None,
78
+ return_dict: Optional[bool] = None,
79
  ):
80
+ r"""
81
+ start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
82
+ Labels for position (index) of the start of the labelled span for computing the token classification loss.
83
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
84
+ are not taken into account for computing the loss.
85
+ end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
86
+ Labels for position (index) of the end of the labelled span for computing the token classification loss.
87
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
88
+ are not taken into account for computing the loss.
89
+ """
90
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
91
 
92
+ if signal_bias_mask is not None and not self.args.signal_bias_on_top_of_lm:
93
+ inputs_embeds = self.signal_phrases_bias(input_ids, signal_bias_mask)
94
+
95
+ outputs = self.model(
96
+ # input_ids,
97
+ attention_mask=attention_mask,
98
+ token_type_ids=token_type_ids,
99
+ position_ids=position_ids,
100
+ head_mask=head_mask,
101
+ inputs_embeds=inputs_embeds,
102
+ output_attentions=output_attentions,
103
+ output_hidden_states=output_hidden_states,
104
+ return_dict=return_dict,
105
+ )
106
+ else:
107
+ if self.args.model_name_or_path in ['facebook/bart-large', 'facebook/bart-base', 'facebook/bart-large-cnn']:
108
+ outputs = self.model(
109
+ input_ids,
110
+ attention_mask=attention_mask,
111
+ head_mask=head_mask,
112
+ inputs_embeds=inputs_embeds,
113
+ output_attentions=output_attentions,
114
+ output_hidden_states=output_hidden_states,
115
+ return_dict=return_dict,
116
+ )
117
+ elif self.args.model_name_or_path in ['microsoft/deberta-base']:
118
+ outputs = self.model(
119
+ input_ids,
120
+ attention_mask=attention_mask,
121
+ token_type_ids=token_type_ids,
122
+ position_ids=position_ids,
123
+ inputs_embeds=inputs_embeds,
124
+ output_attentions=output_attentions,
125
+ output_hidden_states=output_hidden_states,
126
+ return_dict=return_dict,
127
+ )
128
+
129
+ else:
130
+ outputs = self.model(
131
+ input_ids,
132
+ attention_mask=attention_mask,
133
+ token_type_ids=token_type_ids,
134
+ position_ids=position_ids,
135
+ head_mask=head_mask,
136
+ inputs_embeds=inputs_embeds,
137
+ output_attentions=output_attentions,
138
+ output_hidden_states=output_hidden_states,
139
+ return_dict=return_dict,
140
+ )
141
 
142
  sequence_output = outputs[0]
143
+ if signal_bias_mask is not None and self.args.signal_bias_on_top_of_lm:
144
+ sequence_output[signal_bias_mask == 1] += self.signal_phrases_layer
145
 
146
  sequence_output = self.dropout(sequence_output)
147
+ logits = self.classifier(sequence_output) # [batch_size, max_seq_length, 6]
 
 
148
  start_arg0_logits, end_arg0_logits, start_arg1_logits, end_arg1_logits, start_sig_logits, end_sig_logits = logits.split(1, dim=-1)
149
+ start_arg0_logits = start_arg0_logits.squeeze(-1).contiguous()
150
+ end_arg0_logits = end_arg0_logits.squeeze(-1).contiguous()
151
+ start_arg1_logits = start_arg1_logits.squeeze(-1).contiguous()
152
+ end_arg1_logits = end_arg1_logits.squeeze(-1).contiguous()
153
+ start_sig_logits = start_sig_logits.squeeze(-1).contiguous()
154
+ end_sig_logits = end_sig_logits.squeeze(-1).contiguous()
155
+
156
+ # start_arg0_logits -= (1 - attention_mask) * 1e4
157
+ # end_arg0_logits -= (1 - attention_mask) * 1e4
158
+ # start_arg1_logits -= (1 - attention_mask) * 1e4
159
+ # end_arg1_logits -= (1 - attention_mask) * 1e4
160
+
161
+ # start_arg0_logits[:, 0] = -1e4
162
+ # end_arg0_logits[:, 0] = -1e4
163
+ # start_arg1_logits[:, 0] = -1e4
164
+ # end_arg1_logits[:, 0] = -1e4
165
 
166
  signal_classification_logits = None
167
  if self.args.signal_classification and not self.args.pretrained_signal_detector:
168
  signal_classification_logits = self.signal_classifier(sequence_output[:, 0, :])
169
+ # start_logits = start_logits.squeeze(-1).contiguous()
170
+ # end_logits = end_logits.squeeze(-1).contiguous()
171
+
172
+ arg0_loss = None
173
+ arg1_loss = None
174
+ sig_loss = None
175
+ total_loss = None
176
+ signal_classification_loss = None
177
+ if start_positions is not None and end_positions is not None:
178
+ loss_fct = nn.CrossEntropyLoss()
179
+
180
+ start_arg0_loss = loss_fct(start_arg0_logits, start_positions[:, 0])
181
+ end_arg0_loss = loss_fct(end_arg0_logits, end_positions[:, 0])
182
+ arg0_loss = (start_arg0_loss + end_arg0_loss) / 2
183
+
184
+ start_arg1_loss = loss_fct(start_arg1_logits, start_positions[:, 1])
185
+ end_arg1_loss = loss_fct(end_arg1_logits, end_positions[:, 1])
186
+ arg1_loss = (start_arg1_loss + end_arg1_loss) / 2
187
+
188
+ # sig_loss = 0.
189
+ start_sig_loss = loss_fct(start_sig_logits, start_positions[:, 2])
190
+ end_sig_loss = loss_fct(end_sig_logits, end_positions[:, 2])
191
+ sig_loss = (start_sig_loss + end_sig_loss) / 2
192
+
193
+ if sig_loss.isnan():
194
+ sig_loss = 0.
195
+
196
+ if self.args.signal_classification and not self.args.pretrained_signal_detector:
197
+ signal_classification_labels = end_positions[:, 2] != -100
198
+ signal_classification_loss = loss_fct(signal_classification_logits, signal_classification_labels.long())
199
+ total_loss = (arg0_loss + arg1_loss + sig_loss + signal_classification_loss) / 4
200
+ else:
201
+ total_loss = (arg0_loss + arg1_loss + sig_loss) / 3
202
+
203
 
204
  return {
205
  'start_arg0_logits': start_arg0_logits,
 
208
  'end_arg1_logits': end_arg1_logits,
209
  'start_sig_logits': start_sig_logits,
210
  'end_sig_logits': end_sig_logits,
211
+ 'signal_classification_logits': signal_classification_logits,
212
+ 'arg0_loss': arg0_loss,
213
+ 'arg1_loss': arg1_loss,
214
+ 'sig_loss': sig_loss,
215
+ 'signal_classification_loss': signal_classification_loss,
216
+ 'loss': total_loss,
217
  }
218
 
219
+ """
220
+ def save_pretrained(self, save_directory):
 
 
 
 
 
 
 
 
 
 
 
 
221
 
222
+ #Save model state dict as safetensor, configuration, and tokenizer files.
223
 
224
+ # Ensure the directory exists
225
+ os.makedirs(save_directory, exist_ok=True)
226
+
227
+ # Save model state dict as safetensor (use torch.save for PyTorch model)
228
+ model_path = os.path.join(save_directory, "model.safetensor")
229
+ save_file(self.state_dict(), model_path)
230
+
231
+ # Save config if available
232
+ config_save_path = os.path.join(save_directory, 'config.json')
233
+ self.config.to_json_file(config_save_path)
234
 
235
+
236
+ # Save tokenizer
237
+ if hasattr(self, 'tokenizer') and self.tokenizer is not None:
238
+ tokenizer_save_path = os.path.join(save_directory, 'tokenizer')
239
+ self.tokenizer.save_pretrained(tokenizer_save_path)
240
+ """
241
+
242
+
243
+ def save_pretrained(self, save_directory):
244
+ """
245
+ Save model state dict as safetensor, PyTorch .bin format, configuration, and tokenizer files.
246
+ """
247
+ # Ensure the directory exists
248
+ os.makedirs(save_directory, exist_ok=True)
249
+
250
+ # Save model state dict as safetensor
251
+ model_path_safetensor = os.path.join(save_directory, "model.safetensors")
252
+ save_file(self.state_dict(), model_path_safetensor) # Save as .safetensors
253
+
254
+ # Save model state dict as PyTorch .bin (traditional format)
255
+ model_path_bin = os.path.join(save_directory, "pytorch_model.bin")
256
+ torch.save(self.state_dict(), model_path_bin) # Save as .bin using PyTorch's torch.save()
257
+
258
+ # Save config if available
259
+ config_save_path = os.path.join(save_directory, 'config.json')
260
+ self.config.to_json_file(config_save_path)
261
+
262
+ """
263
+ # Save tokenizer if it exists
264
+ if hasattr(self, 'tokenizer') and self.tokenizer is not None:
265
+ tokenizer_save_path = os.path.join(save_directory, 'tokenizer')
266
+ self.tokenizer.save_pretrained(tokenizer_save_path)
267
+ """
268
+
269
+
270
+ def signal_phrases_bias(self, input_ids, signal_bias_mask):
271
+ inputs_embeds = self.model.get_input_embeddings()(input_ids)
272
+ inputs_embeds[signal_bias_mask == 1] += self.signal_phrases_layer # self.signal_phrases_layer(inputs_embeds[signal_bias_mask == 1])
273
+
274
+ return inputs_embeds
275
+
276
  def position_selector(
277
  self,
278
  start_cause_logits,
 
281
  end_effect_logits,
282
  attention_mask,
283
  word_ids,
284
+ ):
285
  # basic post processing (removing logits from [CLS], [SEP], [PAD])
286
  start_cause_logits -= (1 - attention_mask) * 1e4
287
  end_cause_logits -= (1 - attention_mask) * 1e4
 
354
  start_effect_logits,
355
  end_cause_logits,
356
  end_effect_logits,
357
+ attention_mask,
358
+ word_ids,
359
  topk=5
360
+ ):
361
+ # basic post processing (removing logits from [CLS], [SEP], [PAD])
362
+
363
+ start_cause_logits -= (1 - attention_mask) * 1e4
364
+ end_cause_logits -= (1 - attention_mask) * 1e4
365
+ start_effect_logits -= (1 - attention_mask) * 1e4
366
+ end_effect_logits -= (1 - attention_mask) * 1e4
367
+
368
+ start_cause_logits[0] = -1e4
369
+ end_cause_logits[0] = -1e4
370
+ start_effect_logits[0] = -1e4
371
+ end_effect_logits[0] = -1e4
372
+
373
+ start_cause_logits[len(word_ids) - 1] = -1e4
374
+ end_cause_logits[len(word_ids) - 1] = -1e4
375
+ start_effect_logits[len(word_ids) - 1] = -1e4
376
+ end_effect_logits[len(word_ids) - 1] = -1e4
377
+
378
  start_cause_logits = torch.log(torch.softmax(start_cause_logits, dim=-1))
379
  end_cause_logits = torch.log(torch.softmax(end_cause_logits, dim=-1))
380
  start_effect_logits = torch.log(torch.softmax(start_effect_logits, dim=-1))
 
382
 
383
  scores = dict()
384
  for i in range(len(end_cause_logits)):
385
+ if attention_mask[i] == 0:
386
+ break
387
  for j in range(i + 1, len(start_effect_logits)):
388
+ if attention_mask[j] == 0:
389
+ break
390
  scores[str((i, j, "before"))] = end_cause_logits[i].item() + start_effect_logits[j].item()
391
 
392
  for i in range(len(end_effect_logits)):
393
+ if attention_mask[i] == 0:
394
+ break
395
  for j in range(i + 1, len(start_cause_logits)):
396
+ if attention_mask[j] == 0:
397
+ break
398
  scores[str((i, j, "after"))] = start_cause_logits[j].item() + end_effect_logits[i].item()
399
 
400