MilesCranmer commited on
Commit
41e5fd5
1 Parent(s): d0788ef

Only test jax on non-Windows systems

Browse files
Files changed (3) hide show
  1. .github/workflows/CI.yml +7 -1
  2. test/test.py +1 -13
  3. test/test_jax.py +15 -0
.github/workflows/CI.yml CHANGED
@@ -58,9 +58,15 @@ jobs:
58
  run: |
59
  python -m pip install --upgrade pip
60
  pip install -r requirements.txt
61
- pip install jax jaxlib # (optional import)
62
  python setup.py install
 
 
 
63
  shell: bash
64
  - name: "Run tests"
65
  run: python test/test.py
66
  shell: bash
 
 
 
 
 
58
  run: |
59
  python -m pip install --upgrade pip
60
  pip install -r requirements.txt
 
61
  python setup.py install
62
+ - name: "Install JAX"
63
+ if: matrix.os != 'windows-latest'
64
+ run: pip install jax jaxlib # (optional import)
65
  shell: bash
66
  - name: "Run tests"
67
  run: python test/test.py
68
  shell: bash
69
+ - name: "Run JAX tests"
70
+ if: matrix.os != 'windows-latest'
71
+ run: python test/test_jax.py
72
+ shell: bash
test/test.py CHANGED
@@ -1,8 +1,5 @@
1
  import numpy as np
2
- from pysr import pysr, sympy2jax
3
- from jax import numpy as jnp
4
- from jax import random
5
- from jax import grad
6
  import sympy
7
  X = np.random.randn(100, 5)
8
 
@@ -33,12 +30,3 @@ equations = pysr(X, y,
33
  user_input=False)
34
  print(equations)
35
  assert equations.iloc[-1]['MSE'] < 1e-4
36
-
37
- print("Test 4 - text JAX export")
38
- x, y, z = sympy.symbols('x y z')
39
- cosx = 1.0 * sympy.cos(x) + y
40
- key = random.PRNGKey(0)
41
- X = random.normal(key, (1000, 2))
42
- true = 1.0 * jnp.cos(X[:, 0]) + X[:, 1]
43
- f, params = sympy2jax(cosx, [x])
44
- assert jnp.all(jnp.isclose(f(X, params), true)).item()
 
1
  import numpy as np
2
+ from pysr import pysr
 
 
 
3
  import sympy
4
  X = np.random.randn(100, 5)
5
 
 
30
  user_input=False)
31
  print(equations)
32
  assert equations.iloc[-1]['MSE'] < 1e-4
 
 
 
 
 
 
 
 
 
test/test_jax.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from pysr import pysr, sympy2jax
3
+ from jax import numpy as jnp
4
+ from jax import random
5
+ from jax import grad
6
+ import sympy
7
+
8
+ print("Test JAX 1 - test export")
9
+ x, y, z = sympy.symbols('x y z')
10
+ cosx = 1.0 * sympy.cos(x) + y
11
+ key = random.PRNGKey(0)
12
+ X = random.normal(key, (1000, 2))
13
+ true = 1.0 * jnp.cos(X[:, 0]) + X[:, 1]
14
+ f, params = sympy2jax(cosx, [x])
15
+ assert jnp.all(jnp.isclose(f(X, params), true)).item()