MilesCranmer commited on
Commit
e9b2ee8
1 Parent(s): cbb41ef

Only install packages as required

Browse files
Files changed (3) hide show
  1. pysr/julia_extensions.py +25 -0
  2. pysr/juliapkg.json +0 -8
  3. pysr/sr.py +7 -0
pysr/julia_extensions.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """This file installs and loads extensions for SymbolicRegression."""
2
+ from .julia_import import jl
3
+
4
+
5
+ def load_required_packages(*, turbo=False, enable_autodiff=False):
6
+ if turbo:
7
+ load_package("LoopVectorization")
8
+ if enable_autodiff:
9
+ load_package("Zygote")
10
+ if cluster_manager is not None:
11
+ load_package("ClusterManagers")
12
+
13
+
14
+ def load_package(package_name):
15
+ jl.seval(f"""
16
+ try
17
+ using {package_name}
18
+ catch e
19
+ isa(e, ArgumentError) || throw(e)
20
+ using Pkg: Pkg
21
+ Pkg.add("{package_name}")
22
+ using {package_name}
23
+ end
24
+ """)
25
+ return None
pysr/juliapkg.json CHANGED
@@ -5,17 +5,9 @@
5
  "uuid": "8254be44-1295-4e6a-a16d-46603ac705cb",
6
  "version": "=0.24.0"
7
  },
8
- "ClusterManagers": {
9
- "uuid": "34f1f09b-3a8b-5176-ab39-66d58a4d544e",
10
- "version": "0.4"
11
- },
12
  "Serialization": {
13
  "uuid": "9e88b42a-f829-5b0c-bbe9-9e923198166b",
14
  "version": "1"
15
- },
16
- "Zygote": {
17
- "uuid": "e88e6eb3-aa80-5325-afca-941959d7151f",
18
- "version": "0.6"
19
  }
20
  }
21
  }
 
5
  "uuid": "8254be44-1295-4e6a-a16d-46603ac705cb",
6
  "version": "=0.24.0"
7
  },
 
 
 
 
8
  "Serialization": {
9
  "uuid": "9e88b42a-f829-5b0c-bbe9-9e923198166b",
10
  "version": "1"
 
 
 
 
11
  }
12
  }
13
  }
pysr/sr.py CHANGED
@@ -32,6 +32,7 @@ from .export_numpy import sympy2numpy
32
  from .export_sympy import assert_valid_sympy_symbol, create_sympy_symbols, pysr2sympy
33
  from .export_torch import sympy2torch
34
  from .feature_selection import run_feature_selection
 
35
  from .julia_helpers import (
36
  PythonCall,
37
  _escape_filename,
@@ -1605,6 +1606,12 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
1605
  else "nothing"
1606
  )
1607
 
 
 
 
 
 
 
1608
  mutation_weights = SymbolicRegression.MutationWeights(
1609
  mutate_constant=self.weight_mutate_constant,
1610
  mutate_operator=self.weight_mutate_operator,
 
32
  from .export_sympy import assert_valid_sympy_symbol, create_sympy_symbols, pysr2sympy
33
  from .export_torch import sympy2torch
34
  from .feature_selection import run_feature_selection
35
+ from .julia_extensions import load_required_packages
36
  from .julia_helpers import (
37
  PythonCall,
38
  _escape_filename,
 
1606
  else "nothing"
1607
  )
1608
 
1609
+ load_required_packages(
1610
+ turbo=turbo,
1611
+ enable_autodiff=enable_autodiff,
1612
+ cluster_manager=cluster_manager
1613
+ )
1614
+
1615
  mutation_weights = SymbolicRegression.MutationWeights(
1616
  mutate_constant=self.weight_mutate_constant,
1617
  mutate_operator=self.weight_mutate_operator,