Spaces:
Sleeping
Sleeping
File size: 1,996 Bytes
4bc5036 f111c66 4bc5036 f111c66 be480c8 e8cc94f be480c8 f111c66 167c984 f111c66 c16ce23 f111c66 55935b2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 |
import pandas as pd
import pickle
from fastapi import FastAPI
import uvicorn
from pydantic import BaseModel
# Load the saved model
with open("model_and_key_components.pkl", "rb") as f:
components = pickle.load(f)
dt_model = components['model']
app = FastAPI()
class IncomePredictionRequest(BaseModel):
age: int
gender: str
education: str
worker_class: str
marital_status: str
race: str
is_hispanic: str
employment_commitment: str
employment_stat: int
wage_per_hour: int
working_week_per_year: int
industry_code: int
industry_code_main: str
occupation_code: int
occupation_code_main: str
total_employed: int
household_summary: str
vet_benefit: int
tax_status: str
gains: int
losses: int
stocks_status: int
citizenship: str
importance_of_record: float
class IncomePredictionResponse(BaseModel):
income_prediction: str
prediction_probability: float
@app.get("/")
async def root():
# Endpoint at the root URL ("/") returns a welcome message with a clickable link
message = "Welcome to the Income Classification API! This API Provides predictions for Income based on several inputs. To use this API, please access the API documentation here: https://rasmodev-income-prediction-fastapi.hf.space/docs/"
return message
@app.post("/predict/")
async def predict_income(data: IncomePredictionRequest):
try:
input_data = data.dict()
input_df = pd.DataFrame([input_data])
prediction = dt_model.predict(input_df)
prediction_proba = dt_model.predict_proba(input_df)
prediction_result = "Income over $50K" if prediction[0] == 1 else "Income under $50K"
return {"income_prediction": prediction_result, "prediction_probability": prediction_proba[0][1]}
except Exception as e:
logging.error(f"Prediction failed: {e}")
raise
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=7860, reload=True) |