Spaces:
Running
Running
MilesCranmer
commited on
Commit
•
42cd6af
1
Parent(s):
4c39e04
Add jax, pytorch, sympy output from Regressor
Browse files- docs/options.md +15 -11
- pysr/sklearn.py +13 -5
docs/options.md
CHANGED
@@ -198,17 +198,18 @@ over `X` (as a PyTorch tensor). This is differentiable, and the
|
|
198 |
parameters of this PyTorch module correspond to the learned parameters
|
199 |
in the equation, and are trainable.
|
200 |
```python
|
201 |
-
|
202 |
-
|
203 |
```
|
|
|
204 |
|
205 |
-
For JAX, you can equivalently
|
206 |
This will return a dictionary containing a `'callable'` (a JAX function),
|
207 |
and `'parameters'` (a list of parameters in the equation).
|
208 |
You can execute this function with:
|
209 |
```python
|
210 |
-
|
211 |
-
|
212 |
```
|
213 |
Since the parameter list is a jax array, this therefore lets you also
|
214 |
train the parameters within JAX (and is differentiable).
|
@@ -226,26 +227,29 @@ Here are some additional examples:
|
|
226 |
|
227 |
abs(x-y) loss
|
228 |
```python
|
229 |
-
|
230 |
```
|
231 |
Note that the function name doesn't matter:
|
232 |
```python
|
233 |
-
|
234 |
```
|
235 |
With weights:
|
236 |
```python
|
237 |
-
|
|
|
238 |
```
|
239 |
Weights can be used in arbitrary ways:
|
240 |
```python
|
241 |
-
|
|
|
242 |
```
|
243 |
Built-in loss (faster) (see [losses](https://astroautomata.com/SymbolicRegression.jl/dev/losses/)).
|
244 |
This one computes the L3 norm:
|
245 |
```python
|
246 |
-
|
247 |
```
|
248 |
Can also uses these losses for weighted (weighted-average):
|
249 |
```python
|
250 |
-
|
|
|
251 |
```
|
|
|
198 |
parameters of this PyTorch module correspond to the learned parameters
|
199 |
in the equation, and are trainable.
|
200 |
```python
|
201 |
+
torch_model = model.pytorch()
|
202 |
+
torch_model(X)
|
203 |
```
|
204 |
+
**Warning: If you are using custom operators, you must define `extra_torch_mappings` or `extra_jax_mappings` (both are `dict` of callables) to provide an equivalent definition of the functions.** (At any time you can set these parameters or any others with `model.set_params`.)
|
205 |
|
206 |
+
For JAX, you can equivalently call `model.jax()`
|
207 |
This will return a dictionary containing a `'callable'` (a JAX function),
|
208 |
and `'parameters'` (a list of parameters in the equation).
|
209 |
You can execute this function with:
|
210 |
```python
|
211 |
+
jax_model = model.jax()
|
212 |
+
jax_model['callable'](X, jax_model['parameters'])
|
213 |
```
|
214 |
Since the parameter list is a jax array, this therefore lets you also
|
215 |
train the parameters within JAX (and is differentiable).
|
|
|
227 |
|
228 |
abs(x-y) loss
|
229 |
```python
|
230 |
+
PySRRegressor(..., loss="f(x, y) = abs(x - y)^1.5")
|
231 |
```
|
232 |
Note that the function name doesn't matter:
|
233 |
```python
|
234 |
+
PySRRegressor(..., loss="loss(x, y) = abs(x * y)")
|
235 |
```
|
236 |
With weights:
|
237 |
```python
|
238 |
+
model = PySRRegressor(..., loss="myloss(x, y, w) = w * abs(x - y)")
|
239 |
+
model.fit(..., weights=weights)
|
240 |
```
|
241 |
Weights can be used in arbitrary ways:
|
242 |
```python
|
243 |
+
model = PySRRegressor(..., weights=weights, loss="myloss(x, y, w) = abs(x - y)^2/w^2")
|
244 |
+
model.fit(..., weights=weights)
|
245 |
```
|
246 |
Built-in loss (faster) (see [losses](https://astroautomata.com/SymbolicRegression.jl/dev/losses/)).
|
247 |
This one computes the L3 norm:
|
248 |
```python
|
249 |
+
PySRRegressor(..., loss="LPDistLoss{3}()")
|
250 |
```
|
251 |
Can also uses these losses for weighted (weighted-average):
|
252 |
```python
|
253 |
+
model = PySRRegressor(..., weights=weights, loss="LPDistLoss{3}()")
|
254 |
+
model.fit(..., weights=weights)
|
255 |
```
|
pysr/sklearn.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1 |
-
from pysr import pysr, best_row
|
2 |
from sklearn.base import BaseEstimator, RegressorMixin
|
3 |
import inspect
|
4 |
import pandas as pd
|
@@ -94,14 +94,22 @@ class PySRRegressor(BaseEstimator, RegressorMixin):
|
|
94 |
return self
|
95 |
|
96 |
def predict(self, X):
|
97 |
-
|
98 |
-
np_format = equation_row["lambda_format"]
|
99 |
-
|
100 |
return np_format(X)
|
101 |
|
|
|
|
|
102 |
|
103 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
104 |
|
|
|
|
|
105 |
_pysr_docstring_split = []
|
106 |
_start_recording = False
|
107 |
for line in inspect.getdoc(pysr).split("\n"):
|
|
|
1 |
+
from pysr import pysr, best_row, get_hof
|
2 |
from sklearn.base import BaseEstimator, RegressorMixin
|
3 |
import inspect
|
4 |
import pandas as pd
|
|
|
94 |
return self
|
95 |
|
96 |
def predict(self, X):
|
97 |
+
np_format = self.get_best()["lambda_format"]
|
|
|
|
|
98 |
return np_format(X)
|
99 |
|
100 |
+
def sympy(self):
|
101 |
+
return self.get_best()["sympy_format"]
|
102 |
|
103 |
+
def jax(self):
|
104 |
+
self.equations = get_hof(output_jax_format=True)
|
105 |
+
return self.get_best()["jax_format"]
|
106 |
+
|
107 |
+
def pytorch(self):
|
108 |
+
self.equations = get_hof(output_torch_format=True)
|
109 |
+
return self.get_best()["torch_format"]
|
110 |
|
111 |
+
|
112 |
+
# Add the docs from pysr() to PySRRegressor():
|
113 |
_pysr_docstring_split = []
|
114 |
_start_recording = False
|
115 |
for line in inspect.getdoc(pysr).split("\n"):
|