Vnese_crawl / predictor.py
SonFox2920's picture
Upload 6 files
b208559 verified
raw
history blame
2.61 kB
hf_token = "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"
import warnings
warnings.filterwarnings('ignore')
import logging
logging.disable(logging.WARNING)
import torch
import numpy as np
import random
from transformers import AutoModel, AutoTokenizer
from torch.utils.data import DataLoader
from Mbert import MBERTClassifier, SentencePairDataset
import pandas as pd
# Thiết lập seed cố định
def set_seed(seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
# Gọi hàm set_seed với seed cố định, ví dụ: 42
set_seed(42)
device = torch.device("cpu")
modelname = "bert-base-multilingual-cased"
tokenizer = AutoTokenizer.from_pretrained(modelname, token=hf_token)
mbert = AutoModel.from_pretrained(modelname, token=hf_token).to(device)
model = MBERTClassifier(mbert, num_classes=3).to(device)
model.load_state_dict(torch.load('Model/classifier.pt', map_location=device))
def predict(context, claim):
data = pd.DataFrame([{'context': context, 'claim': claim}])
X1_pub_test = data['claim']
X2_pub_test = data['context']
X_pub_test = [(X1_pub_test, X2_pub_test) for (X1_pub_test, X2_pub_test) in zip(X1_pub_test, X2_pub_test)]
y_pub_test = [1]
test_dataset = SentencePairDataset(X_pub_test, y_pub_test, tokenizer, 256)
test_loader_pub = DataLoader(test_dataset, batch_size=1)
model.eval()
predictions = []
probabilities = []
for batch in test_loader_pub:
input_ids = batch["input_ids"].to(device)
attention_mask = batch["attention_mask"].to(device)
with torch.no_grad():
outputs = model(input_ids, attention_mask)
probs = torch.nn.functional.softmax(outputs, dim=1)
predicted = torch.argmax(outputs, dim=1)
predictions.extend(predicted.cpu().numpy().tolist())
probabilities.extend(probs.cpu().numpy().tolist())
data['verdict'] = predictions
data['verdict'] = data['verdict'].replace(0, "SUPPORTED")
data['verdict'] = data['verdict'].replace(1, "REFUTED")
data['verdict'] = data['verdict'].replace(2, "NEI")
result = {
'verdict': data['verdict'][0],
'probabilities': {
'SUPPORTED': probabilities[0][0],
'REFUTED': probabilities[0][1],
'NEI': probabilities[0][2]
}
}
return result