Spaces:
Sleeping
Sleeping
MilesCranmer
commited on
Commit
•
04f3f2f
1
Parent(s):
5db0d89
Fully document jax/torch export
Browse files- docs/options.md +39 -2
docs/options.md
CHANGED
@@ -15,7 +15,8 @@ may find useful include:
|
|
15 |
- `batching`, `batchSize`
|
16 |
- `variable_names` (or pandas input)
|
17 |
- Constraining operator complexity
|
18 |
-
- LaTeX, SymPy
|
|
|
19 |
- `loss`
|
20 |
|
21 |
These are described below
|
@@ -144,7 +145,7 @@ The other terms say that each multiplication can only have sub-expressions
|
|
144 |
of up to complexity 3 (e.g., 5.0 + x2) in each side, and cosine can only operate on
|
145 |
expressions of complexity 5 (e.g., 5.0 + x2 exp(x3)).
|
146 |
|
147 |
-
## LaTeX, SymPy
|
148 |
|
149 |
The `pysr` command will return a pandas dataframe. The `sympy_format`
|
150 |
column gives sympy equations, and the `lambda_format` gives callable
|
@@ -159,6 +160,42 @@ for the best equation, using the `score` column to sort equations.
|
|
159 |
`best_latex()` returns the LaTeX form of this, and `best_callable()`
|
160 |
returns a callable function.
|
161 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
162 |
## `loss`
|
163 |
|
164 |
The default loss is mean-square error, and weighted mean-square error.
|
|
|
15 |
- `batching`, `batchSize`
|
16 |
- `variable_names` (or pandas input)
|
17 |
- Constraining operator complexity
|
18 |
+
- LaTeX, SymPy
|
19 |
+
- Callable exports: numpy, pytorch, jax
|
20 |
- `loss`
|
21 |
|
22 |
These are described below
|
|
|
145 |
of up to complexity 3 (e.g., 5.0 + x2) in each side, and cosine can only operate on
|
146 |
expressions of complexity 5 (e.g., 5.0 + x2 exp(x3)).
|
147 |
|
148 |
+
## LaTeX, SymPy
|
149 |
|
150 |
The `pysr` command will return a pandas dataframe. The `sympy_format`
|
151 |
column gives sympy equations, and the `lambda_format` gives callable
|
|
|
160 |
`best_latex()` returns the LaTeX form of this, and `best_callable()`
|
161 |
returns a callable function.
|
162 |
|
163 |
+
|
164 |
+
## Callable exports: numpy, pytorch, jax
|
165 |
+
|
166 |
+
By default, the dataframe of equations will contain columns
|
167 |
+
with the identifier `lambda_format`. These are simple functions
|
168 |
+
which correspond to the equation, but executed
|
169 |
+
with numpy functions. You can pass your `X` matrix to these functions
|
170 |
+
just as you did to the `pysr` call. Thus, this allows
|
171 |
+
you to numerically evaluate the equations over different output.
|
172 |
+
|
173 |
+
|
174 |
+
One can do the same thing for PyTorch, which uses code
|
175 |
+
from [sympytorch](https://github.com/patrick-kidger/sympytorch),
|
176 |
+
and for JAX, which uses code from
|
177 |
+
[sympy2jax](https://github.com/MilesCranmer/sympy2jax).
|
178 |
+
|
179 |
+
For torch, set the argument `output_torch_format=True`, which
|
180 |
+
will generate a column `torch_format`. Each element of this column
|
181 |
+
is a PyTorch module which runs the equation, using PyTorch functions,
|
182 |
+
over `X` (as a PyTorch tensor). This is differentiable, and the
|
183 |
+
parameters of this PyTorch module correspond to the learned parameters
|
184 |
+
in the equation, and are trainable.
|
185 |
+
|
186 |
+
For jax, set the argument `output_jax_format=True`, which
|
187 |
+
will generate a column `jax_format`. Each element of this column
|
188 |
+
is a dictionary containing a `'callable'` (a JAX function),
|
189 |
+
and `'parameters'` (a list of parameters in the equation).
|
190 |
+
One can execute this function with: `element['callable'](X, element['parameters'])`.
|
191 |
+
Since the parameter list is a jax array, this therefore lets you also
|
192 |
+
train the parameters within JAX (and is differentiable).
|
193 |
+
|
194 |
+
If you forget to turn these on when calling the function initially,
|
195 |
+
you can re-run `get_hof(output_jax_format=True)`, and it will re-use
|
196 |
+
the equations and other state properties, assuming you haven't
|
197 |
+
re-run `pysr` in the meantime!
|
198 |
+
|
199 |
## `loss`
|
200 |
|
201 |
The default loss is mean-square error, and weighted mean-square error.
|