caslabs commited on
Commit
3f36ea6
1 Parent(s): c594b07

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +44 -10
app.py CHANGED
@@ -1,28 +1,62 @@
 
 
1
  import gradio as gr
2
  import pandas as pd
3
  import xgboost as xgb
4
  from huggingface_hub import hf_hub_download
 
5
 
6
- # Load the model from the Hugging Face Hub
7
  model_path = hf_hub_download(repo_id="caslabs/xgboost-home-price-predictor", filename="xgboost_model.json")
8
  model = xgb.XGBRegressor()
9
  model.load_model(model_path)
10
 
11
- # Define the prediction function
12
- def predict_price(features):
13
- # Convert the JSON input to a DataFrame
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  df = pd.DataFrame([features])
15
  predicted_price = model.predict(df)[0]
16
  return {"predicted_price": predicted_price}
17
 
18
- # Set up the Gradio interface
19
  iface = gr.Interface(
20
- fn=predict_price,
21
- inputs=gr.JSON(), # Accept JSON input
22
- outputs=gr.JSON(), # Return JSON output
23
  title="Home Price Prediction API",
24
  description="Predict home price based on input features"
25
  )
26
 
27
- # Launch the interface without 'enable_api'
28
- iface.launch()
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, HTTPException
2
+ from pydantic import BaseModel
3
  import gradio as gr
4
  import pandas as pd
5
  import xgboost as xgb
6
  from huggingface_hub import hf_hub_download
7
+ import uvicorn
8
 
9
+ # Load the model from Hugging Face Hub
10
  model_path = hf_hub_download(repo_id="caslabs/xgboost-home-price-predictor", filename="xgboost_model.json")
11
  model = xgb.XGBRegressor()
12
  model.load_model(model_path)
13
 
14
+ # Initialize FastAPI app
15
+ app = FastAPI()
16
+
17
+ # Define the input data model for FastAPI
18
+ class PredictionRequest(BaseModel):
19
+ Site_Area_sqft: float
20
+ Actual_Age_Years: int
21
+ Total_Rooms: int
22
+ Bedrooms: int
23
+ Bathrooms: float
24
+ Gross_Living_Area_sqft: float
25
+ Design_Style_Code: int
26
+ Condition_Code: int
27
+ Energy_Efficient_Code: int
28
+ Garage_Carport_Code: int
29
+
30
+ # Define a prediction endpoint in FastAPI
31
+ @app.post("/predict")
32
+ async def predict(request: PredictionRequest):
33
+ data = pd.DataFrame([request.dict()])
34
+ try:
35
+ predicted_price = model.predict(data)[0]
36
+ return {"predicted_price": predicted_price}
37
+ except Exception as e:
38
+ raise HTTPException(status_code=500, detail=str(e))
39
+
40
+ # Define the Gradio prediction function
41
+ def gradio_predict_price(features):
42
  df = pd.DataFrame([features])
43
  predicted_price = model.predict(df)[0]
44
  return {"predicted_price": predicted_price}
45
 
46
+ # Set up Gradio interface
47
  iface = gr.Interface(
48
+ fn=gradio_predict_price,
49
+ inputs=gr.JSON(),
50
+ outputs=gr.JSON(),
51
  title="Home Price Prediction API",
52
  description="Predict home price based on input features"
53
  )
54
 
55
+ # Launch Gradio on a separate route
56
+ @app.on_event("startup")
57
+ async def startup_event():
58
+ iface.launch(server_name="0.0.0.0", server_port=7860, share=False)
59
+
60
+ # Run FastAPI app if this script is executed
61
+ if __name__ == "__main__":
62
+ uvicorn.run(app, host="0.0.0.0", port=8000)