Spaces:
Sleeping
Sleeping
MilesCranmer
commited on
Commit
•
86d9c0b
1
Parent(s):
1ec3ee8
Expand colab notebook
Browse files- examples/pysr_demo.ipynb +71 -8
examples/pysr_demo.ipynb
CHANGED
@@ -1262,6 +1262,7 @@
|
|
1262 |
]
|
1263 |
},
|
1264 |
{
|
|
|
1265 |
"cell_type": "markdown",
|
1266 |
"metadata": {
|
1267 |
"id": "nCCIvvAGuyFi"
|
@@ -1269,7 +1270,60 @@
|
|
1269 |
"source": [
|
1270 |
"## Learning over the network:\n",
|
1271 |
"\n",
|
1272 |
-
"Now, let's fit `g` using PySR
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1273 |
]
|
1274 |
},
|
1275 |
{
|
@@ -1281,17 +1335,15 @@
|
|
1281 |
},
|
1282 |
"outputs": [],
|
1283 |
"source": [
|
1284 |
-
"np.random.
|
1285 |
-
"
|
1286 |
-
"tmpy = y_i_for_pysr.detach().numpy().reshape(-1)\n",
|
1287 |
-
"idx2 = np.random.randint(0, tmpy.shape[0], size=500)\n",
|
1288 |
"\n",
|
1289 |
"model = PySRRegressor(\n",
|
1290 |
" niterations=20,\n",
|
1291 |
" binary_operators=[\"plus\", \"sub\", \"mult\"],\n",
|
1292 |
" unary_operators=[\"cos\", \"square\", \"neg\"],\n",
|
1293 |
")\n",
|
1294 |
-
"model.fit(
|
1295 |
]
|
1296 |
},
|
1297 |
{
|
@@ -1310,9 +1362,12 @@
|
|
1310 |
"id": "6WuaeqyqbDhe"
|
1311 |
},
|
1312 |
"source": [
|
1313 |
-
"Recall we are searching for $
|
|
|
|
|
|
|
1314 |
"\n",
|
1315 |
-
"
|
1316 |
]
|
1317 |
},
|
1318 |
{
|
@@ -1384,7 +1439,15 @@
|
|
1384 |
"name": "main_ipynb"
|
1385 |
},
|
1386 |
"language_info": {
|
|
|
|
|
|
|
|
|
|
|
|
|
1387 |
"name": "python",
|
|
|
|
|
1388 |
"version": "3.10.9"
|
1389 |
}
|
1390 |
},
|
|
|
1262 |
]
|
1263 |
},
|
1264 |
{
|
1265 |
+
"attachments": {},
|
1266 |
"cell_type": "markdown",
|
1267 |
"metadata": {
|
1268 |
"id": "nCCIvvAGuyFi"
|
|
|
1270 |
"source": [
|
1271 |
"## Learning over the network:\n",
|
1272 |
"\n",
|
1273 |
+
"Now, let's fit `g` using PySR.\n",
|
1274 |
+
"\n",
|
1275 |
+
"> **Warning**\n",
|
1276 |
+
">\n",
|
1277 |
+
"> First, let's save the data, because sometimes PyTorch and PyJulia's C bindings interfere and cause the colab kernel to crash. If we need to restart, we can just load the data without having to retrain the network:"
|
1278 |
+
]
|
1279 |
+
},
|
1280 |
+
{
|
1281 |
+
"cell_type": "code",
|
1282 |
+
"execution_count": null,
|
1283 |
+
"metadata": {},
|
1284 |
+
"outputs": [],
|
1285 |
+
"source": [
|
1286 |
+
"nnet_recordings = {\n",
|
1287 |
+
" \"g_input\": X_for_pysr.detach().cpu().numpy().reshape(-1, 5),\n",
|
1288 |
+
" \"g_output\": y_i_for_pysr.detach().cpu().numpy().reshape(-1),\n",
|
1289 |
+
" \"f_input\": y_for_pysr.detach().cpu().numpy().reshape(-1, 1),\n",
|
1290 |
+
" \"f_output\": z_for_pysr.detach().cpu().numpy().reshape(-1),\n",
|
1291 |
+
"}\n",
|
1292 |
+
"\n",
|
1293 |
+
"# Save the data for later use:\n",
|
1294 |
+
"import pickle as pkl\n",
|
1295 |
+
"\n",
|
1296 |
+
"with open(\"nnet_recordings.pkl\", \"wb\") as f:\n",
|
1297 |
+
" pkl.dump(nnet_recordings, f)"
|
1298 |
+
]
|
1299 |
+
},
|
1300 |
+
{
|
1301 |
+
"attachments": {},
|
1302 |
+
"cell_type": "markdown",
|
1303 |
+
"metadata": {},
|
1304 |
+
"source": [
|
1305 |
+
"We can now load the data:"
|
1306 |
+
]
|
1307 |
+
},
|
1308 |
+
{
|
1309 |
+
"cell_type": "code",
|
1310 |
+
"execution_count": null,
|
1311 |
+
"metadata": {},
|
1312 |
+
"outputs": [],
|
1313 |
+
"source": [
|
1314 |
+
"nnet_recordings = pkl.load(open(\"nnet_recordings.pkl\", \"rb\"))\n",
|
1315 |
+
"f_input = nnet_recordings[\"f_input\"]\n",
|
1316 |
+
"f_output = nnet_recordings[\"f_output\"]\n",
|
1317 |
+
"g_input = nnet_recordings[\"g_input\"]\n",
|
1318 |
+
"g_output = nnet_recordings[\"g_output\"]"
|
1319 |
+
]
|
1320 |
+
},
|
1321 |
+
{
|
1322 |
+
"attachments": {},
|
1323 |
+
"cell_type": "markdown",
|
1324 |
+
"metadata": {},
|
1325 |
+
"source": [
|
1326 |
+
"And now fit using a subsample of the data (symbolic regression only needs a small sample to find the best equation):"
|
1327 |
]
|
1328 |
},
|
1329 |
{
|
|
|
1335 |
},
|
1336 |
"outputs": [],
|
1337 |
"source": [
|
1338 |
+
"rstate = np.random.RandomState(0)\n",
|
1339 |
+
"f_sample_idx = rstate.choice(f_input.shape[0], size=500, replace=False)\n",
|
|
|
|
|
1340 |
"\n",
|
1341 |
"model = PySRRegressor(\n",
|
1342 |
" niterations=20,\n",
|
1343 |
" binary_operators=[\"plus\", \"sub\", \"mult\"],\n",
|
1344 |
" unary_operators=[\"cos\", \"square\", \"neg\"],\n",
|
1345 |
")\n",
|
1346 |
+
"model.fit(g_input[f_sample_idx], g_output[f_sample_idx])"
|
1347 |
]
|
1348 |
},
|
1349 |
{
|
|
|
1362 |
"id": "6WuaeqyqbDhe"
|
1363 |
},
|
1364 |
"source": [
|
1365 |
+
"Recall we are searching for $f$ and $g$ such that:\n",
|
1366 |
+
"$$z=f(\\sum g(x_i))$$ \n",
|
1367 |
+
"which approximates the true relation:\n",
|
1368 |
+
"$$ z = y^2,\\quad y = \\frac{1}{10} \\sum(y_i),\\quad y_i = x_{i0}^2 + 6 \\cos(2 x_{i2})$$\n",
|
1369 |
"\n",
|
1370 |
+
"Let's see how well we did in recovering $g$:"
|
1371 |
]
|
1372 |
},
|
1373 |
{
|
|
|
1439 |
"name": "main_ipynb"
|
1440 |
},
|
1441 |
"language_info": {
|
1442 |
+
"codemirror_mode": {
|
1443 |
+
"name": "ipython",
|
1444 |
+
"version": 3
|
1445 |
+
},
|
1446 |
+
"file_extension": ".py",
|
1447 |
+
"mimetype": "text/x-python",
|
1448 |
"name": "python",
|
1449 |
+
"nbconvert_exporter": "python",
|
1450 |
+
"pygments_lexer": "ipython3",
|
1451 |
"version": "3.10.9"
|
1452 |
}
|
1453 |
},
|