import torch import torch.nn as nn import torch.nn.functional as F from transformers import AutoTokenizer, AutoModelForSequenceClassification class BertClassifier(nn.Module): def __init__(self, bert): super(BertClassifier, self).__init__() self.bert = bert def forward(self, input_id, attention_mask): output = self.bert(input_ids=input_id, attention_mask=attention_mask) return output.logits tokenizer = AutoTokenizer.from_pretrained('bert-base-cased') bert = AutoModelForSequenceClassification.from_pretrained('bert-base-cased').train() classifier = nn.Sequential( nn.Linear(768, 1024), nn.ReLU(), nn.Dropout(0.5), nn.Linear(1024, 2) ) bert.classifier = classifier model = BertClassifier(bert) model.load_state_dict(torch.load("./bert/bert_model.pth", map_location=torch.device('cpu'), weights_only=True)) model.eval() def BERT_predict(text): tokenized_input = tokenizer(text, padding="max_length", truncation=True, max_length=30, return_tensors="pt") model.eval() with torch.no_grad(): logits = model(tokenized_input['input_ids'], tokenized_input['attention_mask']) probabilities = F.softmax(logits, dim=-1) prediction = torch.argmax(probabilities, dim=-1).item() return prediction, probabilities[0][1].item()