Spaces:
Sleeping
Sleeping
MilesCranmer
commited on
Commit
•
41e5fd5
1
Parent(s):
d0788ef
Only test jax on non-Windows systems
Browse files- .github/workflows/CI.yml +7 -1
- test/test.py +1 -13
- test/test_jax.py +15 -0
.github/workflows/CI.yml
CHANGED
@@ -58,9 +58,15 @@ jobs:
|
|
58 |
run: |
|
59 |
python -m pip install --upgrade pip
|
60 |
pip install -r requirements.txt
|
61 |
-
pip install jax jaxlib # (optional import)
|
62 |
python setup.py install
|
|
|
|
|
|
|
63 |
shell: bash
|
64 |
- name: "Run tests"
|
65 |
run: python test/test.py
|
66 |
shell: bash
|
|
|
|
|
|
|
|
|
|
58 |
run: |
|
59 |
python -m pip install --upgrade pip
|
60 |
pip install -r requirements.txt
|
|
|
61 |
python setup.py install
|
62 |
+
- name: "Install JAX"
|
63 |
+
if: matrix.os != 'windows-latest'
|
64 |
+
run: pip install jax jaxlib # (optional import)
|
65 |
shell: bash
|
66 |
- name: "Run tests"
|
67 |
run: python test/test.py
|
68 |
shell: bash
|
69 |
+
- name: "Run JAX tests"
|
70 |
+
if: matrix.os != 'windows-latest'
|
71 |
+
run: python test/test_jax.py
|
72 |
+
shell: bash
|
test/test.py
CHANGED
@@ -1,8 +1,5 @@
|
|
1 |
import numpy as np
|
2 |
-
from pysr import pysr
|
3 |
-
from jax import numpy as jnp
|
4 |
-
from jax import random
|
5 |
-
from jax import grad
|
6 |
import sympy
|
7 |
X = np.random.randn(100, 5)
|
8 |
|
@@ -33,12 +30,3 @@ equations = pysr(X, y,
|
|
33 |
user_input=False)
|
34 |
print(equations)
|
35 |
assert equations.iloc[-1]['MSE'] < 1e-4
|
36 |
-
|
37 |
-
print("Test 4 - text JAX export")
|
38 |
-
x, y, z = sympy.symbols('x y z')
|
39 |
-
cosx = 1.0 * sympy.cos(x) + y
|
40 |
-
key = random.PRNGKey(0)
|
41 |
-
X = random.normal(key, (1000, 2))
|
42 |
-
true = 1.0 * jnp.cos(X[:, 0]) + X[:, 1]
|
43 |
-
f, params = sympy2jax(cosx, [x])
|
44 |
-
assert jnp.all(jnp.isclose(f(X, params), true)).item()
|
|
|
1 |
import numpy as np
|
2 |
+
from pysr import pysr
|
|
|
|
|
|
|
3 |
import sympy
|
4 |
X = np.random.randn(100, 5)
|
5 |
|
|
|
30 |
user_input=False)
|
31 |
print(equations)
|
32 |
assert equations.iloc[-1]['MSE'] < 1e-4
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
test/test_jax.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
from pysr import pysr, sympy2jax
|
3 |
+
from jax import numpy as jnp
|
4 |
+
from jax import random
|
5 |
+
from jax import grad
|
6 |
+
import sympy
|
7 |
+
|
8 |
+
print("Test JAX 1 - test export")
|
9 |
+
x, y, z = sympy.symbols('x y z')
|
10 |
+
cosx = 1.0 * sympy.cos(x) + y
|
11 |
+
key = random.PRNGKey(0)
|
12 |
+
X = random.normal(key, (1000, 2))
|
13 |
+
true = 1.0 * jnp.cos(X[:, 0]) + X[:, 1]
|
14 |
+
f, params = sympy2jax(cosx, [x])
|
15 |
+
assert jnp.all(jnp.isclose(f(X, params), true)).item()
|