Trial / modeling_st2.py
anamargarida's picture
Update modeling_st2.py
b990576 verified
import torch
from torch import nn
from typing import Optional
from transformers import (
AutoModel,
AutoTokenizer,
AutoConfig,
AutoModelForSequenceClassification
)
import os
from safetensors.torch import save_file
class SignalDetector(nn.Module):
def __init__(self, model_and_tokenizer_path) -> None:
super().__init__()
self.tokenizer = AutoTokenizer.from_pretrained(model_and_tokenizer_path)
self.signal_detector = AutoModelForSequenceClassification.from_pretrained(model_and_tokenizer_path)
self.signal_detector.eval()
self.signal_detector.cuda()
@torch.no_grad()
def predict(self, text: str) -> int:
input_ids = self.tokenizer.encode(text)
input_ids = torch.tensor([input_ids]).cuda()
outputs = self.signal_detector(input_ids)
return outputs[0].argmax().item()
class ST2ModelV2(nn.Module):
def __init__(self, args):
super(ST2ModelV2, self).__init__()
self.args = args
self.config = AutoConfig.from_pretrained("roberta-large")
self.model = AutoModel.from_pretrained("roberta-large")
self.tokenizer = AutoTokenizer.from_pretrained("roberta-large")
classifier_dropout = self.args.dropout
self.dropout = nn.Dropout(classifier_dropout)
self.classifier = nn.Linear(self.config.hidden_size, 6)
if self.args.signal_classification and not self.args.pretrained_signal_detector:
self.signal_classifier = nn.Linear(self.config.hidden_size, 2)
def forward(
self,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
token_type_ids: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
offset_mapping=None,
signal_bias_mask: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
start_positions: Optional[torch.Tensor] = None, # [batch_size, 3]
end_positions: Optional[torch.Tensor] = None, # [batch_size, 3]
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
):
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.model(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
sequence_output = outputs[0]
sequence_output = self.dropout(sequence_output)
logits = self.classifier(sequence_output) # [batch_size, max_seq_length, 6]
start_arg0_logits, end_arg0_logits, start_arg1_logits, end_arg1_logits, start_sig_logits, end_sig_logits = logits.split(1, dim=-1)
start_arg0_logits = start_arg0_logits.squeeze(-1).contiguous()
end_arg0_logits = end_arg0_logits.squeeze(-1).contiguous()
start_arg1_logits = start_arg1_logits.squeeze(-1).contiguous()
end_arg1_logits = end_arg1_logits.squeeze(-1).contiguous()
start_sig_logits = start_sig_logits.squeeze(-1).contiguous()
end_sig_logits = end_sig_logits.squeeze(-1).contiguous()
# start_arg0_logits -= (1 - attention_mask) * 1e4
# end_arg0_logits -= (1 - attention_mask) * 1e4
# start_arg1_logits -= (1 - attention_mask) * 1e4
# end_arg1_logits -= (1 - attention_mask) * 1e4
# start_arg0_logits[:, 0] = -1e4
# end_arg0_logits[:, 0] = -1e4
# start_arg1_logits[:, 0] = -1e4
# end_arg1_logits[:, 0] = -1e4
signal_classification_logits = None
if self.args.signal_classification and not self.args.pretrained_signal_detector:
signal_classification_logits = self.signal_classifier(sequence_output[:, 0, :])
# start_logits = start_logits.squeeze(-1).contiguous()
# end_logits = end_logits.squeeze(-1).contiguous()
return {
'start_arg0_logits': start_arg0_logits,
'end_arg0_logits': end_arg0_logits,
'start_arg1_logits': start_arg1_logits,
'end_arg1_logits': end_arg1_logits,
'start_sig_logits': start_sig_logits,
'end_sig_logits': end_sig_logits,
'signal_classification_logits': signal_classification_logits,
}
def beam_search_position_selector(
self,
start_cause_logits,
start_effect_logits,
end_cause_logits,
end_effect_logits,
topk=5
):
start_cause_logits = torch.log(torch.softmax(start_cause_logits, dim=-1))
end_cause_logits = torch.log(torch.softmax(end_cause_logits, dim=-1))
start_effect_logits = torch.log(torch.softmax(start_effect_logits, dim=-1))
end_effect_logits = torch.log(torch.softmax(end_effect_logits, dim=-1))
scores = dict()
for i in range(len(end_cause_logits)):
for j in range(i + 1, len(start_effect_logits)):
scores[str((i, j, "before"))] = end_cause_logits[i].item() + start_effect_logits[j].item()
for i in range(len(end_effect_logits)):
for j in range(i + 1, len(start_cause_logits)):
scores[str((i, j, "after"))] = start_cause_logits[j].item() + end_effect_logits[i].item()
topk_scores = dict()
for i, (index, score) in enumerate(sorted(scores.items(), key=lambda x: x[1], reverse=True)[:topk]):
if eval(index)[2] == 'before':
end_cause = eval(index)[0]
start_effect = eval(index)[1]
this_start_cause_logits = start_cause_logits.clone()
this_start_cause_logits[end_cause + 1:] = -1e9
start_cause_values, start_cause_indices = this_start_cause_logits.topk(topk)
this_end_effect_logits = end_effect_logits.clone()
this_end_effect_logits[:start_effect] = -1e9
end_effect_values, end_effect_indices = this_end_effect_logits.topk(topk)
for m in range(len(start_cause_values)):
for n in range(len(end_effect_values)):
topk_scores[str((start_cause_indices[m].item(), end_cause, start_effect, end_effect_indices[n].item()))] = score + start_cause_values[m].item() + end_effect_values[n].item()
elif eval(index)[2] == 'after':
start_cause = eval(index)[1]
end_effect = eval(index)[0]
this_end_cause_logits = end_cause_logits.clone()
this_end_cause_logits[:start_cause] = -1e9
end_cause_values, end_cause_indices = this_end_cause_logits.topk(topk)
this_start_effect_logits = start_effect_logits.clone()
this_start_effect_logits[end_effect + 1:] = -1e9
start_effect_values, start_effect_indices = this_start_effect_logits.topk(topk)
for m in range(len(end_cause_values)):
for n in range(len(start_effect_values)):
topk_scores[str((start_cause, end_cause_indices[m].item(), start_effect_indices[n].item(), end_effect))] = score + end_cause_values[m].item() + start_effect_values[n].item()
first, second = sorted(topk_scores.items(), key=lambda x: x[1], reverse=True)[:2]
return eval(first[0]), eval(second[0]), first[1], second[1], topk_scores