aliabd commited on
Commit
d0ba97a
·
1 Parent(s): cbcf957
Files changed (1) hide show
  1. app.py +15 -8
app.py CHANGED
@@ -1,6 +1,7 @@
1
- import random
2
-
3
  import gradio as gr
 
4
  import matplotlib
5
  import matplotlib.pyplot as plt
6
  import pandas as pd
@@ -8,16 +9,13 @@ import shap
8
  import xgboost as xgb
9
  from datasets import load_dataset
10
 
11
- print(gr.__version__)
12
 
 
13
  matplotlib.use("Agg")
14
-
15
  dataset = load_dataset("scikit-learn/adult-census-income")
16
-
17
  X_train = dataset["train"].to_pandas()
18
  _ = X_train.pop("fnlwgt")
19
  _ = X_train.pop("race")
20
-
21
  y_train = X_train.pop("income")
22
  y_train = (y_train == ">50K").astype(int)
23
  categorical_columns = [
@@ -30,12 +28,11 @@ categorical_columns = [
30
  "native.country",
31
  ]
32
  X_train = X_train.astype({col: "category" for col in categorical_columns})
33
-
34
-
35
  data = xgb.DMatrix(X_train, label=y_train, enable_categorical=True)
36
  model = xgb.train(params={"objective": "binary:logistic"}, dtrain=data)
37
  explainer = shap.TreeExplainer(model)
38
 
 
39
 
40
  def predict(*args):
41
  df = pd.DataFrame([args], columns=X_train.columns)
@@ -67,12 +64,17 @@ unique_occupation = sorted(X_train["occupation"].unique())
67
  unique_sex = sorted(X_train["sex"].unique())
68
  unique_country = sorted(X_train["native.country"].unique())
69
 
 
 
70
  with gr.Blocks() as demo:
 
71
  gr.Markdown("""
72
  **Income Classification with XGBoost 💰**: This demo uses an XGBoost classifier predicts income based on demographic factors, along with Shapley value-based *explanations*. The [source code for this Gradio demo is here](https://huggingface.co/spaces/gradio/xgboost-income-prediction-with-explainability/blob/main/app.py).
73
  """)
 
74
  with gr.Row():
75
  with gr.Column():
 
76
  age = gr.Slider(label="Age", minimum=17, maximum=90, step=1, randomize=True)
77
  work_class = gr.Dropdown(
78
  label="Workclass",
@@ -128,11 +130,14 @@ with gr.Blocks() as demo:
128
  value=lambda: random.choice(unique_country),
129
  )
130
  with gr.Column():
 
131
  label = gr.Label()
132
  plot = gr.Plot()
133
  with gr.Row():
 
134
  predict_btn = gr.Button(value="Predict")
135
  interpret_btn = gr.Button(value="Explain")
 
136
  predict_btn.click(
137
  predict,
138
  inputs=[
@@ -151,6 +156,7 @@ with gr.Blocks() as demo:
151
  ],
152
  outputs=[label],
153
  )
 
154
  interpret_btn.click(
155
  interpret,
156
  inputs=[
@@ -170,4 +176,5 @@ with gr.Blocks() as demo:
170
  outputs=[plot],
171
  )
172
 
 
173
  demo.launch()
 
1
+ # URL: https://huggingface.co/spaces/gradio/xgboost-income-prediction-with-explainability
2
+ # imports
3
  import gradio as gr
4
+ import random
5
  import matplotlib
6
  import matplotlib.pyplot as plt
7
  import pandas as pd
 
9
  import xgboost as xgb
10
  from datasets import load_dataset
11
 
 
12
 
13
+ # loading the model and setting up
14
  matplotlib.use("Agg")
 
15
  dataset = load_dataset("scikit-learn/adult-census-income")
 
16
  X_train = dataset["train"].to_pandas()
17
  _ = X_train.pop("fnlwgt")
18
  _ = X_train.pop("race")
 
19
  y_train = X_train.pop("income")
20
  y_train = (y_train == ">50K").astype(int)
21
  categorical_columns = [
 
28
  "native.country",
29
  ]
30
  X_train = X_train.astype({col: "category" for col in categorical_columns})
 
 
31
  data = xgb.DMatrix(X_train, label=y_train, enable_categorical=True)
32
  model = xgb.train(params={"objective": "binary:logistic"}, dtrain=data)
33
  explainer = shap.TreeExplainer(model)
34
 
35
+ # defining the two core fns
36
 
37
  def predict(*args):
38
  df = pd.DataFrame([args], columns=X_train.columns)
 
64
  unique_sex = sorted(X_train["sex"].unique())
65
  unique_country = sorted(X_train["native.country"].unique())
66
 
67
+ # starting the block
68
+
69
  with gr.Blocks() as demo:
70
+ # defining text on the page
71
  gr.Markdown("""
72
  **Income Classification with XGBoost 💰**: This demo uses an XGBoost classifier predicts income based on demographic factors, along with Shapley value-based *explanations*. The [source code for this Gradio demo is here](https://huggingface.co/spaces/gradio/xgboost-income-prediction-with-explainability/blob/main/app.py).
73
  """)
74
+ # defining the layout
75
  with gr.Row():
76
  with gr.Column():
77
+ # defining the inputs
78
  age = gr.Slider(label="Age", minimum=17, maximum=90, step=1, randomize=True)
79
  work_class = gr.Dropdown(
80
  label="Workclass",
 
130
  value=lambda: random.choice(unique_country),
131
  )
132
  with gr.Column():
133
+ # defining the outputs
134
  label = gr.Label()
135
  plot = gr.Plot()
136
  with gr.Row():
137
+ # defining the buttons
138
  predict_btn = gr.Button(value="Predict")
139
  interpret_btn = gr.Button(value="Explain")
140
+ # defining the fn that will run when predict is clicked, what it will get as inputs, and which output it will update
141
  predict_btn.click(
142
  predict,
143
  inputs=[
 
156
  ],
157
  outputs=[label],
158
  )
159
+ # defining the fn that will run when interpret is clicked, what it will get as inputs, and which output it will update
160
  interpret_btn.click(
161
  interpret,
162
  inputs=[
 
176
  outputs=[plot],
177
  )
178
 
179
+ # launch
180
  demo.launch()