anamargarida commited on
Commit
5803690
·
verified ·
1 Parent(s): 2ce452b

Create ST2ModelV2_6.py

Browse files
Files changed (1) hide show
  1. ST2ModelV2_6.py +247 -0
ST2ModelV2_6.py ADDED
@@ -0,0 +1,247 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from typing import Optional
4
+ from transformers import (
5
+ AutoModel,
6
+ AutoTokenizer,
7
+ AutoConfig,
8
+ AutoModelForSequenceClassification
9
+ )
10
+ from statistics import mode
11
+ from safetensors.torch import load_file
12
+
13
+
14
+ class ST2ModelV2(nn.Module):
15
+ def __init__(self, args, config):
16
+ super(ST2ModelV2, self).__init__()
17
+ self.args = args
18
+ self.config = config
19
+
20
+ # Load the base model (e.g., Roberta)
21
+ self.model = AutoModel.from_pretrained("roberta-large", config=config)
22
+
23
+ # Define classifier layers
24
+ classifier_dropout = self.args.dropout
25
+ self.dropout = nn.Dropout(classifier_dropout)
26
+ self.classifier = nn.Linear(self.config.hidden_size, 6)
27
+
28
+ if self.args.signal_classification and not self.args.pretrained_signal_detector:
29
+ self.signal_classifier = nn.Linear(self.config.hidden_size, 2)
30
+
31
+
32
+
33
+ def forward(
34
+ self,
35
+ input_ids=None,
36
+ attention_mask=None,
37
+ token_type_ids=None,
38
+ position_ids=None,
39
+ signal_bias_mask=None,
40
+ head_mask=None,
41
+ inputs_embeds=None,
42
+ start_positions=None,
43
+ end_positions=None,
44
+ output_attentions=None,
45
+ output_hidden_states=None,
46
+ return_dict=None,
47
+ ):
48
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
49
+
50
+ # Ensure that self.model is not None before calling forward
51
+ if self.model is None:
52
+ raise ValueError("The model weights have not been loaded. Use from_pretrained() to load them.")
53
+
54
+ outputs = self.model(
55
+ input_ids=input_ids,
56
+ attention_mask=attention_mask,
57
+ token_type_ids=token_type_ids,
58
+ position_ids=position_ids,
59
+ head_mask=head_mask,
60
+ inputs_embeds=inputs_embeds,
61
+ output_attentions=output_attentions,
62
+ output_hidden_states=output_hidden_states,
63
+ return_dict=return_dict,
64
+ )
65
+
66
+ sequence_output = outputs[0]
67
+
68
+ sequence_output = self.dropout(sequence_output)
69
+ logits = self.classifier(sequence_output)
70
+
71
+ # Split logits
72
+ start_arg0_logits, end_arg0_logits, start_arg1_logits, end_arg1_logits, start_sig_logits, end_sig_logits = logits.split(1, dim=-1)
73
+ start_arg0_logits = start_arg0_logits.squeeze(-1)
74
+ end_arg0_logits = end_arg0_logits.squeeze(-1)
75
+ start_arg1_logits = start_arg1_logits.squeeze(-1)
76
+ end_arg1_logits = end_arg1_logits.squeeze(-1)
77
+ start_sig_logits = start_sig_logits.squeeze(-1)
78
+ end_sig_logits = end_sig_logits.squeeze(-1)
79
+
80
+ signal_classification_logits = None
81
+ if self.args.signal_classification and not self.args.pretrained_signal_detector:
82
+ signal_classification_logits = self.signal_classifier(sequence_output[:, 0, :])
83
+
84
+ return {
85
+ 'start_arg0_logits': start_arg0_logits,
86
+ 'end_arg0_logits': end_arg0_logits,
87
+ 'start_arg1_logits': start_arg1_logits,
88
+ 'end_arg1_logits': end_arg1_logits,
89
+ 'start_sig_logits': start_sig_logits,
90
+ 'end_sig_logits': end_sig_logits,
91
+ 'signal_classification_logits': signal_classification_logits
92
+ }
93
+
94
+ @classmethod
95
+ def from_pretrained(cls, model_name, config=None, args=None, **kwargs):
96
+ """
97
+ Custom from_pretrained method to load the model from Hugging Face and initialize
98
+ any additional components such as the classifier.
99
+ """
100
+ # Load the configuration
101
+ config = AutoConfig.from_pretrained(model_name) if config is None else config
102
+
103
+ # Instantiate the model
104
+ model = cls(args, config)
105
+
106
+ # Load the pre-trained weights into the model
107
+ model.model = AutoModel.from_pretrained(model_name, config=config, **kwargs, use_safetensors=False)
108
+
109
+
110
+ return model
111
+
112
+ def position_selector(
113
+ self,
114
+ start_cause_logits,
115
+ start_effect_logits,
116
+ end_cause_logits,
117
+ end_effect_logits,
118
+ attention_mask,
119
+ word_ids,
120
+ ):
121
+ # basic post processing (removing logits from [CLS], [SEP], [PAD])
122
+ start_cause_logits -= (1 - attention_mask) * 1e4
123
+ end_cause_logits -= (1 - attention_mask) * 1e4
124
+ start_effect_logits -= (1 - attention_mask) * 1e4
125
+ end_effect_logits -= (1 - attention_mask) * 1e4
126
+
127
+ start_cause_logits[0] = -1e4
128
+ end_cause_logits[0] = -1e4
129
+ start_effect_logits[0] = -1e4
130
+ end_effect_logits[0] = -1e4
131
+
132
+ start_cause_logits[len(word_ids) - 1] = -1e4
133
+ end_cause_logits[len(word_ids) - 1] = -1e4
134
+ start_effect_logits[len(word_ids) - 1] = -1e4
135
+ end_effect_logits[len(word_ids) - 1] = -1e4
136
+
137
+ start_cause_logits = torch.log(torch.softmax(start_cause_logits, dim=-1))
138
+ end_cause_logits = torch.log(torch.softmax(end_cause_logits, dim=-1))
139
+ start_effect_logits = torch.log(torch.softmax(start_effect_logits, dim=-1))
140
+ end_effect_logits = torch.log(torch.softmax(end_effect_logits, dim=-1))
141
+
142
+ max_arg0_before_arg1 = None
143
+ for i in range(len(end_cause_logits)):
144
+ if attention_mask[i] == 0:
145
+ break
146
+ for j in range(i + 1, len(start_effect_logits)):
147
+ if attention_mask[j] == 0:
148
+ break
149
+
150
+ if max_arg0_before_arg1 is None:
151
+ max_arg0_before_arg1 = ((i, j), end_cause_logits[i] + start_effect_logits[j])
152
+ else:
153
+ if end_cause_logits[i] + start_effect_logits[j] > max_arg0_before_arg1[1]:
154
+ max_arg0_before_arg1 = ((i, j), end_cause_logits[i] + start_effect_logits[j])
155
+
156
+ max_arg0_after_arg1 = None
157
+ for i in range(len(end_effect_logits)):
158
+ if attention_mask[i] == 0:
159
+ break
160
+ for j in range(i + 1, len(start_cause_logits)):
161
+ if attention_mask[j] == 0:
162
+ break
163
+ if max_arg0_after_arg1 is None:
164
+ max_arg0_after_arg1 = ((i, j), start_cause_logits[j] + end_effect_logits[i])
165
+ else:
166
+ if start_cause_logits[j] + end_effect_logits[i] > max_arg0_after_arg1[1]:
167
+ max_arg0_after_arg1 = ((i, j), start_cause_logits[j] + end_effect_logits[i])
168
+
169
+ if max_arg0_before_arg1[1].item() > max_arg0_after_arg1[1].item():
170
+ end_cause, start_effect = max_arg0_before_arg1[0]
171
+ start_cause_logits[end_cause + 1:] = -1e4
172
+ start_cause = start_cause_logits.argmax().item()
173
+
174
+ end_effect_logits[:start_effect] = -1e4
175
+ end_effect = end_effect_logits.argmax().item()
176
+ else:
177
+ end_effect, start_cause = max_arg0_after_arg1[0]
178
+ end_cause_logits[:start_cause] = -1e4
179
+ end_cause = end_cause_logits.argmax().item()
180
+
181
+ start_effect_logits[end_effect + 1:] = -1e4
182
+ start_effect = start_effect_logits.argmax().item()
183
+
184
+ return start_cause, end_cause, start_effect, end_effect
185
+
186
+
187
+ def beam_search_position_selector(
188
+ self,
189
+ start_cause_logits,
190
+ start_effect_logits,
191
+ end_cause_logits,
192
+ end_effect_logits,
193
+ topk=5
194
+ ):
195
+
196
+ start_cause_logits = torch.log(torch.softmax(start_cause_logits, dim=-1))
197
+ end_cause_logits = torch.log(torch.softmax(end_cause_logits, dim=-1))
198
+ start_effect_logits = torch.log(torch.softmax(start_effect_logits, dim=-1))
199
+ end_effect_logits = torch.log(torch.softmax(end_effect_logits, dim=-1))
200
+
201
+ scores = dict()
202
+ for i in range(len(end_cause_logits)):
203
+
204
+ for j in range(i + 1, len(start_effect_logits)):
205
+ scores[str((i, j, "before"))] = end_cause_logits[i].item() + start_effect_logits[j].item()
206
+
207
+ for i in range(len(end_effect_logits)):
208
+ for j in range(i + 1, len(start_cause_logits)):
209
+ scores[str((i, j, "after"))] = start_cause_logits[j].item() + end_effect_logits[i].item()
210
+
211
+
212
+ topk_scores = dict()
213
+ for i, (index, score) in enumerate(sorted(scores.items(), key=lambda x: x[1], reverse=True)[:topk]):
214
+ if eval(index)[2] == 'before':
215
+ end_cause = eval(index)[0]
216
+ start_effect = eval(index)[1]
217
+
218
+ this_start_cause_logits = start_cause_logits.clone()
219
+ this_start_cause_logits[end_cause + 1:] = -1e9
220
+ start_cause_values, start_cause_indices = this_start_cause_logits.topk(topk)
221
+
222
+ this_end_effect_logits = end_effect_logits.clone()
223
+ this_end_effect_logits[:start_effect] = -1e9
224
+ end_effect_values, end_effect_indices = this_end_effect_logits.topk(topk)
225
+
226
+ for m in range(len(start_cause_values)):
227
+ for n in range(len(end_effect_values)):
228
+ topk_scores[str((start_cause_indices[m].item(), end_cause, start_effect, end_effect_indices[n].item()))] = score + start_cause_values[m].item() + end_effect_values[n].item()
229
+
230
+ elif eval(index)[2] == 'after':
231
+ start_cause = eval(index)[1]
232
+ end_effect = eval(index)[0]
233
+
234
+ this_end_cause_logits = end_cause_logits.clone()
235
+ this_end_cause_logits[:start_cause] = -1e9
236
+ end_cause_values, end_cause_indices = this_end_cause_logits.topk(topk)
237
+
238
+ this_start_effect_logits = start_effect_logits.clone()
239
+ this_start_effect_logits[end_effect + 1:] = -1e9
240
+ start_effect_values, start_effect_indices = this_start_effect_logits.topk(topk)
241
+
242
+ for m in range(len(end_cause_values)):
243
+ for n in range(len(start_effect_values)):
244
+ topk_scores[str((start_cause, end_cause_indices[m].item(), start_effect_indices[n].item(), end_effect))] = score + end_cause_values[m].item() + start_effect_values[n].item()
245
+
246
+ first, second = sorted(topk_scores.items(), key=lambda x: x[1], reverse=True)[:2]
247
+ return eval(first[0]), eval(second[0]), first[1], second[1], topk_scores