File size: 7,553 Bytes
6179e3c
1
{"cells": [{"cell_type": "markdown", "id": "302934307671667531413257853548643485645", "metadata": {}, "source": ["# Gradio Demo: xgboost-income-prediction-with-explainability\n", "### This demo takes in 12 inputs from the user in dropdowns and sliders and predicts income. It also has a separate button for explaining the prediction.\n", "        "]}, {"cell_type": "code", "execution_count": null, "id": "272996653310673477252411125948039410165", "metadata": {}, "outputs": [], "source": ["!pip install -q gradio numpy==1.23.2 matplotlib shap xgboost==1.7.6 pandas datasets"]}, {"cell_type": "code", "execution_count": null, "id": "288918539441861185822528903084949547379", "metadata": {}, "outputs": [], "source": ["# type: ignore\n", "import gradio as gr\n", "import random\n", "import matplotlib.pyplot as plt\n", "import pandas as pd\n", "import shap\n", "import xgboost as xgb\n", "from datasets import load_dataset\n", "\n", "dataset = load_dataset(\"scikit-learn/adult-census-income\")\n", "X_train = dataset[\"train\"].to_pandas()\n", "_ = X_train.pop(\"fnlwgt\")\n", "_ = X_train.pop(\"race\")\n", "y_train = X_train.pop(\"income\")\n", "y_train = (y_train == \">50K\").astype(int)\n", "categorical_columns = [\n", "    \"workclass\",\n", "    \"education\",\n", "    \"marital.status\",\n", "    \"occupation\",\n", "    \"relationship\",\n", "    \"sex\",\n", "    \"native.country\",\n", "]\n", "X_train = X_train.astype({col: \"category\" for col in categorical_columns})\n", "data = xgb.DMatrix(X_train, label=y_train, enable_categorical=True)\n", "model = xgb.train(params={\"objective\": \"binary:logistic\"}, dtrain=data)\n", "explainer = shap.TreeExplainer(model)\n", "\n", "def predict(*args):\n", "    df = pd.DataFrame([args], columns=X_train.columns)\n", "    df = df.astype({col: \"category\" for col in categorical_columns})\n", "    pos_pred = model.predict(xgb.DMatrix(df, enable_categorical=True))\n", "    return {\">50K\": float(pos_pred[0]), \"<=50K\": 1 - float(pos_pred[0])}\n", "\n", "def interpret(*args):\n", "    df = pd.DataFrame([args], columns=X_train.columns)\n", "    df = df.astype({col: \"category\" for col in categorical_columns})\n", "    shap_values = explainer.shap_values(xgb.DMatrix(df, enable_categorical=True))\n", "    scores_desc = list(zip(shap_values[0], X_train.columns))\n", "    scores_desc = sorted(scores_desc)\n", "    fig_m = plt.figure(tight_layout=True)\n", "    plt.barh([s[1] for s in scores_desc], [s[0] for s in scores_desc])\n", "    plt.title(\"Feature Shap Values\")\n", "    plt.ylabel(\"Shap Value\")\n", "    plt.xlabel(\"Feature\")\n", "    plt.tight_layout()\n", "    return fig_m\n", "\n", "unique_class = sorted(X_train[\"workclass\"].unique())\n", "unique_education = sorted(X_train[\"education\"].unique())\n", "unique_marital_status = sorted(X_train[\"marital.status\"].unique())\n", "unique_relationship = sorted(X_train[\"relationship\"].unique())\n", "unique_occupation = sorted(X_train[\"occupation\"].unique())\n", "unique_sex = sorted(X_train[\"sex\"].unique())\n", "unique_country = sorted(X_train[\"native.country\"].unique())\n", "\n", "with gr.Blocks() as demo:\n", "    gr.Markdown(\"\"\"\n", "    **Income Classification with XGBoost \ud83d\udcb0**:  This demo uses an XGBoost classifier predicts income based on demographic factors, along with Shapley value-based *explanations*. The [source code for this Gradio demo is here](https://huggingface.co/spaces/gradio/xgboost-income-prediction-with-explainability/blob/main/app.py).\n", "    \"\"\")\n", "    with gr.Row():\n", "        with gr.Column():\n", "            age = gr.Slider(label=\"Age\", minimum=17, maximum=90, step=1, randomize=True)\n", "            work_class = gr.Dropdown(\n", "                label=\"Workclass\",\n", "                choices=unique_class,\n", "                value=lambda: random.choice(unique_class),\n", "            )\n", "            education = gr.Dropdown(\n", "                label=\"Education Level\",\n", "                choices=unique_education,\n", "                value=lambda: random.choice(unique_education),\n", "            )\n", "            years = gr.Slider(\n", "                label=\"Years of schooling\",\n", "                minimum=1,\n", "                maximum=16,\n", "                step=1,\n", "                randomize=True,\n", "            )\n", "            marital_status = gr.Dropdown(\n", "                label=\"Marital Status\",\n", "                choices=unique_marital_status,\n", "                value=lambda: random.choice(unique_marital_status),\n", "            )\n", "            occupation = gr.Dropdown(\n", "                label=\"Occupation\",\n", "                choices=unique_occupation,\n", "                value=lambda: random.choice(unique_occupation),\n", "            )\n", "            relationship = gr.Dropdown(\n", "                label=\"Relationship Status\",\n", "                choices=unique_relationship,\n", "                value=lambda: random.choice(unique_relationship),\n", "            )\n", "            sex = gr.Dropdown(\n", "                label=\"Sex\", choices=unique_sex, value=lambda: random.choice(unique_sex)\n", "            )\n", "            capital_gain = gr.Slider(\n", "                label=\"Capital Gain\",\n", "                minimum=0,\n", "                maximum=100000,\n", "                step=500,\n", "                randomize=True,\n", "            )\n", "            capital_loss = gr.Slider(\n", "                label=\"Capital Loss\", minimum=0, maximum=10000, step=500, randomize=True\n", "            )\n", "            hours_per_week = gr.Slider(\n", "                label=\"Hours Per Week Worked\", minimum=1, maximum=99, step=1\n", "            )\n", "            country = gr.Dropdown(\n", "                label=\"Native Country\",\n", "                choices=unique_country,\n", "                value=lambda: random.choice(unique_country),\n", "            )\n", "        with gr.Column():\n", "            label = gr.Label()\n", "            plot = gr.Plot()\n", "            with gr.Row():\n", "                predict_btn = gr.Button(value=\"Predict\")\n", "                interpret_btn = gr.Button(value=\"Explain\")\n", "            predict_btn.click(\n", "                predict,\n", "                inputs=[\n", "                    age,\n", "                    work_class,\n", "                    education,\n", "                    years,\n", "                    marital_status,\n", "                    occupation,\n", "                    relationship,\n", "                    sex,\n", "                    capital_gain,\n", "                    capital_loss,\n", "                    hours_per_week,\n", "                    country,\n", "                ],\n", "                outputs=[label],\n", "            )\n", "            interpret_btn.click(\n", "                interpret,\n", "                inputs=[\n", "                    age,\n", "                    work_class,\n", "                    education,\n", "                    years,\n", "                    marital_status,\n", "                    occupation,\n", "                    relationship,\n", "                    sex,\n", "                    capital_gain,\n", "                    capital_loss,\n", "                    hours_per_week,\n", "                    country,\n", "                ],\n", "                outputs=[plot],\n", "            )\n", "\n", "demo.launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5}