legacydemo / src /upvote_predictor.py
gupta-amulya's picture
Enhance SemanticSearcher integration and refine UpvotePredictor output handling
20df6e4
import numpy as np
import torch
from torch.utils.data import DataLoader, SequentialSampler, TensorDataset
from transformers import BertTokenizer
class UpvotePredictor:
def __init__(self, model_path: str):
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.upvote_ml_model = torch.load(
model_path, map_location=torch.device("cpu"), weights_only=False
)
self.tokenizer = BertTokenizer.from_pretrained(
"bert-base-uncased", do_lower_case=True
)
self.upvote_ml_model.to(self.device)
self.upvote_ml_model.eval()
def get_upvote_prediction(
self, question: str, answer: str, question_context: str = None
) -> int:
llm_response_input_ids = []
llm_response_attention_masks = []
encoded_dict = self.tokenizer.encode_plus(
answer,
add_special_tokens=True,
max_length=256,
padding="max_length",
truncation=True,
return_attention_mask=True,
return_tensors="pt",
)
llm_response_input_ids.append(encoded_dict["input_ids"])
llm_response_attention_masks.append(encoded_dict["attention_mask"])
llm_response_input_ids = torch.cat(llm_response_input_ids, dim=0)
llm_response_attention_masks = torch.cat(llm_response_attention_masks, dim=0)
test_dataset = TensorDataset(
llm_response_input_ids, llm_response_attention_masks
)
test_dataloader = DataLoader(
test_dataset, # The validation samples.
sampler=SequentialSampler(test_dataset), # Pull out batches sequentially.
batch_size=1, # Evaluate with this batch size.
)
predictions = []
for batch in test_dataloader:
b_input_ids = batch[0].to(self.device)
b_input_mask = batch[1].to(self.device)
with torch.no_grad():
output = self.upvote_ml_model(
b_input_ids, token_type_ids=None, attention_mask=b_input_mask
)
logits = output.logits
logits = logits.detach().cpu().numpy()
pred_flat = np.argmax(logits, axis=1).flatten()
predictions.extend(list(pred_flat))
if predictions[0] == 0:
return "Not credible suggestion"
else:
return "Credible suggestion"