import torch import torch.nn as nn from transformers.modeling_outputs import ( BaseModelOutput, SequenceClassifierOutput, ) from typing import Optional, Union, Tuple from .configuration_glm2 import gLM2Config from .modeling_glm2 import gLM2Model, gLM2PreTrainedModel from transformers import PretrainedConfig from typing import List class gLM2ClassicationConfig(gLM2Config): def __init__(self, num_classes: int = 2, **kwargs): super().__init__(**kwargs) self.num_classes = num_classes self.auto_map['AutoModelForSequenceClassification'] = "extension_glm2.gLM2ForSequenceClassification" class gLM2ForSequenceClassification(gLM2PreTrainedModel): config_class = gLM2ClassicationConfig def __init__(self, config: gLM2ClassicationConfig): super().__init__(config) self.glm2 = gLM2Model(config) self.score = nn.Linear(config.dim, config.num_classes, bias=False) self.post_init() def get_input_embeddings(self): return self.glm2.tok_embeddings def set_input_embeddings(self, value): self.glm2.tok_embeddings = value def forward( self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, labels: Optional[torch.LongTensor] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, **kwargs, ) -> Union[Tuple, SequenceClassifierOutput]: return_dict = return_dict if return_dict is not None else self.config.use_return_dict outputs = self.glm2( input_ids, attention_mask=attention_mask, output_hidden_states=output_hidden_states, return_dict=return_dict, ) token_embeddings = outputs[0] # use <+> as CLS token cls_token = token_embeddings[:, 0, :] logits = self.score(cls_token) loss = None if labels is not None: labels = labels.to(logits.device) loss_fct = nn.CrossEntropyLoss() loss = loss_fct(logits.view(-1, self.config.num_classes), labels.view(-1)) if not return_dict: output = (logits,) + outputs[2:] return ((loss,) + output) if loss is not None else output return SequenceClassifierOutput( loss=loss, logits=logits, hidden_states=outputs.hidden_states, )