Spaces:
Sleeping
Sleeping
File size: 473 Bytes
a450bc7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 |
from transformers import BertForTokenClassification
import torch
class BertModel(torch.nn.Module):
def __init__(self, pretrained_model, num_labels):
super(BertModel, self).__init__()
self.bert = BertForTokenClassification.from_pretrained(pretrained_model, num_labels=num_labels)
def forward(self, input_id, mask, label):
output = self.bert(input_ids=input_id, attention_mask=mask, labels=label, return_dict=False)
return output |