Spaces:
Running
Running
MilesCranmer
commited on
Commit
•
5841096
1
Parent(s):
42bb187
Install SR directly from GitHub repo
Browse files- pysr/sr.py +29 -38
pysr/sr.py
CHANGED
@@ -26,7 +26,7 @@ def install(julia_project=None, quiet=False): # pragma: no cover
|
|
26 |
|
27 |
julia.install(quiet=quiet)
|
28 |
|
29 |
-
julia_project = _get_julia_project(julia_project)
|
30 |
|
31 |
Main = init_julia()
|
32 |
Main.eval("using Pkg")
|
@@ -37,15 +37,10 @@ def install(julia_project=None, quiet=False): # pragma: no cover
|
|
37 |
# Can't pass IO to Julia call as it evaluates to PyObject, so just directly
|
38 |
# use Main.eval:
|
39 |
Main.eval(f'Pkg.activate("{_escape_filename(julia_project)}", {io_arg})')
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
"Could not update Julia project. "
|
45 |
-
"It is possible that your Julia registry is out-of-date. "
|
46 |
-
"To switch to an always-updated registry, "
|
47 |
-
"see the solution in https://github.com/MilesCranmer/PySR/issues/27."
|
48 |
-
) from e
|
49 |
Main.eval(f"Pkg.instantiate({io_arg})")
|
50 |
Main.eval(f"Pkg.precompile({io_arg})")
|
51 |
if not quiet:
|
@@ -78,8 +73,8 @@ sympy_mappings = {
|
|
78 |
"div": lambda x, y: x / y,
|
79 |
"mult": lambda x, y: x * y,
|
80 |
"sqrt_abs": lambda x: sympy.sqrt(abs(x)),
|
81 |
-
"square": lambda x: x**2,
|
82 |
-
"cube": lambda x: x**3,
|
83 |
"plus": lambda x, y: x + y,
|
84 |
"sub": lambda x, y: x - y,
|
85 |
"neg": lambda x: -x,
|
@@ -282,14 +277,16 @@ class CallableEquation:
|
|
282 |
|
283 |
def _get_julia_project(julia_project):
|
284 |
if julia_project is None:
|
|
|
285 |
# Create temp directory:
|
286 |
tmp_dir = tempfile.mkdtemp()
|
287 |
tmp_dir = Path(tmp_dir)
|
288 |
-
#
|
289 |
-
|
290 |
-
return tmp_dir
|
291 |
else:
|
292 |
-
|
|
|
|
|
293 |
|
294 |
|
295 |
def silence_julia_warning():
|
@@ -337,27 +334,13 @@ To silence this warning, you can run pysr.silence_julia_warning() after importin
|
|
337 |
return Main
|
338 |
|
339 |
|
340 |
-
def
|
341 |
-
|
342 |
-
|
343 |
-
|
344 |
-
|
345 |
-
|
346 |
-
""
|
347 |
-
|
348 |
-
project_toml = """
|
349 |
-
[deps]
|
350 |
-
SymbolicRegression = "8254be44-1295-4e6a-a16d-46603ac705cb"
|
351 |
-
|
352 |
-
[compat]
|
353 |
-
SymbolicRegression = "0.7.7, 0.7.8"
|
354 |
-
julia = "1.5"
|
355 |
-
"""
|
356 |
-
|
357 |
-
project_toml_path = tmp_dir / "Project.toml"
|
358 |
-
project_toml_path.write_text(project_toml)
|
359 |
-
|
360 |
-
|
361 |
class PySRRegressor(BaseEstimator, RegressorMixin):
|
362 |
def __init__(
|
363 |
self,
|
@@ -1025,7 +1008,7 @@ class PySRRegressor(BaseEstimator, RegressorMixin):
|
|
1025 |
else:
|
1026 |
X, y = _denoise(X, y, Xresampled=Xresampled)
|
1027 |
|
1028 |
-
self.julia_project = _get_julia_project(self.julia_project)
|
1029 |
|
1030 |
tmpdir = Path(tempfile.mkdtemp(dir=self.params["tempdir"]))
|
1031 |
|
@@ -1058,11 +1041,19 @@ class PySRRegressor(BaseEstimator, RegressorMixin):
|
|
1058 |
io = "devnull" if self.params["update_verbosity"] == 0 else "stderr"
|
1059 |
io_arg = f"io={io}" if is_julia_version_greater_eq(Main, "1.6") else ""
|
1060 |
|
|
|
|
|
|
|
|
|
1061 |
Main.eval(
|
1062 |
f'Pkg.activate("{_escape_filename(self.julia_project)}", {io_arg})'
|
1063 |
)
|
1064 |
from julia.api import JuliaError
|
1065 |
|
|
|
|
|
|
|
|
|
1066 |
try:
|
1067 |
if update:
|
1068 |
Main.eval(f"Pkg.resolve({io_arg})")
|
|
|
26 |
|
27 |
julia.install(quiet=quiet)
|
28 |
|
29 |
+
julia_project, is_fresh_env = _get_julia_project(julia_project)
|
30 |
|
31 |
Main = init_julia()
|
32 |
Main.eval("using Pkg")
|
|
|
37 |
# Can't pass IO to Julia call as it evaluates to PyObject, so just directly
|
38 |
# use Main.eval:
|
39 |
Main.eval(f'Pkg.activate("{_escape_filename(julia_project)}", {io_arg})')
|
40 |
+
if is_fresh_env:
|
41 |
+
# Install SymbolicRegression.jl:
|
42 |
+
_add_sr_to_julia_project(Main, io_arg)
|
43 |
+
|
|
|
|
|
|
|
|
|
|
|
44 |
Main.eval(f"Pkg.instantiate({io_arg})")
|
45 |
Main.eval(f"Pkg.precompile({io_arg})")
|
46 |
if not quiet:
|
|
|
73 |
"div": lambda x, y: x / y,
|
74 |
"mult": lambda x, y: x * y,
|
75 |
"sqrt_abs": lambda x: sympy.sqrt(abs(x)),
|
76 |
+
"square": lambda x: x ** 2,
|
77 |
+
"cube": lambda x: x ** 3,
|
78 |
"plus": lambda x, y: x + y,
|
79 |
"sub": lambda x, y: x - y,
|
80 |
"neg": lambda x: -x,
|
|
|
277 |
|
278 |
def _get_julia_project(julia_project):
|
279 |
if julia_project is None:
|
280 |
+
is_fresh_env = True
|
281 |
# Create temp directory:
|
282 |
tmp_dir = tempfile.mkdtemp()
|
283 |
tmp_dir = Path(tmp_dir)
|
284 |
+
# Will create Project.toml in temp dir:
|
285 |
+
julia_project = tmp_dir
|
|
|
286 |
else:
|
287 |
+
is_fresh_env = False
|
288 |
+
julia_project = Path(julia_project)
|
289 |
+
return julia_project, is_fresh_env
|
290 |
|
291 |
|
292 |
def silence_julia_warning():
|
|
|
334 |
return Main
|
335 |
|
336 |
|
337 |
+
def _add_sr_to_julia_project(Main, io_arg):
|
338 |
+
Main.spec = Main.PackageSpec(
|
339 |
+
name="SymbolicRegression",
|
340 |
+
url="https://github.com/MilesCranmer/SymbolicRegression.jl",
|
341 |
+
rev="v0.7.8",
|
342 |
+
)
|
343 |
+
Main.eval(f"Pkg.add(spec, {io_arg})")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
344 |
class PySRRegressor(BaseEstimator, RegressorMixin):
|
345 |
def __init__(
|
346 |
self,
|
|
|
1008 |
else:
|
1009 |
X, y = _denoise(X, y, Xresampled=Xresampled)
|
1010 |
|
1011 |
+
self.julia_project, is_fresh_env = _get_julia_project(self.julia_project)
|
1012 |
|
1013 |
tmpdir = Path(tempfile.mkdtemp(dir=self.params["tempdir"]))
|
1014 |
|
|
|
1041 |
io = "devnull" if self.params["update_verbosity"] == 0 else "stderr"
|
1042 |
io_arg = f"io={io}" if is_julia_version_greater_eq(Main, "1.6") else ""
|
1043 |
|
1044 |
+
# [deps]
|
1045 |
+
# SymbolicRegression = "8254be44-1295-4e6a-a16d-46603ac705cb"
|
1046 |
+
# "0.7.8"
|
1047 |
+
|
1048 |
Main.eval(
|
1049 |
f'Pkg.activate("{_escape_filename(self.julia_project)}", {io_arg})'
|
1050 |
)
|
1051 |
from julia.api import JuliaError
|
1052 |
|
1053 |
+
if is_fresh_env:
|
1054 |
+
# Install SymbolicRegression.jl:
|
1055 |
+
_add_sr_to_julia_project(Main, io_arg)
|
1056 |
+
|
1057 |
try:
|
1058 |
if update:
|
1059 |
Main.eval(f"Pkg.resolve({io_arg})")
|