MilesCranmer commited on
Commit
69bfcd2
1 Parent(s): 6b46e9f

Set up `julia_kwargs` to initialize Julia binary

Browse files
Files changed (2) hide show
  1. pysr/julia_helpers.py +7 -4
  2. pysr/sr.py +24 -8
pysr/julia_helpers.py CHANGED
@@ -143,13 +143,16 @@ def _check_for_conflicting_libraries(): # pragma: no cover
143
  )
144
 
145
 
146
- def init_julia(julia_project=None, quiet=False):
147
  """Initialize julia binary, turning off compiled modules if needed."""
148
  global julia_initialized
149
 
150
  if not julia_initialized:
151
  _check_for_conflicting_libraries()
152
 
 
 
 
153
  from julia.core import JuliaInfo, UnsupportedPythonError
154
 
155
  _julia_version_assertion()
@@ -167,16 +170,16 @@ def init_julia(julia_project=None, quiet=False):
167
  if not info.is_pycall_built():
168
  raise ImportError(_import_error())
169
 
 
170
  Main = None
171
  try:
 
172
  from julia import Main as _Main
173
 
174
  Main = _Main
175
  except UnsupportedPythonError:
176
  # Static python binary, so we turn off pre-compiled modules.
177
- from julia.core import Julia
178
-
179
- jl = Julia(compiled_modules=False)
180
  from julia import Main as _Main
181
 
182
  Main = _Main
 
143
  )
144
 
145
 
146
+ def init_julia(julia_project=None, quiet=False, julia_kwargs=None):
147
  """Initialize julia binary, turning off compiled modules if needed."""
148
  global julia_initialized
149
 
150
  if not julia_initialized:
151
  _check_for_conflicting_libraries()
152
 
153
+ if julia_kwargs is None:
154
+ julia_kwargs = {}
155
+
156
  from julia.core import JuliaInfo, UnsupportedPythonError
157
 
158
  _julia_version_assertion()
 
170
  if not info.is_pycall_built():
171
  raise ImportError(_import_error())
172
 
173
+ from julia.core import Julia
174
  Main = None
175
  try:
176
+ jl = Julia(**julia_kwargs)
177
  from julia import Main as _Main
178
 
179
  Main = _Main
180
  except UnsupportedPythonError:
181
  # Static python binary, so we turn off pre-compiled modules.
182
+ jl = Julia(compiled_modules=False, **julia_kwargs)
 
 
183
  from julia import Main as _Main
184
 
185
  Main = _Main
pysr/sr.py CHANGED
@@ -581,10 +581,15 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
581
  inputting to PySR. Can help PySR fit noisy data.
582
  Default is `False`.
583
  select_k_features : int
584
- whether to run feature selection in Python using random forests,
585
- before passing to the symbolic regression code. None means no
586
- feature selection; an int means select that many features.
587
- Default is `None`.
 
 
 
 
 
588
  **kwargs : dict
589
  Supports deprecated keyword arguments. Other arguments will
590
  result in an error.
@@ -733,6 +738,7 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
733
  extra_jax_mappings=None,
734
  denoise=False,
735
  select_k_features=None,
 
736
  **kwargs,
737
  ):
738
 
@@ -827,6 +833,7 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
827
  # Pre-modelling transformation
828
  self.denoise = denoise
829
  self.select_k_features = select_k_features
 
830
 
831
  # Once all valid parameters have been assigned handle the
832
  # deprecated kwargs
@@ -1259,6 +1266,17 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
1259
  + len(packed_modified_params["unary_operators"])
1260
  > 0
1261
  )
 
 
 
 
 
 
 
 
 
 
 
1262
  return packed_modified_params
1263
 
1264
  def _validate_and_set_fit_params(self, X, y, Xresampled, weights, variable_names):
@@ -1469,13 +1487,11 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
1469
  batch_size = mutated_params["batch_size"]
1470
  update_verbosity = mutated_params["update_verbosity"]
1471
  progress = mutated_params["progress"]
 
1472
 
1473
  # Start julia backend processes
1474
  if Main is None:
1475
- if multithreading:
1476
- os.environ["JULIA_NUM_THREADS"] = str(self.procs)
1477
-
1478
- Main = init_julia(self.julia_project)
1479
 
1480
  if cluster_manager is not None:
1481
  cluster_manager = _load_cluster_manager(cluster_manager)
 
581
  inputting to PySR. Can help PySR fit noisy data.
582
  Default is `False`.
583
  select_k_features : int
584
+ Whether to run feature selection in Python using random forests,
585
+ before passing to the symbolic regression code. None means no
586
+ feature selection; an int means select that many features.
587
+ Default is `None`.
588
+ julia_kwargs : dict
589
+ Keyword arguments to pass to `julia.core.Julia(...)` to initialize
590
+ the Julia runtime. The default, when `None`, is to set `threads` equal
591
+ to `procs`, and `optimize` to 3.
592
+ Default is `None`.
593
  **kwargs : dict
594
  Supports deprecated keyword arguments. Other arguments will
595
  result in an error.
 
738
  extra_jax_mappings=None,
739
  denoise=False,
740
  select_k_features=None,
741
+ julia_kwargs=None,
742
  **kwargs,
743
  ):
744
 
 
833
  # Pre-modelling transformation
834
  self.denoise = denoise
835
  self.select_k_features = select_k_features
836
+ self.julia_kwargs = julia_kwargs
837
 
838
  # Once all valid parameters have been assigned handle the
839
  # deprecated kwargs
 
1266
  + len(packed_modified_params["unary_operators"])
1267
  > 0
1268
  )
1269
+
1270
+ julia_kwargs = {}
1271
+ if self.julia_kwargs is not None:
1272
+ for key, value in self.julia_kwargs.items():
1273
+ julia_kwargs[key] = value
1274
+ if "optimize" not in julia_kwargs:
1275
+ julia_kwargs["optimize"] = 3
1276
+ if "threads" not in julia_kwargs and packed_modified_params["multithreading"]:
1277
+ julia_kwargs["threads"] = self.procs
1278
+ packed_modified_params["julia_kwargs"] = julia_kwargs
1279
+
1280
  return packed_modified_params
1281
 
1282
  def _validate_and_set_fit_params(self, X, y, Xresampled, weights, variable_names):
 
1487
  batch_size = mutated_params["batch_size"]
1488
  update_verbosity = mutated_params["update_verbosity"]
1489
  progress = mutated_params["progress"]
1490
+ julia_kwargs = mutated_params["julia_kwargs"]
1491
 
1492
  # Start julia backend processes
1493
  if Main is None:
1494
+ Main = init_julia(self.julia_project, julia_kwargs=julia_kwargs)
 
 
 
1495
 
1496
  if cluster_manager is not None:
1497
  cluster_manager = _load_cluster_manager(cluster_manager)