File size: 871 Bytes
b10d6fb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 |
import torch
import torch.nn as nn
from transformers import DebertaV2Model, DebertaV2PreTrainedModel
class MultiHeadDebertaForSequenceClassification(DebertaV2PreTrainedModel):
def __init__(self, config, num_heads=5):
super().__init__(config)
self.num_heads = num_heads
self.deberta = DebertaV2Model(config)
self.heads = nn.ModuleList([nn.Linear(config.hidden_size, 4) for _ in range(num_heads)])
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.post_init()
def forward(self, input_ids=None, attention_mask=None):
outputs = self.deberta(input_ids=input_ids, attention_mask=attention_mask)
sequence_output = outputs[0]
logits_list = [head(self.dropout(sequence_output[:, 0, :])) for head in self.heads]
logits = torch.stack(logits_list, dim=1)
return logits
|