caslabs commited on
Commit
6a535fc
·
verified ·
1 Parent(s): 3311e46

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -29
app.py CHANGED
@@ -1,39 +1,27 @@
1
- from fastapi import FastAPI, HTTPException
2
- from pydantic import BaseModel
3
  import pandas as pd
4
  import xgboost as xgb
5
  from huggingface_hub import hf_hub_download
6
 
7
- # Download and load the model from the correct repo
8
  model_path = hf_hub_download(repo_id="caslabs/xgboost-home-price-predictor", filename="xgboost_model.json")
9
  model = xgb.XGBRegressor()
10
  model.load_model(model_path)
11
 
12
- # Initialize FastAPI app
13
- app = FastAPI()
 
 
 
 
14
 
15
- # Define the expected input format
16
- class PredictionRequest(BaseModel):
17
- Site_Area_sqft: float
18
- Actual_Age_Years: int
19
- Total_Rooms: int
20
- Bedrooms: int
21
- Bathrooms: float
22
- Gross_Living_Area_sqft: float
23
- Design_Style_Code: int
24
- Condition_Code: int
25
- Energy_Efficient_Code: int
26
- Garage_Carport_Code: int
27
 
28
- # Define the /predict route
29
- @app.post("/predict")
30
- async def predict(request: PredictionRequest):
31
- # Convert the input data to a DataFrame
32
- data = pd.DataFrame([request.dict()])
33
-
34
- # Make a prediction
35
- try:
36
- predicted_price = model.predict(data)[0]
37
- return {"predicted_price": predicted_price}
38
- except Exception as e:
39
- raise HTTPException(status_code=500, detail=str(e))
 
1
+ import gradio as gr
 
2
  import pandas as pd
3
  import xgboost as xgb
4
  from huggingface_hub import hf_hub_download
5
 
6
+ # Load the model from the Hugging Face Hub
7
  model_path = hf_hub_download(repo_id="caslabs/xgboost-home-price-predictor", filename="xgboost_model.json")
8
  model = xgb.XGBRegressor()
9
  model.load_model(model_path)
10
 
11
+ # Define the prediction function
12
+ def predict_price(features):
13
+ # Convert the JSON input to a DataFrame
14
+ df = pd.DataFrame([features])
15
+ predicted_price = model.predict(df)[0]
16
+ return {"predicted_price": predicted_price}
17
 
18
+ # Set up the Gradio interface
19
+ iface = gr.Interface(
20
+ fn=predict_price,
21
+ inputs=gr.JSON(), # Accept JSON input
22
+ outputs=gr.JSON(), # Return JSON output
23
+ title="Home Price Prediction API",
24
+ description="Predict home price based on input features"
25
+ )
 
 
 
 
26
 
27
+ iface.launch()