caslabs commited on
Commit
775aed7
·
verified ·
1 Parent(s): 155f630

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +29 -17
app.py CHANGED
@@ -1,27 +1,39 @@
1
- import gradio as gr
 
2
  import pandas as pd
3
  import xgboost as xgb
4
  from huggingface_hub import hf_hub_download
5
 
6
- # Load the model from the Hugging Face Hub
7
  model_path = hf_hub_download(repo_id="caslabs/xgboost-home-price-predictor", filename="xgboost_model.json")
8
  model = xgb.XGBRegressor()
9
  model.load_model(model_path)
10
 
11
- # Define the prediction function
12
- def predict_price(features):
13
- # Convert the JSON input to a DataFrame
14
- df = pd.DataFrame([features])
15
- predicted_price = model.predict(df)[0]
16
- return {"predicted_price": predicted_price}
17
 
18
- # Set up the Gradio interface
19
- iface = gr.Interface(
20
- fn=predict_price,
21
- inputs=gr.JSON(), # Accept JSON input
22
- outputs=gr.JSON(), # Return JSON output
23
- title="Home Price Prediction API",
24
- description="Predict home price based on input features"
25
- )
 
 
 
 
26
 
27
- iface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, HTTPException
2
+ from pydantic import BaseModel
3
  import pandas as pd
4
  import xgboost as xgb
5
  from huggingface_hub import hf_hub_download
6
 
7
+ # Download and load the model from the correct repo
8
  model_path = hf_hub_download(repo_id="caslabs/xgboost-home-price-predictor", filename="xgboost_model.json")
9
  model = xgb.XGBRegressor()
10
  model.load_model(model_path)
11
 
12
+ # Initialize FastAPI app
13
+ app = FastAPI()
 
 
 
 
14
 
15
+ # Define the expected input format
16
+ class PredictionRequest(BaseModel):
17
+ Site_Area_sqft: float
18
+ Actual_Age_Years: int
19
+ Total_Rooms: int
20
+ Bedrooms: int
21
+ Bathrooms: float
22
+ Gross_Living_Area_sqft: float
23
+ Design_Style_Code: int
24
+ Condition_Code: int
25
+ Energy_Efficient_Code: int
26
+ Garage_Carport_Code: int
27
 
28
+ # Define the /predict route
29
+ @app.post("/predict")
30
+ async def predict(request: PredictionRequest):
31
+ # Convert the input data to a DataFrame
32
+ data = pd.DataFrame([request.dict()])
33
+
34
+ # Make a prediction
35
+ try:
36
+ predicted_price = model.predict(data)[0]
37
+ return {"predicted_price": predicted_price}
38
+ except Exception as e:
39
+ raise HTTPException(status_code=500, detail=str(e))