curfox_test_api / main.py
Arafath10's picture
Update main.py
40dddb0 verified
raw
history blame
2.18 kB
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
import pandas as pd
import numpy as np
import joblib
# Load your trained model and encoders
xgb_model = joblib.load("xgb_model.joblib")
encoders = joblib.load("encoders.joblib")
# Function to handle unseen labels during encoding
def safe_transform(encoder, column):
classes = encoder.classes_
return [encoder.transform([x])[0] if x in classes else -1 for x in column]
# Define FastAPI app
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Endpoint for making predictions
@app.post("/predict")
def predict(customer_name: str,
customer_address: str,
customer_phone: str,
customer_email: str,
cod:str,
weight: str,
pickup_address: str,
origin_city_name: str,
destination_city_name: str,
destination_country: str,
rigin_country: str):
# Convert input data to DataFrame
input_data = {
'customer_name': customer_name,
'customer_address': customer_address,
'customer_phone': customer_phone,
'customer_email': customer_email,
'cod': float(cod),
'weight': float(weight),
'pickup_address':pickup_address,
'origin_city.name':origin_city_name,
'destination_city.name':destination_city_name
}
input_df = pd.DataFrame([input_data])
# Encode categorical variables using the same encoders used during training
for col in input_df.columns:
if col in encoders:
input_df[col] = safe_transform(encoders[col], input_df[col])
# Predict and obtain probabilities
pred = xgb_model.predict(input_df)
pred_proba = xgb_model.predict_proba(input_df)
# Output
predicted_status = "Unknown" if pred[0] == -1 else encoders['status.name'].inverse_transform([pred])[0]
probability = pred_proba[0][pred[0]] * 100 if pred[0] != -1 else "Unknown"
if predicted_status == "RETURN TO CLIENT":
probability = 100 - probability
return {"Probability": round(probability,2)}