MilesCranmer commited on
Commit
6d58816
1 Parent(s): 49b163d

Refactor backend loading

Browse files
Files changed (2) hide show
  1. pysr/julia_helpers.py +22 -0
  2. pysr/sr.py +7 -14
pysr/julia_helpers.py CHANGED
@@ -4,6 +4,7 @@ import subprocess
4
  import warnings
5
  from pathlib import Path
6
  import os
 
7
 
8
  from .version import __version__, __symbolic_regression_jl_version__
9
 
@@ -230,3 +231,24 @@ def _version_assertion():
230
  "PySR requires Julia 1.6.0 or greater. "
231
  "Please update your Julia installation."
232
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  import warnings
5
  from pathlib import Path
6
  import os
7
+ from julia.api import JuliaError
8
 
9
  from .version import __version__, __symbolic_regression_jl_version__
10
 
 
231
  "PySR requires Julia 1.6.0 or greater. "
232
  "Please update your Julia installation."
233
  )
234
+
235
+
236
+ def _load_cluster_manager(Main, cluster_manager):
237
+ Main.eval(f"import ClusterManagers: addprocs_{cluster_manager}")
238
+ return Main.eval(f"addprocs_{cluster_manager}")
239
+
240
+
241
+ def _update_julia_project(Main, julia_project, is_shared, io_arg):
242
+ try:
243
+ if is_shared:
244
+ _add_sr_to_julia_project(Main, io_arg)
245
+ Main.eval(f"Pkg.resolve({io_arg})")
246
+ except (JuliaError, RuntimeError) as e:
247
+ raise ImportError(_import_error_string(julia_project)) from e
248
+
249
+
250
+ def _load_backend(Main, julia_project):
251
+ try:
252
+ Main.eval("using SymbolicRegression")
253
+ except (JuliaError, RuntimeError) as e:
254
+ raise ImportError(_import_error_string(julia_project)) from e
pysr/sr.py CHANGED
@@ -26,8 +26,9 @@ from .julia_helpers import (
26
  _process_julia_project,
27
  is_julia_version_greater_eq,
28
  _escape_filename,
29
- _add_sr_to_julia_project,
30
- _import_error_string,
 
31
  )
32
  from .export_numpy import CallableEquation
33
  from .export_latex import generate_single_table, generate_multiple_tables, to_latex
@@ -1453,8 +1454,7 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
1453
  Main = init_julia(self.julia_project)
1454
 
1455
  if cluster_manager is not None:
1456
- Main.eval(f"import ClusterManagers: addprocs_{cluster_manager}")
1457
- cluster_manager = Main.eval(f"addprocs_{cluster_manager}")
1458
 
1459
  if not already_ran:
1460
  julia_project, is_shared = _process_julia_project(self.julia_project)
@@ -1470,16 +1470,9 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
1470
  from julia.api import JuliaError
1471
 
1472
  if self.update:
1473
- try:
1474
- if is_shared:
1475
- _add_sr_to_julia_project(Main, io_arg)
1476
- Main.eval(f"Pkg.resolve({io_arg})")
1477
- except (JuliaError, RuntimeError) as e:
1478
- raise ImportError(_import_error_string(julia_project)) from e
1479
- try:
1480
- Main.eval("using SymbolicRegression")
1481
- except (JuliaError, RuntimeError) as e:
1482
- raise ImportError(_import_error_string(julia_project)) from e
1483
 
1484
  Main.plus = Main.eval("(+)")
1485
  Main.sub = Main.eval("(-)")
 
26
  _process_julia_project,
27
  is_julia_version_greater_eq,
28
  _escape_filename,
29
+ _load_cluster_manager,
30
+ _update_julia_project,
31
+ _load_backend,
32
  )
33
  from .export_numpy import CallableEquation
34
  from .export_latex import generate_single_table, generate_multiple_tables, to_latex
 
1454
  Main = init_julia(self.julia_project)
1455
 
1456
  if cluster_manager is not None:
1457
+ cluster_manager = _load_cluster_manager(cluster_manager)
 
1458
 
1459
  if not already_ran:
1460
  julia_project, is_shared = _process_julia_project(self.julia_project)
 
1470
  from julia.api import JuliaError
1471
 
1472
  if self.update:
1473
+ _update_julia_project(Main, julia_project, is_shared, io_arg)
1474
+
1475
+ _load_backend(Main, julia_project)
 
 
 
 
 
 
 
1476
 
1477
  Main.plus = Main.eval("(+)")
1478
  Main.sub = Main.eval("(-)")