Spaces:
Sleeping
Sleeping
import numpy as np | |
import torch | |
from torch.utils.data import DataLoader, SequentialSampler, TensorDataset | |
from transformers import BertTokenizer | |
class UpvotePredictor: | |
def __init__(self, model_path: str): | |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
self.upvote_ml_model = torch.load( | |
model_path, map_location=torch.device("cpu"), weights_only=False | |
) | |
self.tokenizer = BertTokenizer.from_pretrained( | |
"bert-base-uncased", do_lower_case=True | |
) | |
self.upvote_ml_model.to(self.device) | |
self.upvote_ml_model.eval() | |
def get_upvote_prediction( | |
self, question: str, answer: str, question_context: str = None | |
) -> int: | |
llm_response_input_ids = [] | |
llm_response_attention_masks = [] | |
encoded_dict = self.tokenizer.encode_plus( | |
answer, | |
add_special_tokens=True, | |
max_length=256, | |
padding="max_length", | |
truncation=True, | |
return_attention_mask=True, | |
return_tensors="pt", | |
) | |
llm_response_input_ids.append(encoded_dict["input_ids"]) | |
llm_response_attention_masks.append(encoded_dict["attention_mask"]) | |
llm_response_input_ids = torch.cat(llm_response_input_ids, dim=0) | |
llm_response_attention_masks = torch.cat(llm_response_attention_masks, dim=0) | |
test_dataset = TensorDataset( | |
llm_response_input_ids, llm_response_attention_masks | |
) | |
test_dataloader = DataLoader( | |
test_dataset, # The validation samples. | |
sampler=SequentialSampler(test_dataset), # Pull out batches sequentially. | |
batch_size=1, # Evaluate with this batch size. | |
) | |
predictions = [] | |
for batch in test_dataloader: | |
b_input_ids = batch[0].to(self.device) | |
b_input_mask = batch[1].to(self.device) | |
with torch.no_grad(): | |
output = self.upvote_ml_model( | |
b_input_ids, token_type_ids=None, attention_mask=b_input_mask | |
) | |
logits = output.logits | |
logits = logits.detach().cpu().numpy() | |
pred_flat = np.argmax(logits, axis=1).flatten() | |
predictions.extend(list(pred_flat)) | |
if predictions[0] == 0: | |
return "Not credible suggestion" | |
else: | |
return "Credible suggestion" | |