aman-s-affinsys's picture
fix: predict intent
8e14057
from fastapi import FastAPI
from pydantic import BaseModel
from transformers import RobertaTokenizerFast, RobertaForSequenceClassification, TextClassificationPipeline
import uvicorn
# Define FastAPI app
app = FastAPI()
# Load Model on Startup
HUGGINGFACE_MODEL_PATH = "bespin-global/klue-roberta-small-3i4k-intent-classification"
print("Loading model...") # Log message
try:
loaded_tokenizer = RobertaTokenizerFast.from_pretrained(HUGGINGFACE_MODEL_PATH)
loaded_model = RobertaForSequenceClassification.from_pretrained(HUGGINGFACE_MODEL_PATH)
# Create Text Classification Pipeline
text_classifier = TextClassificationPipeline(
tokenizer=loaded_tokenizer,
model=loaded_model,
return_all_scores=True
)
print("Model loaded successfully.") # Log message
except Exception as e:
print(f"Error loading model: {e}")
text_classifier = None
# Health Check Endpoint
@app.get("/")
def hello():
return {"Message": "Space is running Good.", "Status": "Healthy"}
# Define Pydantic Model for Input Validation
class PredictionRequest(BaseModel):
sentence: str
# Prediction Endpoint
@app.post("/predict")
def predict_intent(request: PredictionRequest):
if text_classifier is None:
return {"error": "Model not found"}
sentence = request.sentence.strip() # Correct way to get JSON input
preds_list = text_classifier(sentence)
best_pred = max(preds_list[0], key=lambda x: x["score"]) # Get highest-scoring intent
return {"predicted_intent": best_pred["label"], "confidence": best_pred["score"]}