Mark Kittisopikul commited on
Commit
1ec1c46
1 Parent(s): 7e5102a

Set JULIA_PROJECT before loading pyjulia

Browse files

PyCall.jl only needs to be installed in the pysr environment.

Files changed (2) hide show
  1. pysr/julia_helpers.py +10 -3
  2. pysr/sr.py +1 -1
pysr/julia_helpers.py CHANGED
@@ -12,13 +12,17 @@ def install(julia_project=None, quiet=False): # pragma: no cover
12
 
13
  Also updates the local Julia registry.
14
  """
 
 
 
 
 
15
  import julia
16
 
17
  julia.install(quiet=quiet)
18
 
19
- julia_project, is_shared = _get_julia_project(julia_project)
20
 
21
- Main = init_julia()
22
  Main.eval("using Pkg")
23
 
24
  io = "devnull" if quiet else "stderr"
@@ -72,10 +76,13 @@ def is_julia_version_greater_eq(Main, version="1.6"):
72
  return Main.eval(f'VERSION >= v"{version}"')
73
 
74
 
75
- def init_julia():
76
  """Initialize julia binary, turning off compiled modules if needed."""
77
  from julia.core import JuliaInfo, UnsupportedPythonError
78
 
 
 
 
79
  try:
80
  info = JuliaInfo.load(julia="julia")
81
  except FileNotFoundError:
 
12
 
13
  Also updates the local Julia registry.
14
  """
15
+
16
+ # Set JULIA_PROJECT so that we install in the pysr environment
17
+ julia_project, is_shared = _get_julia_project(julia_project)
18
+ os.environ["JULIA_PROJECT"] = "@" + julia_project if is_shared else julia_project
19
+
20
  import julia
21
 
22
  julia.install(quiet=quiet)
23
 
 
24
 
25
+ Main = init_julia(julia_project)
26
  Main.eval("using Pkg")
27
 
28
  io = "devnull" if quiet else "stderr"
 
76
  return Main.eval(f'VERSION >= v"{version}"')
77
 
78
 
79
+ def init_julia(julia_project=None):
80
  """Initialize julia binary, turning off compiled modules if needed."""
81
  from julia.core import JuliaInfo, UnsupportedPythonError
82
 
83
+ julia_project, is_shared = _get_julia_project(julia_project)
84
+ os.environ["JULIA_PROJECT"] = "@" + julia_project if is_shared else julia_project
85
+
86
  try:
87
  info = JuliaInfo.load(julia="julia")
88
  except FileNotFoundError:
pysr/sr.py CHANGED
@@ -1430,7 +1430,7 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
1430
  if multithreading:
1431
  os.environ["JULIA_NUM_THREADS"] = str(self.procs)
1432
 
1433
- Main = init_julia()
1434
 
1435
  if cluster_manager is not None:
1436
  Main.eval(f"import ClusterManagers: addprocs_{cluster_manager}")
 
1430
  if multithreading:
1431
  os.environ["JULIA_NUM_THREADS"] = str(self.procs)
1432
 
1433
+ Main = init_julia(self.julia_project)
1434
 
1435
  if cluster_manager is not None:
1436
  Main.eval(f"import ClusterManagers: addprocs_{cluster_manager}")