Spaces:
Running
Running
import torch | |
from torch import nn | |
from typing import Optional | |
from transformers import ( | |
AutoModel, | |
AutoTokenizer, | |
AutoConfig, | |
AutoModelForSequenceClassification | |
) | |
import os | |
from safetensors.torch import save_file | |
class SignalDetector(nn.Module): | |
def __init__(self, model_and_tokenizer_path) -> None: | |
super().__init__() | |
self.tokenizer = AutoTokenizer.from_pretrained(model_and_tokenizer_path) | |
self.signal_detector = AutoModelForSequenceClassification.from_pretrained(model_and_tokenizer_path) | |
self.signal_detector.eval() | |
self.signal_detector.cuda() | |
def predict(self, text: str) -> int: | |
input_ids = self.tokenizer.encode(text) | |
input_ids = torch.tensor([input_ids]).cuda() | |
outputs = self.signal_detector(input_ids) | |
return outputs[0].argmax().item() | |
class ST2ModelV2(nn.Module): | |
def __init__(self, args): | |
super(ST2ModelV2, self).__init__() | |
self.args = args | |
self.config = AutoConfig.from_pretrained("roberta-large") | |
self.model = AutoModel.from_pretrained("roberta-large") | |
self.tokenizer = AutoTokenizer.from_pretrained("roberta-large") | |
classifier_dropout = self.args.dropout | |
self.dropout = nn.Dropout(classifier_dropout) | |
self.classifier = nn.Linear(self.config.hidden_size, 6) | |
if self.args.signal_classification and not self.args.pretrained_signal_detector: | |
self.signal_classifier = nn.Linear(self.config.hidden_size, 2) | |
def forward( | |
self, | |
input_ids: Optional[torch.Tensor] = None, | |
attention_mask: Optional[torch.Tensor] = None, | |
token_type_ids: Optional[torch.Tensor] = None, | |
position_ids: Optional[torch.Tensor] = None, | |
offset_mapping=None, | |
signal_bias_mask: Optional[torch.Tensor] = None, | |
head_mask: Optional[torch.Tensor] = None, | |
inputs_embeds: Optional[torch.Tensor] = None, | |
start_positions: Optional[torch.Tensor] = None, # [batch_size, 3] | |
end_positions: Optional[torch.Tensor] = None, # [batch_size, 3] | |
output_attentions: Optional[bool] = None, | |
output_hidden_states: Optional[bool] = None, | |
return_dict: Optional[bool] = None, | |
): | |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |
outputs = self.model( | |
input_ids, | |
attention_mask=attention_mask, | |
token_type_ids=token_type_ids, | |
position_ids=position_ids, | |
head_mask=head_mask, | |
inputs_embeds=inputs_embeds, | |
output_attentions=output_attentions, | |
output_hidden_states=output_hidden_states, | |
return_dict=return_dict, | |
) | |
sequence_output = outputs[0] | |
sequence_output = self.dropout(sequence_output) | |
logits = self.classifier(sequence_output) # [batch_size, max_seq_length, 6] | |
start_arg0_logits, end_arg0_logits, start_arg1_logits, end_arg1_logits, start_sig_logits, end_sig_logits = logits.split(1, dim=-1) | |
start_arg0_logits = start_arg0_logits.squeeze(-1).contiguous() | |
end_arg0_logits = end_arg0_logits.squeeze(-1).contiguous() | |
start_arg1_logits = start_arg1_logits.squeeze(-1).contiguous() | |
end_arg1_logits = end_arg1_logits.squeeze(-1).contiguous() | |
start_sig_logits = start_sig_logits.squeeze(-1).contiguous() | |
end_sig_logits = end_sig_logits.squeeze(-1).contiguous() | |
# start_arg0_logits -= (1 - attention_mask) * 1e4 | |
# end_arg0_logits -= (1 - attention_mask) * 1e4 | |
# start_arg1_logits -= (1 - attention_mask) * 1e4 | |
# end_arg1_logits -= (1 - attention_mask) * 1e4 | |
# start_arg0_logits[:, 0] = -1e4 | |
# end_arg0_logits[:, 0] = -1e4 | |
# start_arg1_logits[:, 0] = -1e4 | |
# end_arg1_logits[:, 0] = -1e4 | |
signal_classification_logits = None | |
if self.args.signal_classification and not self.args.pretrained_signal_detector: | |
signal_classification_logits = self.signal_classifier(sequence_output[:, 0, :]) | |
# start_logits = start_logits.squeeze(-1).contiguous() | |
# end_logits = end_logits.squeeze(-1).contiguous() | |
return { | |
'start_arg0_logits': start_arg0_logits, | |
'end_arg0_logits': end_arg0_logits, | |
'start_arg1_logits': start_arg1_logits, | |
'end_arg1_logits': end_arg1_logits, | |
'start_sig_logits': start_sig_logits, | |
'end_sig_logits': end_sig_logits, | |
'signal_classification_logits': signal_classification_logits, | |
} | |
def beam_search_position_selector( | |
self, | |
start_cause_logits, | |
start_effect_logits, | |
end_cause_logits, | |
end_effect_logits, | |
topk=5 | |
): | |
start_cause_logits = torch.log(torch.softmax(start_cause_logits, dim=-1)) | |
end_cause_logits = torch.log(torch.softmax(end_cause_logits, dim=-1)) | |
start_effect_logits = torch.log(torch.softmax(start_effect_logits, dim=-1)) | |
end_effect_logits = torch.log(torch.softmax(end_effect_logits, dim=-1)) | |
scores = dict() | |
for i in range(len(end_cause_logits)): | |
for j in range(i + 1, len(start_effect_logits)): | |
scores[str((i, j, "before"))] = end_cause_logits[i].item() + start_effect_logits[j].item() | |
for i in range(len(end_effect_logits)): | |
for j in range(i + 1, len(start_cause_logits)): | |
scores[str((i, j, "after"))] = start_cause_logits[j].item() + end_effect_logits[i].item() | |
topk_scores = dict() | |
for i, (index, score) in enumerate(sorted(scores.items(), key=lambda x: x[1], reverse=True)[:topk]): | |
if eval(index)[2] == 'before': | |
end_cause = eval(index)[0] | |
start_effect = eval(index)[1] | |
this_start_cause_logits = start_cause_logits.clone() | |
this_start_cause_logits[end_cause + 1:] = -1e9 | |
start_cause_values, start_cause_indices = this_start_cause_logits.topk(topk) | |
this_end_effect_logits = end_effect_logits.clone() | |
this_end_effect_logits[:start_effect] = -1e9 | |
end_effect_values, end_effect_indices = this_end_effect_logits.topk(topk) | |
for m in range(len(start_cause_values)): | |
for n in range(len(end_effect_values)): | |
topk_scores[str((start_cause_indices[m].item(), end_cause, start_effect, end_effect_indices[n].item()))] = score + start_cause_values[m].item() + end_effect_values[n].item() | |
elif eval(index)[2] == 'after': | |
start_cause = eval(index)[1] | |
end_effect = eval(index)[0] | |
this_end_cause_logits = end_cause_logits.clone() | |
this_end_cause_logits[:start_cause] = -1e9 | |
end_cause_values, end_cause_indices = this_end_cause_logits.topk(topk) | |
this_start_effect_logits = start_effect_logits.clone() | |
this_start_effect_logits[end_effect + 1:] = -1e9 | |
start_effect_values, start_effect_indices = this_start_effect_logits.topk(topk) | |
for m in range(len(end_cause_values)): | |
for n in range(len(start_effect_values)): | |
topk_scores[str((start_cause, end_cause_indices[m].item(), start_effect_indices[n].item(), end_effect))] = score + end_cause_values[m].item() + start_effect_values[n].item() | |
first, second = sorted(topk_scores.items(), key=lambda x: x[1], reverse=True)[:2] | |
return eval(first[0]), eval(second[0]), first[1], second[1], topk_scores |