MilesCranmer commited on
Commit
9bfcbfa
1 Parent(s): 84e4a47

Add tests for jax/torch format

Browse files
Files changed (2) hide show
  1. test/test_jax.py +23 -1
  2. test/test_torch.py +23 -1
test/test_jax.py CHANGED
@@ -1,6 +1,7 @@
1
  import unittest
2
  import numpy as np
3
- from pysr import sympy2jax
 
4
  from jax import numpy as jnp
5
  from jax import random
6
  from jax import grad
@@ -15,3 +16,24 @@ class TestJAX(unittest.TestCase):
15
  true = 1.0 * jnp.cos(X[:, 0]) + X[:, 1]
16
  f, params = sympy2jax(cosx, [x, y, z])
17
  self.assertTrue(jnp.all(jnp.isclose(f(X, params), true)).item())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import unittest
2
  import numpy as np
3
+ from pysr import sympy2jax, get_hof
4
+ import pandas as pd
5
  from jax import numpy as jnp
6
  from jax import random
7
  from jax import grad
 
16
  true = 1.0 * jnp.cos(X[:, 0]) + X[:, 1]
17
  f, params = sympy2jax(cosx, [x, y, z])
18
  self.assertTrue(jnp.all(jnp.isclose(f(X, params), true)).item())
19
+ def test_pipeline(self):
20
+ X = np.random.randn(100, 2)
21
+ equations = pd.DataFrame({
22
+ 'Equation': ['1.0', 'cos(x0)', 'square(cos(x0))'],
23
+ 'MSE': [1.0, 0.1, 1e-5],
24
+ 'Complexity': [1, 2, 3]
25
+ })
26
+
27
+ equations['Complexity MSE Equation'.split(' ')].to_csv(
28
+ 'equation_file.csv.bkup', sep='|')
29
+
30
+ equations = get_hof(
31
+ 'equation_file.csv', n_features=2, variables_names='x0 x1'.split(' '),
32
+ extra_sympy_mappings={}, output_jax_format=True,
33
+ multioutput=False, nout=1)
34
+
35
+ jformat = equations.iloc[-1].jax_format
36
+ np.testing.assert_almost_equal(
37
+ np.array(jformat['callable'](jnp.array(X), jformat['parameters'])),
38
+ np.square(np.cos(X[:, 0]))
39
+ )
test/test_torch.py CHANGED
@@ -1,6 +1,7 @@
1
  import unittest
2
  import numpy as np
3
- from pysr import sympy2torch
 
4
  import torch
5
  import sympy
6
 
@@ -14,3 +15,24 @@ class TestTorch(unittest.TestCase):
14
  self.assertTrue(
15
  np.all(np.isclose(torch_module(X).detach().numpy(), true.detach().numpy()))
16
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import unittest
2
  import numpy as np
3
+ import pandas as pd
4
+ from pysr import sympy2torch, get_hof
5
  import torch
6
  import sympy
7
 
 
15
  self.assertTrue(
16
  np.all(np.isclose(torch_module(X).detach().numpy(), true.detach().numpy()))
17
  )
18
+ def test_pipeline(self):
19
+ X = np.random.randn(100, 2)
20
+ equations = pd.DataFrame({
21
+ 'Equation': ['1.0', 'cos(x0)', 'square(cos(x0))'],
22
+ 'MSE': [1.0, 0.1, 1e-5],
23
+ 'Complexity': [1, 2, 3]
24
+ })
25
+
26
+ equations['Complexity MSE Equation'.split(' ')].to_csv(
27
+ 'equation_file.csv.bkup', sep='|')
28
+
29
+ equations = get_hof(
30
+ 'equation_file.csv', n_features=2, variables_names='x0 x1'.split(' '),
31
+ extra_sympy_mappings={}, output_torch_format=True,
32
+ multioutput=False, nout=1)
33
+
34
+ tformat = equations.iloc[-1].torch_format
35
+ np.testing.assert_almost_equal(
36
+ tformat(torch.tensor(X)).detach().numpy(),
37
+ np.square(np.cos(X[:, 0]))
38
+ )