abullard1's picture
Create handler.py
68e5def verified
raw
history blame
1.64 kB
from transformers import pipeline
import torch
from typing import Dict, Any
class EndpointHandler:
def __init__(self, path=""):
device = 0 if torch.cuda.is_available() else -1
torch_d_type = torch.float16 if torch.cuda.is_available() else torch.float32
self.classifier = pipeline(
task="text-classification",
model="abullard1/albert-v2-steam-review-constructiveness-classifier",
tokenizer="albert-base-v2",
device=device,
top_k=None,
truncation=True,
max_length=512,
torch_dtype=torch_d_type
)
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
input_text = data.get("inputs", "")
results = self.classifier(input_text)
label_1, score_1 = results[0][0]["label"], results[0][0]["score"]
label_2, score_2 = results[0][1]["label"], results[0][1]["score"]
return {
"label_1": label_1,
"score_1": score_1,
"label_2": label_2,
"score_2": score_2,
"prediction_text": self.format_prediction_text(label_1, score_1, label_2, score_2)
}
def format_prediction_text(self, label_1, score_1, label_2, score_2) -> str:
def label_to_constructiveness(label):
return "Constructive" if label == "LABEL_1" else "Not Constructive"
if score_1 >= score_2:
return f"{label_to_constructiveness(label_1)} with a score of {score_1:.2f}. πŸ‘πŸ»"
else:
return f"{label_to_constructiveness(label_2)} with a score of {score_2:.2f}. πŸ‘ŽπŸ»"