NithitEiEi's picture
upload model and app
d8f4336 verified
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModelForSequenceClassification
class BertClassifier(nn.Module):
def __init__(self, bert):
super(BertClassifier, self).__init__()
self.bert = bert
def forward(self, input_id, attention_mask):
output = self.bert(input_ids=input_id, attention_mask=attention_mask)
return output.logits
tokenizer = AutoTokenizer.from_pretrained('microsoft/deberta-v3-base')
bert = AutoModelForSequenceClassification.from_pretrained('microsoft/deberta-v3-base').train()
classifier = nn.Sequential(
nn.Linear(768, 1024),
nn.ReLU(),
nn.Dropout(0.5),
nn.Linear(1024, 2)
)
bert.classifier = classifier
model = BertClassifier(bert)
state_dict = torch.load(
"./deberta/fastai_QIQC-deberta-v3.pth", map_location=torch.device('cpu'),
weights_only=True
)
model.load_state_dict(state_dict, strict=False)
model.eval()
def deBERTa_predict(text):
tokenized_input = tokenizer(text,
padding="max_length",
truncation=True,
max_length=30,
return_tensors="pt")
model.eval()
with torch.no_grad():
logits = model(tokenized_input['input_ids'], tokenized_input['attention_mask'])
probabilities = F.softmax(logits, dim=-1)
prediction = torch.argmax(probabilities, dim=-1).item()
return prediction, probabilities[0][1].item()