import torch from torch import nn from typing import Optional from transformers import ( AutoModel, AutoTokenizer, AutoConfig, AutoModelForSequenceClassification ) import os from safetensors.torch import save_file class SignalDetector(nn.Module): def __init__(self, model_and_tokenizer_path) -> None: super().__init__() self.tokenizer = AutoTokenizer.from_pretrained(model_and_tokenizer_path) self.signal_detector = AutoModelForSequenceClassification.from_pretrained(model_and_tokenizer_path) self.signal_detector.eval() self.signal_detector.cuda() @torch.no_grad() def predict(self, text: str) -> int: input_ids = self.tokenizer.encode(text) input_ids = torch.tensor([input_ids]).cuda() outputs = self.signal_detector(input_ids) return outputs[0].argmax().item() class ST2ModelV2(nn.Module): def __init__(self, args): super(ST2ModelV2, self).__init__() self.args = args self.config = AutoConfig.from_pretrained("roberta-large") self.model = AutoModel.from_pretrained("roberta-large") self.tokenizer = AutoTokenizer.from_pretrained("roberta-large") classifier_dropout = self.args.dropout self.dropout = nn.Dropout(classifier_dropout) self.classifier = nn.Linear(self.config.hidden_size, 6) if self.args.signal_classification and not self.args.pretrained_signal_detector: self.signal_classifier = nn.Linear(self.config.hidden_size, 2) def forward( self, input_ids: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, token_type_ids: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, offset_mapping=None, signal_bias_mask: Optional[torch.Tensor] = None, head_mask: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None, start_positions: Optional[torch.Tensor] = None, # [batch_size, 3] end_positions: Optional[torch.Tensor] = None, # [batch_size, 3] output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ): return_dict = return_dict if return_dict is not None else self.config.use_return_dict outputs = self.model( input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, position_ids=position_ids, head_mask=head_mask, inputs_embeds=inputs_embeds, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) sequence_output = outputs[0] sequence_output = self.dropout(sequence_output) logits = self.classifier(sequence_output) # [batch_size, max_seq_length, 6] start_arg0_logits, end_arg0_logits, start_arg1_logits, end_arg1_logits, start_sig_logits, end_sig_logits = logits.split(1, dim=-1) start_arg0_logits = start_arg0_logits.squeeze(-1).contiguous() end_arg0_logits = end_arg0_logits.squeeze(-1).contiguous() start_arg1_logits = start_arg1_logits.squeeze(-1).contiguous() end_arg1_logits = end_arg1_logits.squeeze(-1).contiguous() start_sig_logits = start_sig_logits.squeeze(-1).contiguous() end_sig_logits = end_sig_logits.squeeze(-1).contiguous() # start_arg0_logits -= (1 - attention_mask) * 1e4 # end_arg0_logits -= (1 - attention_mask) * 1e4 # start_arg1_logits -= (1 - attention_mask) * 1e4 # end_arg1_logits -= (1 - attention_mask) * 1e4 # start_arg0_logits[:, 0] = -1e4 # end_arg0_logits[:, 0] = -1e4 # start_arg1_logits[:, 0] = -1e4 # end_arg1_logits[:, 0] = -1e4 signal_classification_logits = None if self.args.signal_classification and not self.args.pretrained_signal_detector: signal_classification_logits = self.signal_classifier(sequence_output[:, 0, :]) # start_logits = start_logits.squeeze(-1).contiguous() # end_logits = end_logits.squeeze(-1).contiguous() return { 'start_arg0_logits': start_arg0_logits, 'end_arg0_logits': end_arg0_logits, 'start_arg1_logits': start_arg1_logits, 'end_arg1_logits': end_arg1_logits, 'start_sig_logits': start_sig_logits, 'end_sig_logits': end_sig_logits, 'signal_classification_logits': signal_classification_logits, } def beam_search_position_selector( self, start_cause_logits, start_effect_logits, end_cause_logits, end_effect_logits, topk=5 ): start_cause_logits = torch.log(torch.softmax(start_cause_logits, dim=-1)) end_cause_logits = torch.log(torch.softmax(end_cause_logits, dim=-1)) start_effect_logits = torch.log(torch.softmax(start_effect_logits, dim=-1)) end_effect_logits = torch.log(torch.softmax(end_effect_logits, dim=-1)) scores = dict() for i in range(len(end_cause_logits)): for j in range(i + 1, len(start_effect_logits)): scores[str((i, j, "before"))] = end_cause_logits[i].item() + start_effect_logits[j].item() for i in range(len(end_effect_logits)): for j in range(i + 1, len(start_cause_logits)): scores[str((i, j, "after"))] = start_cause_logits[j].item() + end_effect_logits[i].item() topk_scores = dict() for i, (index, score) in enumerate(sorted(scores.items(), key=lambda x: x[1], reverse=True)[:topk]): if eval(index)[2] == 'before': end_cause = eval(index)[0] start_effect = eval(index)[1] this_start_cause_logits = start_cause_logits.clone() this_start_cause_logits[end_cause + 1:] = -1e9 start_cause_values, start_cause_indices = this_start_cause_logits.topk(topk) this_end_effect_logits = end_effect_logits.clone() this_end_effect_logits[:start_effect] = -1e9 end_effect_values, end_effect_indices = this_end_effect_logits.topk(topk) for m in range(len(start_cause_values)): for n in range(len(end_effect_values)): topk_scores[str((start_cause_indices[m].item(), end_cause, start_effect, end_effect_indices[n].item()))] = score + start_cause_values[m].item() + end_effect_values[n].item() elif eval(index)[2] == 'after': start_cause = eval(index)[1] end_effect = eval(index)[0] this_end_cause_logits = end_cause_logits.clone() this_end_cause_logits[:start_cause] = -1e9 end_cause_values, end_cause_indices = this_end_cause_logits.topk(topk) this_start_effect_logits = start_effect_logits.clone() this_start_effect_logits[end_effect + 1:] = -1e9 start_effect_values, start_effect_indices = this_start_effect_logits.topk(topk) for m in range(len(end_cause_values)): for n in range(len(start_effect_values)): topk_scores[str((start_cause, end_cause_indices[m].item(), start_effect_indices[n].item(), end_effect))] = score + end_cause_values[m].item() + start_effect_values[n].item() first, second = sorted(topk_scores.items(), key=lambda x: x[1], reverse=True)[:2] return eval(first[0]), eval(second[0]), first[1], second[1], topk_scores