File size: 387 Bytes
154ca7b
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
from torch import nn
from transformers import BertForSequenceClassification

class ERCBCM(nn.Module):
  
    def __init__(self):
        super(ERCBCM, self).__init__()

        self.encoder = BertForSequenceClassification.from_pretrained('bert-base-uncased')

    def forward(self, text, label):
        loss, text_fea = self.encoder(text, labels=label)[:2]
        return loss, text_fea