File size: 424 Bytes
154ca7b
 
 
 
 
 
 
68f573d
 
154ca7b
 
68f573d
154ca7b
1
2
3
4
5
6
7
8
9
10
11
12
13
from torch import nn
from transformers import BertForSequenceClassification

class ERCBCM(nn.Module):
  
    def __init__(self):
        super(ERCBCM, self).__init__()
        print('>>> ERCBCM Init!')
        self.bert_base = BertForSequenceClassification.from_pretrained('bert-base-uncased')

    def forward(self, text, label):
        loss, text_fea = self.bert_base(text, labels=label)[:2]
        return loss, text_fea