MilesCranmer commited on
Commit
59d9435
1 Parent(s): 6443e24

Fix extensions not being added to package env

Browse files
Files changed (2) hide show
  1. pysr/julia_extensions.py +19 -15
  2. pysr/julia_import.py +3 -0
pysr/julia_extensions.py CHANGED
@@ -1,10 +1,16 @@
1
  """This file installs and loads extensions for SymbolicRegression."""
2
 
3
- from .julia_import import jl
 
 
4
 
5
 
6
  def load_required_packages(
7
- *, turbo=False, bumper=False, enable_autodiff=False, cluster_manager=None
 
 
 
 
8
  ):
9
  if turbo:
10
  load_package("LoopVectorization", "bdcacae8-1622-11e9-2a5c-532679323890")
@@ -16,17 +22,15 @@ def load_required_packages(
16
  load_package("ClusterManagers", "34f1f09b-3a8b-5176-ab39-66d58a4d544e")
17
 
18
 
19
- def load_package(package_name, uuid):
20
- jl.seval(
21
- f"""
22
- try
23
- using {package_name}
24
- catch e
25
- isa(e, ArgumentError) || throw(e)
26
- using Pkg: Pkg
27
- Pkg.add(name="{package_name}", uuid="{uuid}")
28
- using {package_name}
29
- end
30
- """
31
- )
32
  return None
 
1
  """This file installs and loads extensions for SymbolicRegression."""
2
 
3
+ from typing import Optional
4
+
5
+ from .julia_import import Pkg, jl
6
 
7
 
8
  def load_required_packages(
9
+ *,
10
+ turbo: bool = False,
11
+ bumper: bool = False,
12
+ enable_autodiff: bool = False,
13
+ cluster_manager: Optional[str] = None,
14
  ):
15
  if turbo:
16
  load_package("LoopVectorization", "bdcacae8-1622-11e9-2a5c-532679323890")
 
22
  load_package("ClusterManagers", "34f1f09b-3a8b-5176-ab39-66d58a4d544e")
23
 
24
 
25
+ def isinstalled(uuid_s: str) -> bool:
26
+ return jl.haskey(Pkg.dependencies(), jl.Base.UUID(uuid_s))
27
+
28
+
29
+ def load_package(package_name: str, uuid_s: str) -> None:
30
+ if not isinstalled(uuid_s):
31
+ Pkg.add(name=package_name, uuid=uuid_s)
32
+
33
+ # TODO: Protect against loading the same symbol from two packages,
34
+ # maybe with a @gensym here.
35
+ jl.seval(f"using {package_name}")
 
 
36
  return None
pysr/julia_import.py CHANGED
@@ -63,3 +63,6 @@ elif autoload_extensions not in {"no", "yes", ""}:
63
 
64
  jl.seval("using SymbolicRegression")
65
  SymbolicRegression = jl.SymbolicRegression
 
 
 
 
63
 
64
  jl.seval("using SymbolicRegression")
65
  SymbolicRegression = jl.SymbolicRegression
66
+
67
+ jl.seval("using Pkg: Pkg")
68
+ Pkg = jl.Pkg