File size: 3,375 Bytes
0449aa7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b20b499
0449aa7
b20b499
0449aa7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "eea9f1b8-4240-4a68-a412-2b4071b2c04a",
   "metadata": {},
   "outputs": [],
   "source": [
    "from drawdata import ScatterWidget"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "207dbcfd-e731-4758-8035-d1f429aa10d4",
   "metadata": {},
   "outputs": [],
   "source": [
    "widget = ScatterWidget()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "b77030f4-c895-4a39-96ce-b79d6a8a6d69",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "from IPython.core.display import HTML\n",
    "from sklearn.linear_model import LogisticRegression\n",
    "from sklearn.inspection import DecisionBoundaryDisplay\n",
    "from sklearn.tree import DecisionTreeClassifier\n",
    "\n",
    "import matplotlib.pylab as plt \n",
    "import numpy as np\n",
    "import ipywidgets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b9d36c79-3d1d-4084-a1a6-43f3b95c06fe",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "6b21ddf1e80a4a34a786547610d52aa0",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(ScatterWidget(), Output()))"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "widget = ScatterWidget()\n",
    "output = ipywidgets.Output()\n",
    "\n",
    "\n",
    "@output.capture(clear_output=True)\n",
    "def on_change(change):\n",
    "    df = widget.data_as_pandas\n",
    "    if len(df) and (df['color'].nunique() > 1):\n",
    "        X = df[['x', 'y']].values\n",
    "        y = df['color']\n",
    "        display(HTML(\"<br><br><br>\"))\n",
    "        fig = plt.figure(figsize=(12, 12));\n",
    "        classifier = DecisionTreeClassifier().fit(X, y)\n",
    "        disp = DecisionBoundaryDisplay.from_estimator(\n",
    "            classifier, X, \n",
    "            ax=fig.add_subplot(111),\n",
    "            response_method=\"predict_proba\" if len(np.unique(df['color'])) == 2 else \"predict\",\n",
    "            xlabel=\"x\", ylabel=\"y\",\n",
    "            alpha=0.5,\n",
    "        );\n",
    "        disp.ax_.scatter(X[:, 0], X[:, 1], c=y, edgecolor=\"k\");\n",
    "        plt.title(f\"{classifier.__class__.__name__}\");\n",
    "        disp.ax_.set_title(f\"{classifier.__class__.__name__}\");\n",
    "        plt.show();\n",
    "widget.observe(on_change, names=[\"data\"])\n",
    "on_change(None)\n",
    "page = ipywidgets.HBox([widget, output])\n",
    "page"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5f9bb7ce-8a7f-4879-9ddf-5b256bb0ff64",
   "metadata": {},
   "source": [
    "\n",
    "p<br><br><br><br><br><br><br><br><br><br><br><br>"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}