Text Classification
Safetensors
deberta-v2
File size: 871 Bytes
b10d6fb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
import torch
import torch.nn as nn
from transformers import DebertaV2Model, DebertaV2PreTrainedModel

class MultiHeadDebertaForSequenceClassification(DebertaV2PreTrainedModel):
    def __init__(self, config, num_heads=5):
        super().__init__(config)
        self.num_heads = num_heads
        self.deberta = DebertaV2Model(config)  
        self.heads = nn.ModuleList([nn.Linear(config.hidden_size, 4) for _ in range(num_heads)])  
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        self.post_init()

    def forward(self, input_ids=None, attention_mask=None):
        outputs = self.deberta(input_ids=input_ids, attention_mask=attention_mask)
        sequence_output = outputs[0]  
        logits_list = [head(self.dropout(sequence_output[:, 0, :])) for head in self.heads]
        logits = torch.stack(logits_list, dim=1)
        return logits