MilesCranmer commited on
Commit
86d9c0b
1 Parent(s): 1ec3ee8

Expand colab notebook

Browse files
Files changed (1) hide show
  1. examples/pysr_demo.ipynb +71 -8
examples/pysr_demo.ipynb CHANGED
@@ -1262,6 +1262,7 @@
1262
  ]
1263
  },
1264
  {
 
1265
  "cell_type": "markdown",
1266
  "metadata": {
1267
  "id": "nCCIvvAGuyFi"
@@ -1269,7 +1270,60 @@
1269
  "source": [
1270
  "## Learning over the network:\n",
1271
  "\n",
1272
- "Now, let's fit `g` using PySR:"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1273
  ]
1274
  },
1275
  {
@@ -1281,17 +1335,15 @@
1281
  },
1282
  "outputs": [],
1283
  "source": [
1284
- "np.random.seed(1)\n",
1285
- "tmpX = X_for_pysr.detach().numpy().reshape(-1, 5)\n",
1286
- "tmpy = y_i_for_pysr.detach().numpy().reshape(-1)\n",
1287
- "idx2 = np.random.randint(0, tmpy.shape[0], size=500)\n",
1288
  "\n",
1289
  "model = PySRRegressor(\n",
1290
  " niterations=20,\n",
1291
  " binary_operators=[\"plus\", \"sub\", \"mult\"],\n",
1292
  " unary_operators=[\"cos\", \"square\", \"neg\"],\n",
1293
  ")\n",
1294
- "model.fit(X=tmpX[idx2], y=tmpy[idx2])"
1295
  ]
1296
  },
1297
  {
@@ -1310,9 +1362,12 @@
1310
  "id": "6WuaeqyqbDhe"
1311
  },
1312
  "source": [
1313
- "Recall we are searching for $y_i$ above:\n",
 
 
 
1314
  "\n",
1315
- "$$ z = y^2,\\quad y = \\frac{1}{10} \\sum(y_i),\\quad y_i = x_{i0}^2 + 6 \\cos(2 x_{i2})$$"
1316
  ]
1317
  },
1318
  {
@@ -1384,7 +1439,15 @@
1384
  "name": "main_ipynb"
1385
  },
1386
  "language_info": {
 
 
 
 
 
 
1387
  "name": "python",
 
 
1388
  "version": "3.10.9"
1389
  }
1390
  },
 
1262
  ]
1263
  },
1264
  {
1265
+ "attachments": {},
1266
  "cell_type": "markdown",
1267
  "metadata": {
1268
  "id": "nCCIvvAGuyFi"
 
1270
  "source": [
1271
  "## Learning over the network:\n",
1272
  "\n",
1273
+ "Now, let's fit `g` using PySR.\n",
1274
+ "\n",
1275
+ "> **Warning**\n",
1276
+ ">\n",
1277
+ "> First, let's save the data, because sometimes PyTorch and PyJulia's C bindings interfere and cause the colab kernel to crash. If we need to restart, we can just load the data without having to retrain the network:"
1278
+ ]
1279
+ },
1280
+ {
1281
+ "cell_type": "code",
1282
+ "execution_count": null,
1283
+ "metadata": {},
1284
+ "outputs": [],
1285
+ "source": [
1286
+ "nnet_recordings = {\n",
1287
+ " \"g_input\": X_for_pysr.detach().cpu().numpy().reshape(-1, 5),\n",
1288
+ " \"g_output\": y_i_for_pysr.detach().cpu().numpy().reshape(-1),\n",
1289
+ " \"f_input\": y_for_pysr.detach().cpu().numpy().reshape(-1, 1),\n",
1290
+ " \"f_output\": z_for_pysr.detach().cpu().numpy().reshape(-1),\n",
1291
+ "}\n",
1292
+ "\n",
1293
+ "# Save the data for later use:\n",
1294
+ "import pickle as pkl\n",
1295
+ "\n",
1296
+ "with open(\"nnet_recordings.pkl\", \"wb\") as f:\n",
1297
+ " pkl.dump(nnet_recordings, f)"
1298
+ ]
1299
+ },
1300
+ {
1301
+ "attachments": {},
1302
+ "cell_type": "markdown",
1303
+ "metadata": {},
1304
+ "source": [
1305
+ "We can now load the data:"
1306
+ ]
1307
+ },
1308
+ {
1309
+ "cell_type": "code",
1310
+ "execution_count": null,
1311
+ "metadata": {},
1312
+ "outputs": [],
1313
+ "source": [
1314
+ "nnet_recordings = pkl.load(open(\"nnet_recordings.pkl\", \"rb\"))\n",
1315
+ "f_input = nnet_recordings[\"f_input\"]\n",
1316
+ "f_output = nnet_recordings[\"f_output\"]\n",
1317
+ "g_input = nnet_recordings[\"g_input\"]\n",
1318
+ "g_output = nnet_recordings[\"g_output\"]"
1319
+ ]
1320
+ },
1321
+ {
1322
+ "attachments": {},
1323
+ "cell_type": "markdown",
1324
+ "metadata": {},
1325
+ "source": [
1326
+ "And now fit using a subsample of the data (symbolic regression only needs a small sample to find the best equation):"
1327
  ]
1328
  },
1329
  {
 
1335
  },
1336
  "outputs": [],
1337
  "source": [
1338
+ "rstate = np.random.RandomState(0)\n",
1339
+ "f_sample_idx = rstate.choice(f_input.shape[0], size=500, replace=False)\n",
 
 
1340
  "\n",
1341
  "model = PySRRegressor(\n",
1342
  " niterations=20,\n",
1343
  " binary_operators=[\"plus\", \"sub\", \"mult\"],\n",
1344
  " unary_operators=[\"cos\", \"square\", \"neg\"],\n",
1345
  ")\n",
1346
+ "model.fit(g_input[f_sample_idx], g_output[f_sample_idx])"
1347
  ]
1348
  },
1349
  {
 
1362
  "id": "6WuaeqyqbDhe"
1363
  },
1364
  "source": [
1365
+ "Recall we are searching for $f$ and $g$ such that:\n",
1366
+ "$$z=f(\\sum g(x_i))$$ \n",
1367
+ "which approximates the true relation:\n",
1368
+ "$$ z = y^2,\\quad y = \\frac{1}{10} \\sum(y_i),\\quad y_i = x_{i0}^2 + 6 \\cos(2 x_{i2})$$\n",
1369
  "\n",
1370
+ "Let's see how well we did in recovering $g$:"
1371
  ]
1372
  },
1373
  {
 
1439
  "name": "main_ipynb"
1440
  },
1441
  "language_info": {
1442
+ "codemirror_mode": {
1443
+ "name": "ipython",
1444
+ "version": 3
1445
+ },
1446
+ "file_extension": ".py",
1447
+ "mimetype": "text/x-python",
1448
  "name": "python",
1449
+ "nbconvert_exporter": "python",
1450
+ "pygments_lexer": "ipython3",
1451
  "version": "3.10.9"
1452
  }
1453
  },