Spaces:
Running
Running
MilesCranmer
commited on
Commit
•
e9b2ee8
1
Parent(s):
cbb41ef
Only install packages as required
Browse files- pysr/julia_extensions.py +25 -0
- pysr/juliapkg.json +0 -8
- pysr/sr.py +7 -0
pysr/julia_extensions.py
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""This file installs and loads extensions for SymbolicRegression."""
|
2 |
+
from .julia_import import jl
|
3 |
+
|
4 |
+
|
5 |
+
def load_required_packages(*, turbo=False, enable_autodiff=False):
|
6 |
+
if turbo:
|
7 |
+
load_package("LoopVectorization")
|
8 |
+
if enable_autodiff:
|
9 |
+
load_package("Zygote")
|
10 |
+
if cluster_manager is not None:
|
11 |
+
load_package("ClusterManagers")
|
12 |
+
|
13 |
+
|
14 |
+
def load_package(package_name):
|
15 |
+
jl.seval(f"""
|
16 |
+
try
|
17 |
+
using {package_name}
|
18 |
+
catch e
|
19 |
+
isa(e, ArgumentError) || throw(e)
|
20 |
+
using Pkg: Pkg
|
21 |
+
Pkg.add("{package_name}")
|
22 |
+
using {package_name}
|
23 |
+
end
|
24 |
+
""")
|
25 |
+
return None
|
pysr/juliapkg.json
CHANGED
@@ -5,17 +5,9 @@
|
|
5 |
"uuid": "8254be44-1295-4e6a-a16d-46603ac705cb",
|
6 |
"version": "=0.24.0"
|
7 |
},
|
8 |
-
"ClusterManagers": {
|
9 |
-
"uuid": "34f1f09b-3a8b-5176-ab39-66d58a4d544e",
|
10 |
-
"version": "0.4"
|
11 |
-
},
|
12 |
"Serialization": {
|
13 |
"uuid": "9e88b42a-f829-5b0c-bbe9-9e923198166b",
|
14 |
"version": "1"
|
15 |
-
},
|
16 |
-
"Zygote": {
|
17 |
-
"uuid": "e88e6eb3-aa80-5325-afca-941959d7151f",
|
18 |
-
"version": "0.6"
|
19 |
}
|
20 |
}
|
21 |
}
|
|
|
5 |
"uuid": "8254be44-1295-4e6a-a16d-46603ac705cb",
|
6 |
"version": "=0.24.0"
|
7 |
},
|
|
|
|
|
|
|
|
|
8 |
"Serialization": {
|
9 |
"uuid": "9e88b42a-f829-5b0c-bbe9-9e923198166b",
|
10 |
"version": "1"
|
|
|
|
|
|
|
|
|
11 |
}
|
12 |
}
|
13 |
}
|
pysr/sr.py
CHANGED
@@ -32,6 +32,7 @@ from .export_numpy import sympy2numpy
|
|
32 |
from .export_sympy import assert_valid_sympy_symbol, create_sympy_symbols, pysr2sympy
|
33 |
from .export_torch import sympy2torch
|
34 |
from .feature_selection import run_feature_selection
|
|
|
35 |
from .julia_helpers import (
|
36 |
PythonCall,
|
37 |
_escape_filename,
|
@@ -1605,6 +1606,12 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
|
|
1605 |
else "nothing"
|
1606 |
)
|
1607 |
|
|
|
|
|
|
|
|
|
|
|
|
|
1608 |
mutation_weights = SymbolicRegression.MutationWeights(
|
1609 |
mutate_constant=self.weight_mutate_constant,
|
1610 |
mutate_operator=self.weight_mutate_operator,
|
|
|
32 |
from .export_sympy import assert_valid_sympy_symbol, create_sympy_symbols, pysr2sympy
|
33 |
from .export_torch import sympy2torch
|
34 |
from .feature_selection import run_feature_selection
|
35 |
+
from .julia_extensions import load_required_packages
|
36 |
from .julia_helpers import (
|
37 |
PythonCall,
|
38 |
_escape_filename,
|
|
|
1606 |
else "nothing"
|
1607 |
)
|
1608 |
|
1609 |
+
load_required_packages(
|
1610 |
+
turbo=turbo,
|
1611 |
+
enable_autodiff=enable_autodiff,
|
1612 |
+
cluster_manager=cluster_manager
|
1613 |
+
)
|
1614 |
+
|
1615 |
mutation_weights = SymbolicRegression.MutationWeights(
|
1616 |
mutate_constant=self.weight_mutate_constant,
|
1617 |
mutate_operator=self.weight_mutate_operator,
|