MilesCranmer commited on
Commit
04f3f2f
1 Parent(s): 5db0d89

Fully document jax/torch export

Browse files
Files changed (1) hide show
  1. docs/options.md +39 -2
docs/options.md CHANGED
@@ -15,7 +15,8 @@ may find useful include:
15
  - `batching`, `batchSize`
16
  - `variable_names` (or pandas input)
17
  - Constraining operator complexity
18
- - LaTeX, SymPy, and callable equation output
 
19
  - `loss`
20
 
21
  These are described below
@@ -144,7 +145,7 @@ The other terms say that each multiplication can only have sub-expressions
144
  of up to complexity 3 (e.g., 5.0 + x2) in each side, and cosine can only operate on
145
  expressions of complexity 5 (e.g., 5.0 + x2 exp(x3)).
146
 
147
- ## LaTeX, SymPy, callables
148
 
149
  The `pysr` command will return a pandas dataframe. The `sympy_format`
150
  column gives sympy equations, and the `lambda_format` gives callable
@@ -159,6 +160,42 @@ for the best equation, using the `score` column to sort equations.
159
  `best_latex()` returns the LaTeX form of this, and `best_callable()`
160
  returns a callable function.
161
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
162
  ## `loss`
163
 
164
  The default loss is mean-square error, and weighted mean-square error.
 
15
  - `batching`, `batchSize`
16
  - `variable_names` (or pandas input)
17
  - Constraining operator complexity
18
+ - LaTeX, SymPy
19
+ - Callable exports: numpy, pytorch, jax
20
  - `loss`
21
 
22
  These are described below
 
145
  of up to complexity 3 (e.g., 5.0 + x2) in each side, and cosine can only operate on
146
  expressions of complexity 5 (e.g., 5.0 + x2 exp(x3)).
147
 
148
+ ## LaTeX, SymPy
149
 
150
  The `pysr` command will return a pandas dataframe. The `sympy_format`
151
  column gives sympy equations, and the `lambda_format` gives callable
 
160
  `best_latex()` returns the LaTeX form of this, and `best_callable()`
161
  returns a callable function.
162
 
163
+
164
+ ## Callable exports: numpy, pytorch, jax
165
+
166
+ By default, the dataframe of equations will contain columns
167
+ with the identifier `lambda_format`. These are simple functions
168
+ which correspond to the equation, but executed
169
+ with numpy functions. You can pass your `X` matrix to these functions
170
+ just as you did to the `pysr` call. Thus, this allows
171
+ you to numerically evaluate the equations over different output.
172
+
173
+
174
+ One can do the same thing for PyTorch, which uses code
175
+ from [sympytorch](https://github.com/patrick-kidger/sympytorch),
176
+ and for JAX, which uses code from
177
+ [sympy2jax](https://github.com/MilesCranmer/sympy2jax).
178
+
179
+ For torch, set the argument `output_torch_format=True`, which
180
+ will generate a column `torch_format`. Each element of this column
181
+ is a PyTorch module which runs the equation, using PyTorch functions,
182
+ over `X` (as a PyTorch tensor). This is differentiable, and the
183
+ parameters of this PyTorch module correspond to the learned parameters
184
+ in the equation, and are trainable.
185
+
186
+ For jax, set the argument `output_jax_format=True`, which
187
+ will generate a column `jax_format`. Each element of this column
188
+ is a dictionary containing a `'callable'` (a JAX function),
189
+ and `'parameters'` (a list of parameters in the equation).
190
+ One can execute this function with: `element['callable'](X, element['parameters'])`.
191
+ Since the parameter list is a jax array, this therefore lets you also
192
+ train the parameters within JAX (and is differentiable).
193
+
194
+ If you forget to turn these on when calling the function initially,
195
+ you can re-run `get_hof(output_jax_format=True)`, and it will re-use
196
+ the equations and other state properties, assuming you haven't
197
+ re-run `pysr` in the meantime!
198
+
199
  ## `loss`
200
 
201
  The default loss is mean-square error, and weighted mean-square error.