import torch import torch.nn import torch.nn.functional as F from torch import Tensor from torch.nn import CrossEntropyLoss from transformers import BertModel, BertPreTrainedModel from transformers import RobertaModel, RobertaPreTrainedModel from transformers.modeling_outputs import TokenClassifierOutput from model.prefix_encoder import PrefixEncoder from model.deberta import DebertaModel, DebertaPreTrainedModel from model.debertaV2 import DebertaV2Model, DebertaV2PreTrainedModel class BertForTokenClassification(BertPreTrainedModel): _keys_to_ignore_on_load_unexpected = [r"pooler"] def __init__(self, config): super().__init__(config) self.num_labels = config.num_labels self.bert = BertModel(config, add_pooling_layer=False) self.dropout = torch.nn.Dropout(config.hidden_dropout_prob) self.classifier = torch.nn.Linear(config.hidden_size, config.num_labels) only_cls_head = True # False in SRL if only_cls_head: for param in self.bert.parameters(): param.requires_grad = False self.init_weights() bert_param = 0 for name, param in self.bert.named_parameters(): bert_param += param.numel() all_param = 0 for name, param in self.named_parameters(): all_param += param.numel() total_param = all_param - bert_param print('total param is {}'.format(total_param)) def forward( self, input_ids=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None, labels=None, output_attentions=None, output_hidden_states=None, return_dict=None, ): r""" labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): Labels for computing the token classification loss. Indices should be in ``[0, ..., config.num_labels - 1]``. """ return_dict = return_dict if return_dict is not None else self.config.use_return_dict outputs = self.bert( input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, position_ids=position_ids, head_mask=head_mask, inputs_embeds=inputs_embeds, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) sequence_output = outputs[0] sequence_output = self.dropout(sequence_output) logits = self.classifier(sequence_output) loss = None if labels is not None: loss_fct = CrossEntropyLoss() # Only keep active parts of the loss if attention_mask is not None: active_loss = attention_mask.view(-1) == 1 active_logits = logits.view(-1, self.num_labels) active_labels = torch.where( active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels) ) loss = loss_fct(active_logits, active_labels) else: loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) if not return_dict: output = (logits,) + outputs[2:] return ((loss,) + output) if loss is not None else output return TokenClassifierOutput( loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) class BertPrefixForTokenClassification(BertPreTrainedModel): def __init__(self, config): super().__init__(config) self.num_labels = config.num_labels self.bert = BertModel(config, add_pooling_layer=False) self.dropout = torch.nn.Dropout(config.hidden_dropout_prob) self.classifier = torch.nn.Linear(config.hidden_size, config.num_labels) from_pretrained = False if from_pretrained: self.classifier.load_state_dict(torch.load('model/checkpoint.pkl')) for param in self.bert.parameters(): param.requires_grad = False self.pre_seq_len = config.pre_seq_len self.n_layer = config.num_hidden_layers self.n_head = config.num_attention_heads self.n_embd = config.hidden_size // config.num_attention_heads self.prefix_tokens = torch.arange(self.pre_seq_len).long() self.prefix_encoder = PrefixEncoder(config) bert_param = 0 for name, param in self.bert.named_parameters(): bert_param += param.numel() all_param = 0 for name, param in self.named_parameters(): all_param += param.numel() total_param = all_param - bert_param print('total param is {}'.format(total_param)) # 9860105 def get_prompt(self, batch_size): prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(self.bert.device) past_key_values = self.prefix_encoder(prefix_tokens) # bsz, seqlen, _ = past_key_values.shape past_key_values = past_key_values.view( batch_size, self.pre_seq_len, self.n_layer * 2, self.n_head, self.n_embd ) past_key_values = self.dropout(past_key_values) past_key_values = past_key_values.permute([2, 0, 3, 1, 4]).split(2) return past_key_values def forward( self, input_ids=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None, labels=None, output_attentions=None, output_hidden_states=None, return_dict=None, ): return_dict = return_dict if return_dict is not None else self.config.use_return_dict batch_size = input_ids.shape[0] past_key_values = self.get_prompt(batch_size=batch_size) prefix_attention_mask = torch.ones(batch_size, self.pre_seq_len).to(self.bert.device) attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1) outputs = self.bert( input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, position_ids=position_ids, head_mask=head_mask, inputs_embeds=inputs_embeds, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, past_key_values=past_key_values, ) sequence_output = outputs[0] sequence_output = self.dropout(sequence_output) logits = self.classifier(sequence_output) attention_mask = attention_mask[:,self.pre_seq_len:].contiguous() loss = None if labels is not None: loss_fct = CrossEntropyLoss() # Only keep active parts of the loss if attention_mask is not None: active_loss = attention_mask.view(-1) == 1 active_logits = logits.view(-1, self.num_labels) active_labels = torch.where( active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels) ) loss = loss_fct(active_logits, active_labels) else: loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) if not return_dict: output = (logits,) + outputs[2:] return ((loss,) + output) if loss is not None else output return TokenClassifierOutput( loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) class RobertaPrefixForTokenClassification(RobertaPreTrainedModel): def __init__(self, config): super().__init__(config) self.num_labels = config.num_labels self.roberta = RobertaModel(config, add_pooling_layer=False) self.dropout = torch.nn.Dropout(config.hidden_dropout_prob) self.classifier = torch.nn.Linear(config.hidden_size, config.num_labels) self.init_weights() for param in self.roberta.parameters(): param.requires_grad = False self.pre_seq_len = config.pre_seq_len self.n_layer = config.num_hidden_layers self.n_head = config.num_attention_heads self.n_embd = config.hidden_size // config.num_attention_heads self.prefix_tokens = torch.arange(self.pre_seq_len).long() self.prefix_encoder = PrefixEncoder(config) bert_param = 0 for name, param in self.roberta.named_parameters(): bert_param += param.numel() all_param = 0 for name, param in self.named_parameters(): all_param += param.numel() total_param = all_param - bert_param print('total param is {}'.format(total_param)) # 9860105 def get_prompt(self, batch_size): prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(self.roberta.device) past_key_values = self.prefix_encoder(prefix_tokens) past_key_values = past_key_values.view( batch_size, self.pre_seq_len, self.n_layer * 2, self.n_head, self.n_embd ) past_key_values = self.dropout(past_key_values) past_key_values = past_key_values.permute([2, 0, 3, 1, 4]).split(2) return past_key_values def forward( self, input_ids=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None, labels=None, output_attentions=None, output_hidden_states=None, return_dict=None, ): return_dict = return_dict if return_dict is not None else self.config.use_return_dict batch_size = input_ids.shape[0] past_key_values = self.get_prompt(batch_size=batch_size) prefix_attention_mask = torch.ones(batch_size, self.pre_seq_len).to(self.roberta.device) attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1) outputs = self.roberta( input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, position_ids=position_ids, head_mask=head_mask, inputs_embeds=inputs_embeds, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, past_key_values=past_key_values, ) sequence_output = outputs[0] sequence_output = self.dropout(sequence_output) logits = self.classifier(sequence_output) attention_mask = attention_mask[:,self.pre_seq_len:].contiguous() loss = None if labels is not None: loss_fct = CrossEntropyLoss() # Only keep active parts of the loss if attention_mask is not None: active_loss = attention_mask.view(-1) == 1 active_logits = logits.view(-1, self.num_labels) active_labels = torch.where( active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels) ) loss = loss_fct(active_logits, active_labels) else: loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) if not return_dict: output = (logits,) + outputs[2:] return ((loss,) + output) if loss is not None else output return TokenClassifierOutput( loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) class DebertaPrefixForTokenClassification(DebertaPreTrainedModel): def __init__(self, config): super().__init__(config) self.num_labels = config.num_labels self.deberta = DebertaModel(config) self.dropout = torch.nn.Dropout(config.hidden_dropout_prob) self.classifier = torch.nn.Linear(config.hidden_size, config.num_labels) self.init_weights() for param in self.deberta.parameters(): param.requires_grad = False self.pre_seq_len = config.pre_seq_len self.n_layer = config.num_hidden_layers self.n_head = config.num_attention_heads self.n_embd = config.hidden_size // config.num_attention_heads self.prefix_tokens = torch.arange(self.pre_seq_len).long() self.prefix_encoder = PrefixEncoder(config) deberta_param = 0 for name, param in self.deberta.named_parameters(): deberta_param += param.numel() all_param = 0 for name, param in self.named_parameters(): all_param += param.numel() total_param = all_param - deberta_param print('total param is {}'.format(total_param)) # 9860105 def get_prompt(self, batch_size): prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(self.deberta.device) past_key_values = self.prefix_encoder(prefix_tokens) # bsz, seqlen, _ = past_key_values.shape past_key_values = past_key_values.view( batch_size, self.pre_seq_len, self.n_layer * 2, self.n_head, self.n_embd ) past_key_values = self.dropout(past_key_values) past_key_values = past_key_values.permute([2, 0, 3, 1, 4]).split(2) return past_key_values def forward( self, input_ids=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None, labels=None, output_attentions=None, output_hidden_states=None, return_dict=None, ): return_dict = return_dict if return_dict is not None else self.config.use_return_dict batch_size = input_ids.shape[0] past_key_values = self.get_prompt(batch_size=batch_size) prefix_attention_mask = torch.ones(batch_size, self.pre_seq_len).to(self.deberta.device) attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1) outputs = self.deberta( input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, position_ids=position_ids, inputs_embeds=inputs_embeds, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, past_key_values=past_key_values, ) sequence_output = outputs[0] sequence_output = self.dropout(sequence_output) logits = self.classifier(sequence_output) attention_mask = attention_mask[:,self.pre_seq_len:].contiguous() loss = None if labels is not None: loss_fct = CrossEntropyLoss() # Only keep active parts of the loss if attention_mask is not None: active_loss = attention_mask.view(-1) == 1 active_logits = logits.view(-1, self.num_labels) active_labels = torch.where( active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels) ) loss = loss_fct(active_logits, active_labels) else: loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) if not return_dict: output = (logits,) + outputs[2:] return ((loss,) + output) if loss is not None else output return TokenClassifierOutput( loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) class DebertaV2PrefixForTokenClassification(DebertaV2PreTrainedModel): def __init__(self, config): super().__init__(config) self.num_labels = config.num_labels self.deberta = DebertaV2Model(config) self.dropout = torch.nn.Dropout(config.hidden_dropout_prob) self.classifier = torch.nn.Linear(config.hidden_size, config.num_labels) self.init_weights() for param in self.deberta.parameters(): param.requires_grad = False self.pre_seq_len = config.pre_seq_len self.n_layer = config.num_hidden_layers self.n_head = config.num_attention_heads self.n_embd = config.hidden_size // config.num_attention_heads self.prefix_tokens = torch.arange(self.pre_seq_len).long() self.prefix_encoder = PrefixEncoder(config) deberta_param = 0 for name, param in self.deberta.named_parameters(): deberta_param += param.numel() all_param = 0 for name, param in self.named_parameters(): all_param += param.numel() total_param = all_param - deberta_param print('total param is {}'.format(total_param)) # 9860105 def get_prompt(self, batch_size): prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(self.deberta.device) past_key_values = self.prefix_encoder(prefix_tokens) past_key_values = past_key_values.view( batch_size, self.pre_seq_len, self.n_layer * 2, self.n_head, self.n_embd ) past_key_values = self.dropout(past_key_values) past_key_values = past_key_values.permute([2, 0, 3, 1, 4]).split(2) return past_key_values def forward( self, input_ids=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None, labels=None, output_attentions=None, output_hidden_states=None, return_dict=None, ): return_dict = return_dict if return_dict is not None else self.config.use_return_dict batch_size = input_ids.shape[0] past_key_values = self.get_prompt(batch_size=batch_size) prefix_attention_mask = torch.ones(batch_size, self.pre_seq_len).to(self.deberta.device) attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1) outputs = self.deberta( input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, position_ids=position_ids, inputs_embeds=inputs_embeds, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, past_key_values=past_key_values, ) sequence_output = outputs[0] sequence_output = self.dropout(sequence_output) logits = self.classifier(sequence_output) attention_mask = attention_mask[:,self.pre_seq_len:].contiguous() loss = None if labels is not None: loss_fct = CrossEntropyLoss() # Only keep active parts of the loss if attention_mask is not None: active_loss = attention_mask.view(-1) == 1 active_logits = logits.view(-1, self.num_labels) active_labels = torch.where( active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels) ) loss = loss_fct(active_logits, active_labels) else: loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) if not return_dict: output = (logits,) + outputs[2:] return ((loss,) + output) if loss is not None else output return TokenClassifierOutput( loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions, )