Spaces:
Sleeping
Sleeping
MilesCranmer
commited on
Commit
•
4c39e04
1
Parent(s):
bb97e2c
Update documentation for sklearn interface
Browse files- docs/options.md +59 -42
docs/options.md
CHANGED
@@ -1,10 +1,8 @@
|
|
1 |
# Features and Options
|
2 |
|
3 |
-
You likely don't need to tune the hyperparameters yourself,
|
4 |
-
but if you would like, you can use `hyperparamopt.py` as an example.
|
5 |
-
|
6 |
Some configurable features and options in `PySR` which you
|
7 |
may find useful include:
|
|
|
8 |
- `binary_operators`, `unary_operators`
|
9 |
- `niterations`
|
10 |
- `ncyclesperiteration`
|
@@ -21,18 +19,31 @@ may find useful include:
|
|
21 |
|
22 |
These are described below
|
23 |
|
24 |
-
The program will output a pandas DataFrame containing the equations
|
25 |
-
|
|
|
|
|
|
|
26 |
at the end of every iteration,
|
27 |
-
which is `hall_of_fame_{date_time}.csv` by default.
|
28 |
-
equations to stdout.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
29 |
|
30 |
## Operators
|
31 |
|
32 |
A list of operators can be found on the operators page.
|
33 |
One can define custom operators in Julia by passing a string:
|
34 |
```python
|
35 |
-
|
36 |
binary_operators=["mult", "plus", "special(x, y) = x^2 + y"],
|
37 |
extra_sympy_mappings={'special': lambda x, y: x**2 + y},
|
38 |
unary_operators=["cos"])
|
@@ -51,8 +62,6 @@ so that the SymPy code can understand the output equation from Julia,
|
|
51 |
when constructing a useable function. This step is optional, but
|
52 |
is necessary for the `lambda_format` to work.
|
53 |
|
54 |
-
One can also edit `operators.jl`.
|
55 |
-
|
56 |
## Iterations
|
57 |
|
58 |
This is the total number of generations that `pysr` will run for.
|
@@ -78,15 +87,15 @@ each population stay closer to the best current equations.
|
|
78 |
|
79 |
One can adjust the number of workers used by Julia with the
|
80 |
`procs` option. You should set this equal to the number of cores
|
81 |
-
you want `pysr` to use.
|
82 |
-
populations simultaneously by default.
|
83 |
|
84 |
## Populations
|
85 |
|
86 |
-
By default, `populations=
|
87 |
-
number of populations with this option.
|
|
|
88 |
the diversity of equations discovered, though will take longer to train.
|
89 |
-
However, it
|
90 |
as there are multiple populations running
|
91 |
on each core.
|
92 |
|
@@ -100,7 +109,8 @@ instead of the usual 4, which creates more populations
|
|
100 |
sigma = ...
|
101 |
weights = 1/sigma**2
|
102 |
|
103 |
-
|
|
|
104 |
```
|
105 |
|
106 |
## Max size
|
@@ -147,55 +157,62 @@ expressions of complexity 5 (e.g., 5.0 + x2 exp(x3)).
|
|
147 |
|
148 |
## LaTeX, SymPy
|
149 |
|
150 |
-
|
151 |
-
|
152 |
-
|
|
|
|
|
|
|
153 |
|
154 |
There are also some helper functions for doing this quickly.
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
for the best equation, using the `score` column to sort equations.
|
160 |
-
`best_latex()` returns the LaTeX form of this, and `best_callable()`
|
161 |
-
returns a callable function.
|
162 |
|
163 |
|
164 |
## Callable exports: numpy, pytorch, jax
|
165 |
|
166 |
By default, the dataframe of equations will contain columns
|
167 |
-
with the identifier `lambda_format`.
|
168 |
-
which correspond to the equation, but executed
|
169 |
-
with numpy functions.
|
170 |
-
|
|
|
171 |
you to numerically evaluate the equations over different output.
|
172 |
|
|
|
|
|
|
|
|
|
|
|
173 |
|
174 |
One can do the same thing for PyTorch, which uses code
|
175 |
from [sympytorch](https://github.com/patrick-kidger/sympytorch),
|
176 |
and for JAX, which uses code from
|
177 |
[sympy2jax](https://github.com/MilesCranmer/sympy2jax).
|
178 |
|
179 |
-
|
180 |
-
|
181 |
-
is a PyTorch module which runs the equation, using PyTorch functions,
|
182 |
over `X` (as a PyTorch tensor). This is differentiable, and the
|
183 |
parameters of this PyTorch module correspond to the learned parameters
|
184 |
in the equation, and are trainable.
|
|
|
|
|
|
|
|
|
185 |
|
186 |
-
For
|
187 |
-
will
|
188 |
-
is a dictionary containing a `'callable'` (a JAX function),
|
189 |
and `'parameters'` (a list of parameters in the equation).
|
190 |
-
|
|
|
|
|
|
|
|
|
191 |
Since the parameter list is a jax array, this therefore lets you also
|
192 |
train the parameters within JAX (and is differentiable).
|
193 |
|
194 |
-
If you forget to turn these on when calling the function initially,
|
195 |
-
you can re-run `get_hof(output_jax_format=True)`, and it will re-use
|
196 |
-
the equations and other state properties, assuming you haven't
|
197 |
-
re-run `pysr` in the meantime!
|
198 |
-
|
199 |
## `loss`
|
200 |
|
201 |
The default loss is mean-square error, and weighted mean-square error.
|
|
|
1 |
# Features and Options
|
2 |
|
|
|
|
|
|
|
3 |
Some configurable features and options in `PySR` which you
|
4 |
may find useful include:
|
5 |
+
- `model_selection`
|
6 |
- `binary_operators`, `unary_operators`
|
7 |
- `niterations`
|
8 |
- `ncyclesperiteration`
|
|
|
19 |
|
20 |
These are described below
|
21 |
|
22 |
+
The program will output a pandas DataFrame containing the equations
|
23 |
+
to `PySRRegressor.equations` containing the loss value
|
24 |
+
and complexity.
|
25 |
+
|
26 |
+
It will also dump to a csv
|
27 |
at the end of every iteration,
|
28 |
+
which is `hall_of_fame_{date_time}.csv` by default.
|
29 |
+
It also prints the equations to stdout.
|
30 |
+
|
31 |
+
## Model selection
|
32 |
+
|
33 |
+
By default, `PySRRegressor` uses `model_selection='best'`
|
34 |
+
which selects an equation from `PySRRegressor.equations` using
|
35 |
+
a combination of accuracy and complexity.
|
36 |
+
You can also select `model_selection='accuracy'`.
|
37 |
+
|
38 |
+
By printing a model (i.e., `print(model)`), you can see
|
39 |
+
the equation selection with the arrow shown in the `pick` column.
|
40 |
|
41 |
## Operators
|
42 |
|
43 |
A list of operators can be found on the operators page.
|
44 |
One can define custom operators in Julia by passing a string:
|
45 |
```python
|
46 |
+
PySRRegressor(niterations=100,
|
47 |
binary_operators=["mult", "plus", "special(x, y) = x^2 + y"],
|
48 |
extra_sympy_mappings={'special': lambda x, y: x**2 + y},
|
49 |
unary_operators=["cos"])
|
|
|
62 |
when constructing a useable function. This step is optional, but
|
63 |
is necessary for the `lambda_format` to work.
|
64 |
|
|
|
|
|
65 |
## Iterations
|
66 |
|
67 |
This is the total number of generations that `pysr` will run for.
|
|
|
87 |
|
88 |
One can adjust the number of workers used by Julia with the
|
89 |
`procs` option. You should set this equal to the number of cores
|
90 |
+
you want `pysr` to use.
|
|
|
91 |
|
92 |
## Populations
|
93 |
|
94 |
+
By default, `populations=20`, but you can set a different
|
95 |
+
number of populations with this option.
|
96 |
+
More populations may increase
|
97 |
the diversity of equations discovered, though will take longer to train.
|
98 |
+
However, it is usually more efficient to have `populations>procs`,
|
99 |
as there are multiple populations running
|
100 |
on each core.
|
101 |
|
|
|
109 |
sigma = ...
|
110 |
weights = 1/sigma**2
|
111 |
|
112 |
+
model = PySRRegressor(procs=10)
|
113 |
+
model.fit(X, y, weights=weights)
|
114 |
```
|
115 |
|
116 |
## Max size
|
|
|
157 |
|
158 |
## LaTeX, SymPy
|
159 |
|
160 |
+
After running `model.fit(...)`, you can look at
|
161 |
+
`model.equations` which is a pandas dataframe.
|
162 |
+
The `sympy_format` column gives sympy equations,
|
163 |
+
and the `lambda_format` gives callable functions.
|
164 |
+
You can optionally pass a pandas dataframe to the callable function,
|
165 |
+
if you called `.fit` on a pandas dataframe as well.
|
166 |
|
167 |
There are also some helper functions for doing this quickly.
|
168 |
+
- `model.latex()` will generate a TeX formatted output of your equation.
|
169 |
+
- `model.sympy()` will return the SymPy representation.
|
170 |
+
- `model.jax()` will return a callable JAX function combined with parameters (see below)
|
171 |
+
- `model.pytorch()` will return a PyTorch model (see below).
|
|
|
|
|
|
|
172 |
|
173 |
|
174 |
## Callable exports: numpy, pytorch, jax
|
175 |
|
176 |
By default, the dataframe of equations will contain columns
|
177 |
+
with the identifier `lambda_format`.
|
178 |
+
These are simple functions which correspond to the equation, but executed
|
179 |
+
with numpy functions.
|
180 |
+
You can pass your `X` matrix to these functions
|
181 |
+
just as you did to the `model.fit` call. Thus, this allows
|
182 |
you to numerically evaluate the equations over different output.
|
183 |
|
184 |
+
Calling `model.predict` will execute the `lambda_format` of
|
185 |
+
the best equation, and return the result. If you selected
|
186 |
+
`model_selection="best"`, this will use an equation that combines
|
187 |
+
accuracy with simplicity. For `model_selection="accuracy"`, this will just
|
188 |
+
look at accuracy.
|
189 |
|
190 |
One can do the same thing for PyTorch, which uses code
|
191 |
from [sympytorch](https://github.com/patrick-kidger/sympytorch),
|
192 |
and for JAX, which uses code from
|
193 |
[sympy2jax](https://github.com/MilesCranmer/sympy2jax).
|
194 |
|
195 |
+
Calling `model.pytorch()` will return
|
196 |
+
a PyTorch module which runs the equation, using PyTorch functions,
|
|
|
197 |
over `X` (as a PyTorch tensor). This is differentiable, and the
|
198 |
parameters of this PyTorch module correspond to the learned parameters
|
199 |
in the equation, and are trainable.
|
200 |
+
```python
|
201 |
+
output = model.pytorch()
|
202 |
+
output['callable'](X)
|
203 |
+
```
|
204 |
|
205 |
+
For JAX, you can equivalently set the argument `output_jax_format=True`.
|
206 |
+
This will return a dictionary containing a `'callable'` (a JAX function),
|
|
|
207 |
and `'parameters'` (a list of parameters in the equation).
|
208 |
+
You can execute this function with:
|
209 |
+
```python
|
210 |
+
output = model.jax()
|
211 |
+
output['callable'](X, output['parameters'])
|
212 |
+
```
|
213 |
Since the parameter list is a jax array, this therefore lets you also
|
214 |
train the parameters within JAX (and is differentiable).
|
215 |
|
|
|
|
|
|
|
|
|
|
|
216 |
## `loss`
|
217 |
|
218 |
The default loss is mean-square error, and weighted mean-square error.
|