MilesCranmer commited on
Commit
e530637
1 Parent(s): 68ea1be

Force conversion to Vector

Browse files
Files changed (2) hide show
  1. pysr/julia_helpers.py +2 -0
  2. pysr/sr.py +25 -18
pysr/julia_helpers.py CHANGED
@@ -6,6 +6,8 @@ import juliapkg
6
 
7
  jl = juliacall.newmodule("PySR")
8
 
 
 
9
  juliainfo = None
10
  julia_initialized = False
11
  julia_kwargs_at_initialization = None
 
6
 
7
  jl = juliacall.newmodule("PySR")
8
 
9
+ from juliacall import convert as jl_convert
10
+
11
  juliainfo = None
12
  julia_initialized = False
13
  julia_kwargs_at_initialization = None
pysr/sr.py CHANGED
@@ -32,7 +32,7 @@ from .export_numpy import sympy2numpy
32
  from .export_sympy import assert_valid_sympy_symbol, create_sympy_symbols, pysr2sympy
33
  from .export_torch import sympy2torch
34
  from .feature_selection import run_feature_selection
35
- from .julia_helpers import _escape_filename, _load_cluster_manager, jl
36
  from .utils import (
37
  _csv_filename_to_pkl_filename,
38
  _preprocess_julia_floats,
@@ -1609,12 +1609,11 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
1609
 
1610
  # Call to Julia backend.
1611
  # See https://github.com/MilesCranmer/SymbolicRegression.jl/blob/master/src/OptionsStruct.jl
1612
- print(bin_constraints)
1613
  options = SymbolicRegression.Options(
1614
  binary_operators=jl.seval(str(binary_operators).replace("'", "")),
1615
  unary_operators=jl.seval(str(unary_operators).replace("'", "")),
1616
- bin_constraints=bin_constraints,
1617
- una_constraints=una_constraints,
1618
  complexity_of_operators=complexity_of_operators,
1619
  complexity_of_constants=self.complexity_of_constants,
1620
  complexity_of_variables=self.complexity_of_variables,
@@ -1679,18 +1678,18 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
1679
  np_dtype = {32: np.complex64, 64: np.complex128}[self.precision]
1680
 
1681
  # This converts the data into a Julia array:
1682
- Main.X = np.array(X, dtype=np_dtype).T
1683
  if len(y.shape) == 1:
1684
- Main.y = np.array(y, dtype=np_dtype)
1685
  else:
1686
- Main.y = np.array(y, dtype=np_dtype).T
1687
  if weights is not None:
1688
  if len(weights.shape) == 1:
1689
- Main.weights = np.array(weights, dtype=np_dtype)
1690
  else:
1691
- Main.weights = np.array(weights, dtype=np_dtype).T
1692
  else:
1693
- Main.weights = None
1694
 
1695
  if self.procs == 0 and not multithreading:
1696
  parallelism = "serial"
@@ -1703,22 +1702,30 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
1703
  None if parallelism in ["serial", "multithreading"] else int(self.procs)
1704
  )
1705
 
1706
- y_variable_names = None
1707
  if len(y.shape) > 1:
1708
  # We set these manually so that they respect Python's 0 indexing
1709
  # (by default Julia will use y1, y2...)
1710
- y_variable_names = [f"y{_subscriptify(i)}" for i in range(y.shape[1])]
 
 
 
 
 
 
 
 
 
1711
 
1712
  # Call to Julia backend.
1713
  # See https://github.com/MilesCranmer/SymbolicRegression.jl/blob/master/src/SymbolicRegression.jl
1714
  self.raw_julia_state_ = SymbolicRegression.equation_search(
1715
- Main.X,
1716
- Main.y,
1717
- weights=Main.weights,
1718
  niterations=int(self.niterations),
1719
- variable_names=self.feature_names_in_.tolist(),
1720
- display_variable_names=self.display_feature_names_in_.tolist(),
1721
- y_variable_names=y_variable_names,
1722
  X_units=self.X_units_,
1723
  y_units=self.y_units_,
1724
  options=options,
 
32
  from .export_sympy import assert_valid_sympy_symbol, create_sympy_symbols, pysr2sympy
33
  from .export_torch import sympy2torch
34
  from .feature_selection import run_feature_selection
35
+ from .julia_helpers import _escape_filename, _load_cluster_manager, jl, jl_convert
36
  from .utils import (
37
  _csv_filename_to_pkl_filename,
38
  _preprocess_julia_floats,
 
1609
 
1610
  # Call to Julia backend.
1611
  # See https://github.com/MilesCranmer/SymbolicRegression.jl/blob/master/src/OptionsStruct.jl
 
1612
  options = SymbolicRegression.Options(
1613
  binary_operators=jl.seval(str(binary_operators).replace("'", "")),
1614
  unary_operators=jl.seval(str(unary_operators).replace("'", "")),
1615
+ bin_constraints=jl_convert(jl.Vector, bin_constraints),
1616
+ una_constraints=jl_convert(jl.Vector, una_constraints),
1617
  complexity_of_operators=complexity_of_operators,
1618
  complexity_of_constants=self.complexity_of_constants,
1619
  complexity_of_variables=self.complexity_of_variables,
 
1678
  np_dtype = {32: np.complex64, 64: np.complex128}[self.precision]
1679
 
1680
  # This converts the data into a Julia array:
1681
+ jl_X = jl_convert(jl.Array, np.array(X, dtype=np_dtype).T)
1682
  if len(y.shape) == 1:
1683
+ jl_y = jl_convert(jl.Vector, np.array(y, dtype=np_dtype))
1684
  else:
1685
+ jl_y = jl_convert(jl.Array, np.array(y, dtype=np_dtype).T)
1686
  if weights is not None:
1687
  if len(weights.shape) == 1:
1688
+ jl_weights = jl_convert(jl.Vector, np.array(weights, dtype=np_dtype))
1689
  else:
1690
+ jl_weights = jl_convert(jl.Array, np.array(weights, dtype=np_dtype).T)
1691
  else:
1692
+ jl_weights = None
1693
 
1694
  if self.procs == 0 and not multithreading:
1695
  parallelism = "serial"
 
1702
  None if parallelism in ["serial", "multithreading"] else int(self.procs)
1703
  )
1704
 
 
1705
  if len(y.shape) > 1:
1706
  # We set these manually so that they respect Python's 0 indexing
1707
  # (by default Julia will use y1, y2...)
1708
+ jl_y_variable_names = jl_convert(
1709
+ jl.Vector, [f"y{_subscriptify(i)}" for i in range(y.shape[1])]
1710
+ )
1711
+ else:
1712
+ jl_y_variable_names = None
1713
+
1714
+ jl_feature_names = jl_convert(jl.Vector, self.feature_names_in_.tolist())
1715
+ jl_display_feature_names = jl_convert(
1716
+ jl.Vector, self.display_feature_names_in_.tolist()
1717
+ )
1718
 
1719
  # Call to Julia backend.
1720
  # See https://github.com/MilesCranmer/SymbolicRegression.jl/blob/master/src/SymbolicRegression.jl
1721
  self.raw_julia_state_ = SymbolicRegression.equation_search(
1722
+ jl_X,
1723
+ jl_y,
1724
+ weights=jl_weights,
1725
  niterations=int(self.niterations),
1726
+ variable_names=jl_feature_names,
1727
+ display_variable_names=jl_display_feature_names,
1728
+ y_variable_names=jl_y_variable_names,
1729
  X_units=self.X_units_,
1730
  y_units=self.y_units_,
1731
  options=options,