import torch import torch.nn as nn from transformers import DebertaV2Model, DebertaV2PreTrainedModel class MultiHeadDebertaForSequenceClassification(DebertaV2PreTrainedModel): def __init__(self, config, num_heads=5): super().__init__(config) self.num_heads = num_heads self.deberta = DebertaV2Model(config) self.heads = nn.ModuleList([nn.Linear(config.hidden_size, 4) for _ in range(num_heads)]) self.dropout = nn.Dropout(config.hidden_dropout_prob) self.post_init() def forward(self, input_ids=None, attention_mask=None): outputs = self.deberta(input_ids=input_ids, attention_mask=attention_mask) sequence_output = outputs[0] logits_list = [head(self.dropout(sequence_output[:, 0, :])) for head in self.heads] logits = torch.stack(logits_list, dim=1) return logits