lilingxi01's picture
[ERCBCM] Optimize the model interfaces and prints.
8b850ac
raw
history blame contribute delete
424 Bytes
from torch import nn
from transformers import BertForSequenceClassification
class ERCBCM(nn.Module):
def __init__(self):
super(ERCBCM, self).__init__()
print('>>> ERCBCM Init!')
self.bert_base = BertForSequenceClassification.from_pretrained('bert-base-uncased')
def forward(self, text, label):
loss, text_fea = self.bert_base(text, labels=label)[:2]
return loss, text_fea