|
import os |
|
import torch |
|
from transformers import AutoModelForSequenceClassification, AutoTokenizer |
|
from datetime import datetime |
|
|
|
|
|
MODEL_PATH = "." |
|
model = AutoModelForSequenceClassification.from_pretrained(MODEL_PATH) |
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH) |
|
model.eval() |
|
|
|
|
|
label_mapping = {0: "Negative", 1: "Positive"} |
|
|
|
def predict(inputs): |
|
""" |
|
Function to handle prediction. |
|
:param inputs: Dictionary with the text to be analyzed, e.g., {'text': 'I love this movie'} |
|
:return: Dictionary with label and confidence score |
|
""" |
|
try: |
|
|
|
input_text = inputs.get("text") |
|
if not input_text: |
|
return {"error": "Invalid input, 'text' key is required"}, 400 |
|
|
|
|
|
tokenized_input = tokenizer(input_text, return_tensors="pt", truncation=True, max_length=512) |
|
|
|
|
|
with torch.no_grad(): |
|
outputs = model(**tokenized_input) |
|
probabilities = torch.nn.functional.softmax(outputs.logits, dim=1) |
|
confidence, label_idx = torch.max(probabilities, dim=1) |
|
confidence = confidence.item() * 100 |
|
label = label_mapping[label_idx.item()] |
|
|
|
|
|
response = { |
|
"data": { |
|
"confidence": f"{confidence:.2f}%", |
|
"input_text": input_text, |
|
"label": label |
|
}, |
|
"model_version": "1.0.0", |
|
"status": "success", |
|
"timestamp": datetime.now().isoformat() |
|
} |
|
|
|
return response |
|
|
|
except Exception as e: |
|
|
|
return {"error": str(e)}, 500 |
|
|