Spaces:
Running
Running
Update modeling_st2.py
Browse files- modeling_st2.py +2 -89
modeling_st2.py
CHANGED
@@ -62,16 +62,7 @@ class ST2ModelV2(nn.Module):
|
|
62 |
output_hidden_states: Optional[bool] = None,
|
63 |
return_dict: Optional[bool] = None,
|
64 |
):
|
65 |
-
|
66 |
-
start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
67 |
-
Labels for position (index) of the start of the labelled span for computing the token classification loss.
|
68 |
-
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
|
69 |
-
are not taken into account for computing the loss.
|
70 |
-
end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
71 |
-
Labels for position (index) of the end of the labelled span for computing the token classification loss.
|
72 |
-
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
|
73 |
-
are not taken into account for computing the loss.
|
74 |
-
"""
|
75 |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
76 |
|
77 |
|
@@ -116,9 +107,7 @@ class ST2ModelV2(nn.Module):
|
|
116 |
signal_classification_logits = self.signal_classifier(sequence_output[:, 0, :])
|
117 |
# start_logits = start_logits.squeeze(-1).contiguous()
|
118 |
# end_logits = end_logits.squeeze(-1).contiguous()
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
|
123 |
return {
|
124 |
'start_arg0_logits': start_arg0_logits,
|
@@ -131,81 +120,6 @@ class ST2ModelV2(nn.Module):
|
|
131 |
|
132 |
}
|
133 |
|
134 |
-
|
135 |
-
|
136 |
-
def position_selector(
|
137 |
-
self,
|
138 |
-
start_cause_logits,
|
139 |
-
start_effect_logits,
|
140 |
-
end_cause_logits,
|
141 |
-
end_effect_logits,
|
142 |
-
attention_mask,
|
143 |
-
word_ids,
|
144 |
-
):
|
145 |
-
# basic post processing (removing logits from [CLS], [SEP], [PAD])
|
146 |
-
start_cause_logits -= (1 - attention_mask) * 1e4
|
147 |
-
end_cause_logits -= (1 - attention_mask) * 1e4
|
148 |
-
start_effect_logits -= (1 - attention_mask) * 1e4
|
149 |
-
end_effect_logits -= (1 - attention_mask) * 1e4
|
150 |
-
|
151 |
-
start_cause_logits[0] = -1e4
|
152 |
-
end_cause_logits[0] = -1e4
|
153 |
-
start_effect_logits[0] = -1e4
|
154 |
-
end_effect_logits[0] = -1e4
|
155 |
-
|
156 |
-
start_cause_logits[len(word_ids) - 1] = -1e4
|
157 |
-
end_cause_logits[len(word_ids) - 1] = -1e4
|
158 |
-
start_effect_logits[len(word_ids) - 1] = -1e4
|
159 |
-
end_effect_logits[len(word_ids) - 1] = -1e4
|
160 |
-
|
161 |
-
start_cause_logits = torch.log(torch.softmax(start_cause_logits, dim=-1))
|
162 |
-
end_cause_logits = torch.log(torch.softmax(end_cause_logits, dim=-1))
|
163 |
-
start_effect_logits = torch.log(torch.softmax(start_effect_logits, dim=-1))
|
164 |
-
end_effect_logits = torch.log(torch.softmax(end_effect_logits, dim=-1))
|
165 |
-
|
166 |
-
max_arg0_before_arg1 = None
|
167 |
-
for i in range(len(end_cause_logits)):
|
168 |
-
if attention_mask[i] == 0:
|
169 |
-
break
|
170 |
-
for j in range(i + 1, len(start_effect_logits)):
|
171 |
-
if attention_mask[j] == 0:
|
172 |
-
break
|
173 |
-
|
174 |
-
if max_arg0_before_arg1 is None:
|
175 |
-
max_arg0_before_arg1 = ((i, j), end_cause_logits[i] + start_effect_logits[j])
|
176 |
-
else:
|
177 |
-
if end_cause_logits[i] + start_effect_logits[j] > max_arg0_before_arg1[1]:
|
178 |
-
max_arg0_before_arg1 = ((i, j), end_cause_logits[i] + start_effect_logits[j])
|
179 |
-
|
180 |
-
max_arg0_after_arg1 = None
|
181 |
-
for i in range(len(end_effect_logits)):
|
182 |
-
if attention_mask[i] == 0:
|
183 |
-
break
|
184 |
-
for j in range(i + 1, len(start_cause_logits)):
|
185 |
-
if attention_mask[j] == 0:
|
186 |
-
break
|
187 |
-
if max_arg0_after_arg1 is None:
|
188 |
-
max_arg0_after_arg1 = ((i, j), start_cause_logits[j] + end_effect_logits[i])
|
189 |
-
else:
|
190 |
-
if start_cause_logits[j] + end_effect_logits[i] > max_arg0_after_arg1[1]:
|
191 |
-
max_arg0_after_arg1 = ((i, j), start_cause_logits[j] + end_effect_logits[i])
|
192 |
-
|
193 |
-
if max_arg0_before_arg1[1].item() > max_arg0_after_arg1[1].item():
|
194 |
-
end_cause, start_effect = max_arg0_before_arg1[0]
|
195 |
-
start_cause_logits[end_cause + 1:] = -1e4
|
196 |
-
start_cause = start_cause_logits.argmax().item()
|
197 |
-
|
198 |
-
end_effect_logits[:start_effect] = -1e4
|
199 |
-
end_effect = end_effect_logits.argmax().item()
|
200 |
-
else:
|
201 |
-
end_effect, start_cause = max_arg0_after_arg1[0]
|
202 |
-
end_cause_logits[:start_cause] = -1e4
|
203 |
-
end_cause = end_cause_logits.argmax().item()
|
204 |
-
|
205 |
-
start_effect_logits[end_effect + 1:] = -1e4
|
206 |
-
start_effect = start_effect_logits.argmax().item()
|
207 |
-
|
208 |
-
return start_cause, end_cause, start_effect, end_effect
|
209 |
|
210 |
|
211 |
def beam_search_position_selector(
|
@@ -216,7 +130,6 @@ class ST2ModelV2(nn.Module):
|
|
216 |
end_effect_logits,
|
217 |
topk=5
|
218 |
):
|
219 |
-
# basic post processing (removing logits from [CLS], [SEP], [PAD])
|
220 |
|
221 |
start_cause_logits = torch.log(torch.softmax(start_cause_logits, dim=-1))
|
222 |
end_cause_logits = torch.log(torch.softmax(end_cause_logits, dim=-1))
|
|
|
62 |
output_hidden_states: Optional[bool] = None,
|
63 |
return_dict: Optional[bool] = None,
|
64 |
):
|
65 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
66 |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
67 |
|
68 |
|
|
|
107 |
signal_classification_logits = self.signal_classifier(sequence_output[:, 0, :])
|
108 |
# start_logits = start_logits.squeeze(-1).contiguous()
|
109 |
# end_logits = end_logits.squeeze(-1).contiguous()
|
110 |
+
|
|
|
|
|
111 |
|
112 |
return {
|
113 |
'start_arg0_logits': start_arg0_logits,
|
|
|
120 |
|
121 |
}
|
122 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
123 |
|
124 |
|
125 |
def beam_search_position_selector(
|
|
|
130 |
end_effect_logits,
|
131 |
topk=5
|
132 |
):
|
|
|
133 |
|
134 |
start_cause_logits = torch.log(torch.softmax(start_cause_logits, dim=-1))
|
135 |
end_cause_logits = torch.log(torch.softmax(end_cause_logits, dim=-1))
|