Spaces:
Sleeping
Sleeping
AutonLabTruth
commited on
Commit
•
0dfd8e3
1
Parent(s):
a9a1691
Refactored out paths and others
Browse files- pysr/sr.py +31 -22
pysr/sr.py
CHANGED
@@ -207,16 +207,7 @@ def pysr(X=None, y=None, weights=None,
|
|
207 |
if len(X.shape) == 1:
|
208 |
X = X[:, None]
|
209 |
|
210 |
-
|
211 |
-
assert len(unary_operators) + len(binary_operators) > 0
|
212 |
-
assert len(X.shape) == 2
|
213 |
-
assert len(y.shape) == 1
|
214 |
-
assert X.shape[0] == y.shape[0]
|
215 |
-
if weights is not None:
|
216 |
-
assert len(weights.shape) == 1
|
217 |
-
assert X.shape[0] == weights.shape[0]
|
218 |
-
if use_custom_variable_names:
|
219 |
-
assert len(variable_names) == X.shape[1]
|
220 |
|
221 |
if select_k_features is not None:
|
222 |
selection = run_feature_selection(X, y, select_k_features)
|
@@ -248,18 +239,8 @@ def pysr(X=None, y=None, weights=None,
|
|
248 |
y = eval(eval_str)
|
249 |
print("Running on", eval_str)
|
250 |
|
251 |
-
|
252 |
-
|
253 |
-
pkg_filename = pkg_directory / "sr.jl"
|
254 |
-
operator_filename = pkg_directory / "operators.jl"
|
255 |
-
|
256 |
-
tmpdir = Path(tempfile.mkdtemp(dir=tempdir))
|
257 |
-
hyperparam_filename = tmpdir / f'hyperparams.jl'
|
258 |
-
dataset_filename = tmpdir / f'dataset.jl'
|
259 |
-
runfile_filename = tmpdir / f'runfile.jl'
|
260 |
-
X_filename = tmpdir / "X.csv"
|
261 |
-
y_filename = tmpdir / "y.csv"
|
262 |
-
weights_filename = tmpdir / "weights.csv"
|
263 |
|
264 |
def_hyperparams = ""
|
265 |
|
@@ -463,6 +444,34 @@ const varMap = {'["' + '", "'.join(variable_names) + '"]'}"""
|
|
463 |
return get_hof()
|
464 |
|
465 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
466 |
def raise_depreciation_errors(limitPowComplexity, threads):
|
467 |
if threads is not None:
|
468 |
raise ValueError("The threads kwarg is deprecated. Use procs.")
|
|
|
207 |
if len(X.shape) == 1:
|
208 |
X = X[:, None]
|
209 |
|
210 |
+
check_assertions(X, binary_operators, unary_operators, use_custom_variable_names, variable_names, weights, y)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
211 |
|
212 |
if select_k_features is not None:
|
213 |
selection = run_feature_selection(X, y, select_k_features)
|
|
|
239 |
y = eval(eval_str)
|
240 |
print("Running on", eval_str)
|
241 |
|
242 |
+
X_filename, dataset_filename, hyperparam_filename, operator_filename, pkg_filename, runfile_filename, tmpdir, weights_filename, y_filename = set_paths(
|
243 |
+
tempdir)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
244 |
|
245 |
def_hyperparams = ""
|
246 |
|
|
|
444 |
return get_hof()
|
445 |
|
446 |
|
447 |
+
def set_paths(tempdir):
|
448 |
+
# System-independent paths
|
449 |
+
pkg_directory = Path(__file__).parents[1] / 'julia'
|
450 |
+
pkg_filename = pkg_directory / "sr.jl"
|
451 |
+
operator_filename = pkg_directory / "operators.jl"
|
452 |
+
tmpdir = Path(tempfile.mkdtemp(dir=tempdir))
|
453 |
+
hyperparam_filename = tmpdir / f'hyperparams.jl'
|
454 |
+
dataset_filename = tmpdir / f'dataset.jl'
|
455 |
+
runfile_filename = tmpdir / f'runfile.jl'
|
456 |
+
X_filename = tmpdir / "X.csv"
|
457 |
+
y_filename = tmpdir / "y.csv"
|
458 |
+
weights_filename = tmpdir / "weights.csv"
|
459 |
+
return X_filename, dataset_filename, hyperparam_filename, operator_filename, pkg_filename, runfile_filename, tmpdir, weights_filename, y_filename
|
460 |
+
|
461 |
+
|
462 |
+
def check_assertions(X, binary_operators, unary_operators, use_custom_variable_names, variable_names, weights, y):
|
463 |
+
# Check for potential errors before they happen
|
464 |
+
assert len(unary_operators) + len(binary_operators) > 0
|
465 |
+
assert len(X.shape) == 2
|
466 |
+
assert len(y.shape) == 1
|
467 |
+
assert X.shape[0] == y.shape[0]
|
468 |
+
if weights is not None:
|
469 |
+
assert len(weights.shape) == 1
|
470 |
+
assert X.shape[0] == weights.shape[0]
|
471 |
+
if use_custom_variable_names:
|
472 |
+
assert len(variable_names) == X.shape[1]
|
473 |
+
|
474 |
+
|
475 |
def raise_depreciation_errors(limitPowComplexity, threads):
|
476 |
if threads is not None:
|
477 |
raise ValueError("The threads kwarg is deprecated. Use procs.")
|