Spaces:
Running
Running
MilesCranmer
commited on
Add more examples to colab notebook
Browse files- examples/pysr_demo.ipynb +270 -0
examples/pysr_demo.ipynb
CHANGED
@@ -711,6 +711,276 @@
|
|
711 |
"plt.scatter(X[:, 0], y_prediction)\n"
|
712 |
]
|
713 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
714 |
{
|
715 |
"cell_type": "markdown",
|
716 |
"metadata": {
|
|
|
711 |
"plt.scatter(X[:, 0], y_prediction)\n"
|
712 |
]
|
713 |
},
|
714 |
+
{
|
715 |
+
"attachments": {},
|
716 |
+
"cell_type": "markdown",
|
717 |
+
"metadata": {},
|
718 |
+
"source": [
|
719 |
+
"# Multiple outputs"
|
720 |
+
]
|
721 |
+
},
|
722 |
+
{
|
723 |
+
"attachments": {},
|
724 |
+
"cell_type": "markdown",
|
725 |
+
"metadata": {},
|
726 |
+
"source": [
|
727 |
+
"For multiple outputs, multiple equations are returned:"
|
728 |
+
]
|
729 |
+
},
|
730 |
+
{
|
731 |
+
"cell_type": "code",
|
732 |
+
"execution_count": null,
|
733 |
+
"metadata": {},
|
734 |
+
"outputs": [],
|
735 |
+
"source": [
|
736 |
+
"X = 2 * np.random.randn(100, 5)\n",
|
737 |
+
"y = 1 / X[:, [0, 1, 2]]\n"
|
738 |
+
]
|
739 |
+
},
|
740 |
+
{
|
741 |
+
"cell_type": "code",
|
742 |
+
"execution_count": null,
|
743 |
+
"metadata": {},
|
744 |
+
"outputs": [],
|
745 |
+
"source": [
|
746 |
+
"model = PySRRegressor(\n",
|
747 |
+
" binary_operators=[\"+\", \"*\"],\n",
|
748 |
+
" unary_operators=[\"inv(x) = 1/x\"],\n",
|
749 |
+
" extra_sympy_mappings={\"inv\": lambda x: 1/x},\n",
|
750 |
+
")\n",
|
751 |
+
"model.fit(X, y)"
|
752 |
+
]
|
753 |
+
},
|
754 |
+
{
|
755 |
+
"cell_type": "code",
|
756 |
+
"execution_count": null,
|
757 |
+
"metadata": {},
|
758 |
+
"outputs": [],
|
759 |
+
"source": [
|
760 |
+
"model"
|
761 |
+
]
|
762 |
+
},
|
763 |
+
{
|
764 |
+
"attachments": {},
|
765 |
+
"cell_type": "markdown",
|
766 |
+
"metadata": {},
|
767 |
+
"source": [
|
768 |
+
"# Julia packages and types"
|
769 |
+
]
|
770 |
+
},
|
771 |
+
{
|
772 |
+
"attachments": {},
|
773 |
+
"cell_type": "markdown",
|
774 |
+
"metadata": {},
|
775 |
+
"source": [
|
776 |
+
"PySR uses [SymbolicRegression.jl](https://github.com/MilesCranmer/SymbolicRegression.jl)\n",
|
777 |
+
"as its search backend. This is a pure Julia package, and so can interface easily with any other\n",
|
778 |
+
"Julia package.\n",
|
779 |
+
"For some tasks, it may be necessary to load such a package.\n",
|
780 |
+
"\n",
|
781 |
+
"For example, let's say we wish to discovery the following relationship:\n",
|
782 |
+
"\n",
|
783 |
+
"$$ y = p_{3x + 1} - 5, $$\n",
|
784 |
+
"\n",
|
785 |
+
"where $p_i$ is the $i$th prime number, and $x$ is the input feature.\n",
|
786 |
+
"\n",
|
787 |
+
"Let's see if we can discover this using\n",
|
788 |
+
"the [Primes.jl](https://github.com/JuliaMath/Primes.jl) package.\n",
|
789 |
+
"\n",
|
790 |
+
"First, let's get the Julia backend\n",
|
791 |
+
"(here, we manually specify 4 threads and `-O3` - although this will only work if PySR has not yet started):"
|
792 |
+
]
|
793 |
+
},
|
794 |
+
{
|
795 |
+
"cell_type": "code",
|
796 |
+
"execution_count": null,
|
797 |
+
"metadata": {},
|
798 |
+
"outputs": [],
|
799 |
+
"source": [
|
800 |
+
"import pysr\n",
|
801 |
+
"jl = pysr.julia_helpers.init_julia(julia_kwargs={\"threads\": 8, \"optimize\": 3})"
|
802 |
+
]
|
803 |
+
},
|
804 |
+
{
|
805 |
+
"attachments": {},
|
806 |
+
"cell_type": "markdown",
|
807 |
+
"metadata": {},
|
808 |
+
"source": [
|
809 |
+
"\n",
|
810 |
+
"\n",
|
811 |
+
"`jl` stores the Julia runtime.\n",
|
812 |
+
"\n",
|
813 |
+
"Now, let's run some Julia code to add the Primes.jl\n",
|
814 |
+
"package to the PySR environment:"
|
815 |
+
]
|
816 |
+
},
|
817 |
+
{
|
818 |
+
"cell_type": "code",
|
819 |
+
"execution_count": null,
|
820 |
+
"metadata": {},
|
821 |
+
"outputs": [],
|
822 |
+
"source": [
|
823 |
+
"jl.eval(\"\"\"\n",
|
824 |
+
"import Pkg\n",
|
825 |
+
"Pkg.add(\"Primes\")\n",
|
826 |
+
"\"\"\")"
|
827 |
+
]
|
828 |
+
},
|
829 |
+
{
|
830 |
+
"attachments": {},
|
831 |
+
"cell_type": "markdown",
|
832 |
+
"metadata": {},
|
833 |
+
"source": [
|
834 |
+
"This imports the Julia package manager, and uses it to install\n",
|
835 |
+
"`Primes.jl`. Now let's import `Primes.jl`:"
|
836 |
+
]
|
837 |
+
},
|
838 |
+
{
|
839 |
+
"cell_type": "code",
|
840 |
+
"execution_count": null,
|
841 |
+
"metadata": {},
|
842 |
+
"outputs": [],
|
843 |
+
"source": [
|
844 |
+
"jl.eval(\"import Primes\")"
|
845 |
+
]
|
846 |
+
},
|
847 |
+
{
|
848 |
+
"attachments": {},
|
849 |
+
"cell_type": "markdown",
|
850 |
+
"metadata": {},
|
851 |
+
"source": [
|
852 |
+
"\n",
|
853 |
+
"Now, we define a custom operator:\n"
|
854 |
+
]
|
855 |
+
},
|
856 |
+
{
|
857 |
+
"cell_type": "code",
|
858 |
+
"execution_count": null,
|
859 |
+
"metadata": {},
|
860 |
+
"outputs": [],
|
861 |
+
"source": [
|
862 |
+
"jl.eval(\"\"\"\n",
|
863 |
+
"function p(i::T) where T\n",
|
864 |
+
" if (0.5 < i < 1000)\n",
|
865 |
+
" return T(Primes.prime(round(Int, i)))\n",
|
866 |
+
" else\n",
|
867 |
+
" return T(NaN)\n",
|
868 |
+
" end\n",
|
869 |
+
"end\n",
|
870 |
+
"\"\"\")"
|
871 |
+
]
|
872 |
+
},
|
873 |
+
{
|
874 |
+
"attachments": {},
|
875 |
+
"cell_type": "markdown",
|
876 |
+
"metadata": {},
|
877 |
+
"source": [
|
878 |
+
"\n",
|
879 |
+
"We have created a a function `p`, which takes an arbitrary number as input.\n",
|
880 |
+
"`p` first checks whether the input is between 0.5 and 1000.\n",
|
881 |
+
"If out-of-bounds, it returns `NaN`.\n",
|
882 |
+
"If in-bounds, it rounds it to the nearest integer, compures the corresponding prime number, and then\n",
|
883 |
+
"converts it to the same type as input.\n",
|
884 |
+
"\n",
|
885 |
+
"Next, let's generate a list of primes for our test dataset.\n",
|
886 |
+
"Since we are using PyJulia, we can just call `p` directly to do this:\n"
|
887 |
+
]
|
888 |
+
},
|
889 |
+
{
|
890 |
+
"cell_type": "code",
|
891 |
+
"execution_count": null,
|
892 |
+
"metadata": {},
|
893 |
+
"outputs": [],
|
894 |
+
"source": [
|
895 |
+
"primes = {i: jl.p(i*1.0) for i in range(1, 999)}"
|
896 |
+
]
|
897 |
+
},
|
898 |
+
{
|
899 |
+
"attachments": {},
|
900 |
+
"cell_type": "markdown",
|
901 |
+
"metadata": {},
|
902 |
+
"source": [
|
903 |
+
"Next, let's use this list of primes to create a dataset of $x, y$ pairs:"
|
904 |
+
]
|
905 |
+
},
|
906 |
+
{
|
907 |
+
"cell_type": "code",
|
908 |
+
"execution_count": null,
|
909 |
+
"metadata": {},
|
910 |
+
"outputs": [],
|
911 |
+
"source": [
|
912 |
+
"import numpy as np\n",
|
913 |
+
"\n",
|
914 |
+
"X = np.random.randint(0, 100, 100)[:, None]\n",
|
915 |
+
"y = [primes[3*X[i, 0] + 1] - 5 + np.random.randn()*0.001 for i in range(100)]"
|
916 |
+
]
|
917 |
+
},
|
918 |
+
{
|
919 |
+
"attachments": {},
|
920 |
+
"cell_type": "markdown",
|
921 |
+
"metadata": {},
|
922 |
+
"source": [
|
923 |
+
"Note that we have also added a tiny bit of noise to the dataset.\n",
|
924 |
+
"\n",
|
925 |
+
"Finally, let's create a PySR model, and pass the custom operator. We also need to define the sympy equivalent, which we can leave as a placeholder for now:"
|
926 |
+
]
|
927 |
+
},
|
928 |
+
{
|
929 |
+
"cell_type": "code",
|
930 |
+
"execution_count": null,
|
931 |
+
"metadata": {},
|
932 |
+
"outputs": [],
|
933 |
+
"source": [
|
934 |
+
"from pysr import PySRRegressor\n",
|
935 |
+
"import sympy\n",
|
936 |
+
"\n",
|
937 |
+
"class sympy_p(sympy.Function):\n",
|
938 |
+
" pass\n",
|
939 |
+
"\n",
|
940 |
+
"model = PySRRegressor(\n",
|
941 |
+
" binary_operators=[\"+\", \"-\", \"*\", \"/\"],\n",
|
942 |
+
" unary_operators=[\"p\"],\n",
|
943 |
+
" niterations=100,\n",
|
944 |
+
" extra_sympy_mappings={\"p\": sympy_p}\n",
|
945 |
+
")"
|
946 |
+
]
|
947 |
+
},
|
948 |
+
{
|
949 |
+
"attachments": {},
|
950 |
+
"cell_type": "markdown",
|
951 |
+
"metadata": {},
|
952 |
+
"source": [
|
953 |
+
"We are all set to go! Let's see if we can find the true relation:"
|
954 |
+
]
|
955 |
+
},
|
956 |
+
{
|
957 |
+
"cell_type": "code",
|
958 |
+
"execution_count": null,
|
959 |
+
"metadata": {},
|
960 |
+
"outputs": [],
|
961 |
+
"source": [
|
962 |
+
"model.fit(X, y)"
|
963 |
+
]
|
964 |
+
},
|
965 |
+
{
|
966 |
+
"attachments": {},
|
967 |
+
"cell_type": "markdown",
|
968 |
+
"metadata": {},
|
969 |
+
"source": [
|
970 |
+
"if all works out, you should be able to see the true relation (note that the constant offset might not be exactly 1, since it is allowed to round to the nearest integer).\n",
|
971 |
+
"\n",
|
972 |
+
"You can get the sympy version of the best equation with:"
|
973 |
+
]
|
974 |
+
},
|
975 |
+
{
|
976 |
+
"cell_type": "code",
|
977 |
+
"execution_count": null,
|
978 |
+
"metadata": {},
|
979 |
+
"outputs": [],
|
980 |
+
"source": [
|
981 |
+
"model.sympy()"
|
982 |
+
]
|
983 |
+
},
|
984 |
{
|
985 |
"cell_type": "markdown",
|
986 |
"metadata": {
|