caslabs's picture
Update app.py
3f36ea6 verified
raw
history blame
1.84 kB
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import gradio as gr
import pandas as pd
import xgboost as xgb
from huggingface_hub import hf_hub_download
import uvicorn
# Load the model from Hugging Face Hub
model_path = hf_hub_download(repo_id="caslabs/xgboost-home-price-predictor", filename="xgboost_model.json")
model = xgb.XGBRegressor()
model.load_model(model_path)
# Initialize FastAPI app
app = FastAPI()
# Define the input data model for FastAPI
class PredictionRequest(BaseModel):
Site_Area_sqft: float
Actual_Age_Years: int
Total_Rooms: int
Bedrooms: int
Bathrooms: float
Gross_Living_Area_sqft: float
Design_Style_Code: int
Condition_Code: int
Energy_Efficient_Code: int
Garage_Carport_Code: int
# Define a prediction endpoint in FastAPI
@app.post("/predict")
async def predict(request: PredictionRequest):
data = pd.DataFrame([request.dict()])
try:
predicted_price = model.predict(data)[0]
return {"predicted_price": predicted_price}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
# Define the Gradio prediction function
def gradio_predict_price(features):
df = pd.DataFrame([features])
predicted_price = model.predict(df)[0]
return {"predicted_price": predicted_price}
# Set up Gradio interface
iface = gr.Interface(
fn=gradio_predict_price,
inputs=gr.JSON(),
outputs=gr.JSON(),
title="Home Price Prediction API",
description="Predict home price based on input features"
)
# Launch Gradio on a separate route
@app.on_event("startup")
async def startup_event():
iface.launch(server_name="0.0.0.0", server_port=7860, share=False)
# Run FastAPI app if this script is executed
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=8000)