Spaces:
Running
Running
MilesCranmer
commited on
Commit
•
2ceb526
1
Parent(s):
66dcb6d
Add JAX export functionality
Browse files- pysr/__init__.py +1 -0
- pysr/export.py +158 -0
- pysr/sr.py +2 -2
pysr/__init__.py
CHANGED
@@ -1,2 +1,3 @@
|
|
1 |
from .sr import pysr, get_hof, best, best_tex, best_callable, best_row
|
2 |
from .feynman_problems import Problem, FeynmanProblem
|
|
|
|
1 |
from .sr import pysr, get_hof, best, best_tex, best_callable, best_row
|
2 |
from .feynman_problems import Problem, FeynmanProblem
|
3 |
+
from .export import sympy2jax
|
pysr/export.py
ADDED
@@ -0,0 +1,158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import functools as ft
|
2 |
+
import sympy
|
3 |
+
import string
|
4 |
+
import random
|
5 |
+
|
6 |
+
try:
|
7 |
+
import jax
|
8 |
+
from jax import numpy as jnp
|
9 |
+
from jax.scipy import special as jsp
|
10 |
+
|
11 |
+
# Special since need to reduce arguments.
|
12 |
+
MUL = 0
|
13 |
+
ADD = 1
|
14 |
+
|
15 |
+
_jnp_func_lookup = {
|
16 |
+
sympy.Mul: MUL,
|
17 |
+
sympy.Add: ADD,
|
18 |
+
sympy.div: "jnp.div",
|
19 |
+
sympy.Abs: "jnp.abs",
|
20 |
+
sympy.sign: "jnp.sign",
|
21 |
+
# Note: May raise error for ints.
|
22 |
+
sympy.ceiling: "jnp.ceil",
|
23 |
+
sympy.floor: "jnp.floor",
|
24 |
+
sympy.log: "jnp.log",
|
25 |
+
sympy.exp: "jnp.exp",
|
26 |
+
sympy.sqrt: "jnp.sqrt",
|
27 |
+
sympy.cos: "jnp.cos",
|
28 |
+
sympy.acos: "jnp.acos",
|
29 |
+
sympy.sin: "jnp.sin",
|
30 |
+
sympy.asin: "jnp.asin",
|
31 |
+
sympy.tan: "jnp.tan",
|
32 |
+
sympy.atan: "jnp.atan",
|
33 |
+
sympy.atan2: "jnp.atan2",
|
34 |
+
# Note: Also may give NaN for complex results.
|
35 |
+
sympy.cosh: "jnp.cosh",
|
36 |
+
sympy.acosh: "jnp.acosh",
|
37 |
+
sympy.sinh: "jnp.sinh",
|
38 |
+
sympy.asinh: "jnp.asinh",
|
39 |
+
sympy.tanh: "jnp.tanh",
|
40 |
+
sympy.atanh: "jnp.atanh",
|
41 |
+
sympy.Pow: "jnp.power",
|
42 |
+
sympy.re: "jnp.real",
|
43 |
+
sympy.im: "jnp.imag",
|
44 |
+
sympy.arg: "jnp.angle",
|
45 |
+
# Note: May raise error for ints and complexes
|
46 |
+
sympy.erf: "jsp.erf",
|
47 |
+
sympy.erfc: "jsp.erfc",
|
48 |
+
sympy.LessThan: "jnp.le",
|
49 |
+
sympy.GreaterThan: "jnp.ge",
|
50 |
+
sympy.And: "jnp.logical_and",
|
51 |
+
sympy.Or: "jnp.logical_or",
|
52 |
+
sympy.Not: "jnp.logical_not",
|
53 |
+
sympy.Max: "jnp.max",
|
54 |
+
sympy.Min: "jnp.min",
|
55 |
+
sympy.Mod: "jnp.mod",
|
56 |
+
sympy.round: 'jnp.round'
|
57 |
+
}
|
58 |
+
except ImportError:
|
59 |
+
...
|
60 |
+
|
61 |
+
def sympy2jaxtext(expr, parameters, symbols_in):
|
62 |
+
if issubclass(expr.func, sympy.Float):
|
63 |
+
parameters.append(float(expr))
|
64 |
+
return f"parameters[{len(parameters) - 1}]"
|
65 |
+
elif issubclass(expr.func, sympy.Integer):
|
66 |
+
return "{int(expr)}"
|
67 |
+
elif issubclass(expr.func, sympy.Symbol):
|
68 |
+
return f"X[:, {[i for i in range(len(symbols_in)) if symbols_in[i] == expr][0]}]"
|
69 |
+
else:
|
70 |
+
_func = _jnp_func_lookup[expr.func]
|
71 |
+
args = [sympy2jaxtext(arg, parameters, symbols_in) for arg in expr.args]
|
72 |
+
if _func == MUL:
|
73 |
+
return ' * '.join(['(' + arg + ')' for arg in args])
|
74 |
+
elif _func == ADD:
|
75 |
+
return ' + '.join(['(' + arg + ')' for arg in args])
|
76 |
+
else:
|
77 |
+
return f'{_func}({", ".join(args)})'
|
78 |
+
|
79 |
+
def sympy2jax(equation, symbols_in):
|
80 |
+
"""Returns a function f and its parameters;
|
81 |
+
the function takes an input matrix, and a list of arguments:
|
82 |
+
f(X, parameters)
|
83 |
+
where the parameters appear in the JAX equation.
|
84 |
+
|
85 |
+
# Examples:
|
86 |
+
|
87 |
+
Let's create a function in SymPy:
|
88 |
+
```python
|
89 |
+
x, y = symbols('x y')
|
90 |
+
cosx = 1.0 * sympy.cos(x) + 3.2 * y
|
91 |
+
```
|
92 |
+
Let's get the JAX version. We pass the equation, and
|
93 |
+
the symbols required.
|
94 |
+
```python
|
95 |
+
f, params = sympy2jax(cosx, [x, y])
|
96 |
+
```
|
97 |
+
The order you supply the symbols is the same order
|
98 |
+
you should supply the features when calling
|
99 |
+
the function `f` (shape `[nrows, nfeatures]`).
|
100 |
+
In this case, features=2 for x and y.
|
101 |
+
The `params` in this case will be
|
102 |
+
`jnp.array([1.0, 3.2])`. You pass these parameters
|
103 |
+
when calling the function, which will let you change them
|
104 |
+
and take gradients.
|
105 |
+
|
106 |
+
Let's generate some JAX data to pass:
|
107 |
+
```python
|
108 |
+
key = random.PRNGKey(0)
|
109 |
+
X = random.normal(key, (10, 2))
|
110 |
+
```
|
111 |
+
|
112 |
+
We can call the function with:
|
113 |
+
```python
|
114 |
+
f(X, params)
|
115 |
+
|
116 |
+
#> DeviceArray([-2.6080756 , 0.72633684, -6.7557726 , -0.2963162 ,
|
117 |
+
# 6.6014843 , 5.032483 , -0.810931 , 4.2520013 ,
|
118 |
+
# 3.5427954 , -2.7479894 ], dtype=float32)
|
119 |
+
```
|
120 |
+
|
121 |
+
We can take gradients with respect
|
122 |
+
to the parameters for each row with JAX
|
123 |
+
gradient parameters now:
|
124 |
+
```python
|
125 |
+
jac_f = jax.jacobian(f, argnums=1)
|
126 |
+
jac_f(X, params)
|
127 |
+
|
128 |
+
#> DeviceArray([[ 0.49364874, -0.9692889 ],
|
129 |
+
# [ 0.8283714 , -0.0318858 ],
|
130 |
+
# [-0.7447336 , -1.8784496 ],
|
131 |
+
# [ 0.70755106, -0.3137085 ],
|
132 |
+
# [ 0.944834 , 1.767703 ],
|
133 |
+
# [ 0.51673377, 1.4111717 ],
|
134 |
+
# [ 0.87347716, -0.52637756],
|
135 |
+
# [ 0.8760679 , 1.0549792 ],
|
136 |
+
# [ 0.9961824 , 0.79581654],
|
137 |
+
# [-0.88465923, -0.5822907 ]], dtype=float32)
|
138 |
+
```
|
139 |
+
|
140 |
+
We can also JIT-compile our function:
|
141 |
+
```python
|
142 |
+
compiled_f = jax.jit(f)
|
143 |
+
compiled_f(X, params)
|
144 |
+
|
145 |
+
#> DeviceArray([-2.6080756 , 0.72633684, -6.7557726 , -0.2963162 ,
|
146 |
+
# 6.6014843 , 5.032483 , -0.810931 , 4.2520013 ,
|
147 |
+
# 3.5427954 , -2.7479894 ], dtype=float32)
|
148 |
+
```
|
149 |
+
"""
|
150 |
+
parameters = []
|
151 |
+
functional_form_text = sympy2jaxtext(equation, parameters, symbols_in)
|
152 |
+
hash_string = 'A' + str(hash([equation, symbols_in]))
|
153 |
+
text = f"def {hash_string}(X, parameters):\n"
|
154 |
+
text += " return "
|
155 |
+
text += functional_form_text
|
156 |
+
ldict = {}
|
157 |
+
exec(text, globals(), ldict)
|
158 |
+
return ldict['f'], jnp.array(parameters)
|
pysr/sr.py
CHANGED
@@ -47,8 +47,8 @@ sympy_mappings = {
|
|
47 |
'erf': lambda x : sympy.erf(x),
|
48 |
'erfc': lambda x : sympy.erfc(x),
|
49 |
'logm': lambda x : sympy.log(abs(x)),
|
50 |
-
'logm10':lambda x : sympy.log(abs(x),
|
51 |
-
'logm2': lambda x : sympy.log(abs(x),
|
52 |
'log1p': lambda x : sympy.log(x + 1),
|
53 |
'floor': lambda x : sympy.floor(x),
|
54 |
'ceil': lambda x : sympy.ceil(x),
|
|
|
47 |
'erf': lambda x : sympy.erf(x),
|
48 |
'erfc': lambda x : sympy.erfc(x),
|
49 |
'logm': lambda x : sympy.log(abs(x)),
|
50 |
+
'logm10':lambda x : sympy.log(abs(x), 10),
|
51 |
+
'logm2': lambda x : sympy.log(abs(x), 2),
|
52 |
'log1p': lambda x : sympy.log(x + 1),
|
53 |
'floor': lambda x : sympy.floor(x),
|
54 |
'ceil': lambda x : sympy.ceil(x),
|