Embedded / routers /predict.py
yamunasivan's picture
create a new file
d1e319b
raw
history blame contribute delete
459 Bytes
from fastapi import APIRouter
from schema.schemas import PredictionInput, PredictionOutput
from service.classifier import load_model, predict
router = APIRouter(prefix="/predict", tags=["Prediction"])
# Load the model once
model, vectorizer = load_model()
@router.post("/", response_model=PredictionOutput)
def make_prediction(input_data: PredictionInput):
prediction = predict(input_data.text, model, vectorizer)
return {"prediction": prediction}