anamargarida commited on
Commit
b990576
·
verified ·
1 Parent(s): 1c03e76

Update modeling_st2.py

Browse files
Files changed (1) hide show
  1. modeling_st2.py +2 -89
modeling_st2.py CHANGED
@@ -62,16 +62,7 @@ class ST2ModelV2(nn.Module):
62
  output_hidden_states: Optional[bool] = None,
63
  return_dict: Optional[bool] = None,
64
  ):
65
- r"""
66
- start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
67
- Labels for position (index) of the start of the labelled span for computing the token classification loss.
68
- Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
69
- are not taken into account for computing the loss.
70
- end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
71
- Labels for position (index) of the end of the labelled span for computing the token classification loss.
72
- Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
73
- are not taken into account for computing the loss.
74
- """
75
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
76
 
77
 
@@ -116,9 +107,7 @@ class ST2ModelV2(nn.Module):
116
  signal_classification_logits = self.signal_classifier(sequence_output[:, 0, :])
117
  # start_logits = start_logits.squeeze(-1).contiguous()
118
  # end_logits = end_logits.squeeze(-1).contiguous()
119
-
120
-
121
-
122
 
123
  return {
124
  'start_arg0_logits': start_arg0_logits,
@@ -131,81 +120,6 @@ class ST2ModelV2(nn.Module):
131
 
132
  }
133
 
134
-
135
-
136
- def position_selector(
137
- self,
138
- start_cause_logits,
139
- start_effect_logits,
140
- end_cause_logits,
141
- end_effect_logits,
142
- attention_mask,
143
- word_ids,
144
- ):
145
- # basic post processing (removing logits from [CLS], [SEP], [PAD])
146
- start_cause_logits -= (1 - attention_mask) * 1e4
147
- end_cause_logits -= (1 - attention_mask) * 1e4
148
- start_effect_logits -= (1 - attention_mask) * 1e4
149
- end_effect_logits -= (1 - attention_mask) * 1e4
150
-
151
- start_cause_logits[0] = -1e4
152
- end_cause_logits[0] = -1e4
153
- start_effect_logits[0] = -1e4
154
- end_effect_logits[0] = -1e4
155
-
156
- start_cause_logits[len(word_ids) - 1] = -1e4
157
- end_cause_logits[len(word_ids) - 1] = -1e4
158
- start_effect_logits[len(word_ids) - 1] = -1e4
159
- end_effect_logits[len(word_ids) - 1] = -1e4
160
-
161
- start_cause_logits = torch.log(torch.softmax(start_cause_logits, dim=-1))
162
- end_cause_logits = torch.log(torch.softmax(end_cause_logits, dim=-1))
163
- start_effect_logits = torch.log(torch.softmax(start_effect_logits, dim=-1))
164
- end_effect_logits = torch.log(torch.softmax(end_effect_logits, dim=-1))
165
-
166
- max_arg0_before_arg1 = None
167
- for i in range(len(end_cause_logits)):
168
- if attention_mask[i] == 0:
169
- break
170
- for j in range(i + 1, len(start_effect_logits)):
171
- if attention_mask[j] == 0:
172
- break
173
-
174
- if max_arg0_before_arg1 is None:
175
- max_arg0_before_arg1 = ((i, j), end_cause_logits[i] + start_effect_logits[j])
176
- else:
177
- if end_cause_logits[i] + start_effect_logits[j] > max_arg0_before_arg1[1]:
178
- max_arg0_before_arg1 = ((i, j), end_cause_logits[i] + start_effect_logits[j])
179
-
180
- max_arg0_after_arg1 = None
181
- for i in range(len(end_effect_logits)):
182
- if attention_mask[i] == 0:
183
- break
184
- for j in range(i + 1, len(start_cause_logits)):
185
- if attention_mask[j] == 0:
186
- break
187
- if max_arg0_after_arg1 is None:
188
- max_arg0_after_arg1 = ((i, j), start_cause_logits[j] + end_effect_logits[i])
189
- else:
190
- if start_cause_logits[j] + end_effect_logits[i] > max_arg0_after_arg1[1]:
191
- max_arg0_after_arg1 = ((i, j), start_cause_logits[j] + end_effect_logits[i])
192
-
193
- if max_arg0_before_arg1[1].item() > max_arg0_after_arg1[1].item():
194
- end_cause, start_effect = max_arg0_before_arg1[0]
195
- start_cause_logits[end_cause + 1:] = -1e4
196
- start_cause = start_cause_logits.argmax().item()
197
-
198
- end_effect_logits[:start_effect] = -1e4
199
- end_effect = end_effect_logits.argmax().item()
200
- else:
201
- end_effect, start_cause = max_arg0_after_arg1[0]
202
- end_cause_logits[:start_cause] = -1e4
203
- end_cause = end_cause_logits.argmax().item()
204
-
205
- start_effect_logits[end_effect + 1:] = -1e4
206
- start_effect = start_effect_logits.argmax().item()
207
-
208
- return start_cause, end_cause, start_effect, end_effect
209
 
210
 
211
  def beam_search_position_selector(
@@ -216,7 +130,6 @@ class ST2ModelV2(nn.Module):
216
  end_effect_logits,
217
  topk=5
218
  ):
219
- # basic post processing (removing logits from [CLS], [SEP], [PAD])
220
 
221
  start_cause_logits = torch.log(torch.softmax(start_cause_logits, dim=-1))
222
  end_cause_logits = torch.log(torch.softmax(end_cause_logits, dim=-1))
 
62
  output_hidden_states: Optional[bool] = None,
63
  return_dict: Optional[bool] = None,
64
  ):
65
+
 
 
 
 
 
 
 
 
 
66
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
67
 
68
 
 
107
  signal_classification_logits = self.signal_classifier(sequence_output[:, 0, :])
108
  # start_logits = start_logits.squeeze(-1).contiguous()
109
  # end_logits = end_logits.squeeze(-1).contiguous()
110
+
 
 
111
 
112
  return {
113
  'start_arg0_logits': start_arg0_logits,
 
120
 
121
  }
122
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
123
 
124
 
125
  def beam_search_position_selector(
 
130
  end_effect_logits,
131
  topk=5
132
  ):
 
133
 
134
  start_cause_logits = torch.log(torch.softmax(start_cause_logits, dim=-1))
135
  end_cause_logits = torch.log(torch.softmax(end_cause_logits, dim=-1))