MilesCranmer commited on
Commit
144e3ff
·
unverified ·
1 Parent(s): c3293a8

fix: symbolic numbers in jax

Browse files
Files changed (2) hide show
  1. pysr/export_jax.py +3 -1
  2. pysr/test/test_jax.py +36 -8
pysr/export_jax.py CHANGED
@@ -55,7 +55,9 @@ def sympy2jaxtext(expr, parameters, symbols_in, extra_jax_mappings=None):
55
  if issubclass(expr.func, sympy.Float):
56
  parameters.append(float(expr))
57
  return f"parameters[{len(parameters) - 1}]"
58
- elif issubclass(expr.func, sympy.Rational):
 
 
59
  return f"{float(expr)}"
60
  elif issubclass(expr.func, sympy.Integer):
61
  return f"{int(expr)}"
 
55
  if issubclass(expr.func, sympy.Float):
56
  parameters.append(float(expr))
57
  return f"parameters[{len(parameters) - 1}]"
58
+ elif issubclass(expr.func, sympy.Rational) or issubclass(
59
+ expr.func, sympy.NumberSymbol
60
+ ):
61
  return f"{float(expr)}"
62
  elif issubclass(expr.func, sympy.Integer):
63
  return f"{int(expr)}"
pysr/test/test_jax.py CHANGED
@@ -5,27 +5,29 @@ import numpy as np
5
  import pandas as pd
6
  import sympy
7
 
 
8
  from pysr import PySRRegressor, sympy2jax
9
 
10
 
11
  class TestJAX(unittest.TestCase):
12
  def setUp(self):
13
  np.random.seed(0)
 
 
 
14
 
15
  def test_sympy2jax(self):
16
- from jax import numpy as jnp
17
  from jax import random
18
 
19
  x, y, z = sympy.symbols("x y z")
20
  cosx = 1.0 * sympy.cos(x) + y
21
  key = random.PRNGKey(0)
22
  X = random.normal(key, (1000, 2))
23
- true = 1.0 * jnp.cos(X[:, 0]) + X[:, 1]
24
  f, params = sympy2jax(cosx, [x, y, z])
25
- self.assertTrue(jnp.all(jnp.isclose(f(X, params), true)).item())
26
 
27
  def test_pipeline_pandas(self):
28
- from jax import numpy as jnp
29
 
30
  X = pd.DataFrame(np.random.randn(100, 10))
31
  y = np.ones(X.shape[0])
@@ -52,14 +54,12 @@ class TestJAX(unittest.TestCase):
52
  jformat = model.jax()
53
 
54
  np.testing.assert_almost_equal(
55
- np.array(jformat["callable"](jnp.array(X), jformat["parameters"])),
56
  np.square(np.cos(X.values[:, 1])), # Select feature 1
57
  decimal=3,
58
  )
59
 
60
  def test_pipeline(self):
61
- from jax import numpy as jnp
62
-
63
  X = np.random.randn(100, 10)
64
  y = np.ones(X.shape[0])
65
  model = PySRRegressor(progress=False, max_evals=10000, output_jax_format=True)
@@ -81,11 +81,39 @@ class TestJAX(unittest.TestCase):
81
  jformat = model.jax()
82
 
83
  np.testing.assert_almost_equal(
84
- np.array(jformat["callable"](jnp.array(X), jformat["parameters"])),
85
  np.square(np.cos(X[:, 1])), # Select feature 1
86
  decimal=3,
87
  )
88
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
  def test_feature_selection_custom_operators(self):
90
  rstate = np.random.RandomState(0)
91
  X = pd.DataFrame({f"k{i}": rstate.randn(2000) for i in range(10, 21)})
 
5
  import pandas as pd
6
  import sympy
7
 
8
+ import pysr
9
  from pysr import PySRRegressor, sympy2jax
10
 
11
 
12
  class TestJAX(unittest.TestCase):
13
  def setUp(self):
14
  np.random.seed(0)
15
+ from jax import numpy as jnp
16
+
17
+ self.jnp = jnp
18
 
19
  def test_sympy2jax(self):
 
20
  from jax import random
21
 
22
  x, y, z = sympy.symbols("x y z")
23
  cosx = 1.0 * sympy.cos(x) + y
24
  key = random.PRNGKey(0)
25
  X = random.normal(key, (1000, 2))
26
+ true = 1.0 * self.jnp.cos(X[:, 0]) + X[:, 1]
27
  f, params = sympy2jax(cosx, [x, y, z])
28
+ self.assertTrue(self.jnp.all(self.jnp.isclose(f(X, params), true)).item())
29
 
30
  def test_pipeline_pandas(self):
 
31
 
32
  X = pd.DataFrame(np.random.randn(100, 10))
33
  y = np.ones(X.shape[0])
 
54
  jformat = model.jax()
55
 
56
  np.testing.assert_almost_equal(
57
+ np.array(jformat["callable"](self.jnp.array(X), jformat["parameters"])),
58
  np.square(np.cos(X.values[:, 1])), # Select feature 1
59
  decimal=3,
60
  )
61
 
62
  def test_pipeline(self):
 
 
63
  X = np.random.randn(100, 10)
64
  y = np.ones(X.shape[0])
65
  model = PySRRegressor(progress=False, max_evals=10000, output_jax_format=True)
 
81
  jformat = model.jax()
82
 
83
  np.testing.assert_almost_equal(
84
+ np.array(jformat["callable"](self.jnp.array(X), jformat["parameters"])),
85
  np.square(np.cos(X[:, 1])), # Select feature 1
86
  decimal=3,
87
  )
88
 
89
+ def test_avoid_simplification(self):
90
+ ex = pysr.export_sympy.pysr2sympy(
91
+ "square(exp(sign(0.44796443))) + 1.5 * x1",
92
+ feature_names_in=["x1"],
93
+ extra_sympy_mappings={"square": lambda x: x**2},
94
+ )
95
+ f, params = pysr.export_jax.sympy2jax(ex, [sympy.symbols("x1")])
96
+ key = np.random.RandomState(0)
97
+ X = key.randn(10, 1)
98
+ np.testing.assert_almost_equal(
99
+ np.array(f(self.jnp.array(X), params)),
100
+ np.square(np.exp(np.sign(0.44796443))) + 1.5 * X[:, 0],
101
+ decimal=3,
102
+ )
103
+
104
+ def test_issue_656(self):
105
+ import sympy
106
+
107
+ E_plus_x1 = sympy.exp(1) + sympy.symbols("x1")
108
+ f, params = pysr.export_jax.sympy2jax(E_plus_x1, [sympy.symbols("x1")])
109
+ key = np.random.RandomState(0)
110
+ X = key.randn(10, 1)
111
+ np.testing.assert_almost_equal(
112
+ np.array(f(self.jnp.array(X), params)),
113
+ np.exp(1) + X[:, 0],
114
+ decimal=3,
115
+ )
116
+
117
  def test_feature_selection_custom_operators(self):
118
  rstate = np.random.RandomState(0)
119
  X = pd.DataFrame({f"k{i}": rstate.randn(2000) for i in range(10, 21)})