MilesCranmer commited on
Commit
9f3b918
1 Parent(s): 88d93a1

refactor: more robust parsing of operators

Browse files
Files changed (3) hide show
  1. pysr/julia_helpers.py +2 -2
  2. pysr/sr.py +18 -3
  3. pysr/test/test.py +10 -0
pysr/julia_helpers.py CHANGED
@@ -41,8 +41,8 @@ def jl_array(x, dtype=None):
41
  return jl_convert(jl.Array[dtype], x)
42
 
43
 
44
- def jl_is_function(f):
45
- return jl.seval("op -> op isa Function")(f)
46
 
47
 
48
  def jl_serialize(obj: Any) -> NDArray[np.uint8]:
 
41
  return jl_convert(jl.Array[dtype], x)
42
 
43
 
44
+ def jl_is_function(f) -> bool:
45
+ return cast(bool, jl.seval("op -> op isa Function")(f))
46
 
47
 
48
  def jl_serialize(obj: Any) -> NDArray[np.uint8]:
pysr/sr.py CHANGED
@@ -13,7 +13,7 @@ from datetime import datetime
13
  from io import StringIO
14
  from multiprocessing import cpu_count
15
  from pathlib import Path
16
- from typing import Callable, Dict, List, Literal, Optional, Tuple, Union, cast
17
 
18
  import numpy as np
19
  import pandas as pd
@@ -44,6 +44,7 @@ from .julia_helpers import (
44
  _load_cluster_manager,
45
  jl_array,
46
  jl_deserialize,
 
47
  jl_serialize,
48
  )
49
  from .julia_import import SymbolicRegression, jl
@@ -1695,11 +1696,25 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
1695
  optimize=self.weight_optimize,
1696
  )
1697
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1698
  # Call to Julia backend.
1699
  # See https://github.com/MilesCranmer/SymbolicRegression.jl/blob/master/src/OptionsStruct.jl
1700
  options = SymbolicRegression.Options(
1701
- binary_operators=jl.seval(str(binary_operators).replace("'", "")),
1702
- unary_operators=jl.seval(str(unary_operators).replace("'", "")),
1703
  bin_constraints=jl_array(bin_constraints),
1704
  una_constraints=jl_array(una_constraints),
1705
  complexity_of_operators=complexity_of_operators,
 
13
  from io import StringIO
14
  from multiprocessing import cpu_count
15
  from pathlib import Path
16
+ from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union, cast
17
 
18
  import numpy as np
19
  import pandas as pd
 
44
  _load_cluster_manager,
45
  jl_array,
46
  jl_deserialize,
47
+ jl_is_function,
48
  jl_serialize,
49
  )
50
  from .julia_import import SymbolicRegression, jl
 
1696
  optimize=self.weight_optimize,
1697
  )
1698
 
1699
+ jl_binary_operators: list[Any] = []
1700
+ jl_unary_operators: list[Any] = []
1701
+ for input_list, output_list, name in [
1702
+ (binary_operators, jl_binary_operators, "binary"),
1703
+ (unary_operators, jl_unary_operators, "unary"),
1704
+ ]:
1705
+ for op in input_list:
1706
+ jl_op = jl.seval(op)
1707
+ if not jl_is_function(jl_op):
1708
+ raise ValueError(
1709
+ f"When building `{name}_operators`, `'{op}'` did not return a Julia function"
1710
+ )
1711
+ output_list.append(jl_op)
1712
+
1713
  # Call to Julia backend.
1714
  # See https://github.com/MilesCranmer/SymbolicRegression.jl/blob/master/src/OptionsStruct.jl
1715
  options = SymbolicRegression.Options(
1716
+ binary_operators=jl_array(jl_binary_operators, dtype=jl.Function),
1717
+ unary_operators=jl_array(jl_unary_operators, dtype=jl.Function),
1718
  bin_constraints=jl_array(bin_constraints),
1719
  una_constraints=jl_array(una_constraints),
1720
  complexity_of_operators=complexity_of_operators,
pysr/test/test.py CHANGED
@@ -431,6 +431,16 @@ class TestPipeline(unittest.TestCase):
431
  )
432
  np.testing.assert_allclose(model.predict(self.X), model3.predict(self.X))
433
 
 
 
 
 
 
 
 
 
 
 
434
 
435
  def manually_create_model(equations, feature_names=None):
436
  if feature_names is None:
 
431
  )
432
  np.testing.assert_allclose(model.predict(self.X), model3.predict(self.X))
433
 
434
+ def test_jl_function_error(self):
435
+ # TODO: Move this to better class
436
+ with self.assertRaises(ValueError) as cm:
437
+ PySRRegressor(unary_operators=["1"]).fit([[1]], [1])
438
+
439
+ self.assertIn(
440
+ "When building `unary_operators`, `'1'` did not return a Julia function",
441
+ str(cm.exception),
442
+ )
443
+
444
 
445
  def manually_create_model(equations, feature_names=None):
446
  if feature_names is None: