MilesCranmer commited on
Commit
09a7186
β€’
1 Parent(s): 5620b3a

Refactor utility functions

Browse files
Files changed (2) hide show
  1. pysr/sr.py +8 -52
  2. pysr/utils.py +55 -0
pysr/sr.py CHANGED
@@ -33,6 +33,12 @@ from .julia_helpers import (
33
  init_julia,
34
  is_julia_version_greater_eq,
35
  )
 
 
 
 
 
 
36
 
37
  Main = None # TODO: Rename to more descriptive name like "julia_runtime"
38
 
@@ -945,10 +951,8 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
945
  model : PySRRegressor
946
  The model with fitted equations.
947
  """
948
- if os.path.splitext(equation_file)[1] != ".pkl":
949
- pkl_filename = _csv_filename_to_pkl_filename(equation_file)
950
- else:
951
- pkl_filename = equation_file
952
 
953
  # Try to load model from <equation_file>.pkl
954
  print(f"Checking if {pkl_filename} exists...")
@@ -2437,51 +2441,3 @@ def run_feature_selection(X, y, select_k_features, random_state=None):
2437
  clf, threshold=-np.inf, max_features=select_k_features, prefit=True
2438
  )
2439
  return selector.get_support(indices=True)
2440
-
2441
-
2442
- def _csv_filename_to_pkl_filename(csv_filename) -> str:
2443
- # Assume that the csv filename is of the form "foo.csv"
2444
- assert str(csv_filename).endswith(".csv")
2445
-
2446
- dirname = str(os.path.dirname(csv_filename))
2447
- basename = str(os.path.basename(csv_filename))
2448
- base = str(os.path.splitext(basename)[0])
2449
-
2450
- pkl_basename = base + ".pkl"
2451
-
2452
- return os.path.join(dirname, pkl_basename)
2453
-
2454
-
2455
- _regexp_im = re.compile(r"\b(\d+\.\d+)im\b")
2456
- _regexp_im_sci = re.compile(r"\b(\d+\.\d+)[eEfF]([+-]?\d+)im\b")
2457
- _regexp_sci = re.compile(r"\b(\d+\.\d+)[eEfF]([+-]?\d+)\b")
2458
-
2459
- _apply_regexp_im = lambda x: _regexp_im.sub(r"\1j", x)
2460
- _apply_regexp_im_sci = lambda x: _regexp_im_sci.sub(r"\1e\2j", x)
2461
- _apply_regexp_sci = lambda x: _regexp_sci.sub(r"\1e\2", x)
2462
-
2463
-
2464
- def _preprocess_julia_floats(s: str) -> str:
2465
- if isinstance(s, str):
2466
- s = _apply_regexp_im(s)
2467
- s = _apply_regexp_im_sci(s)
2468
- s = _apply_regexp_sci(s)
2469
- return s
2470
-
2471
-
2472
- def _subscriptify(i: int) -> str:
2473
- """Converts integer to subscript text form.
2474
-
2475
- For example, 123 -> "₁₂₃".
2476
- """
2477
- return "".join([chr(0x2080 + int(c)) for c in str(i)])
2478
-
2479
-
2480
- def _safe_check_feature_names_in(self, variable_names, generate_names=True):
2481
- """_check_feature_names_in with compat for old versions."""
2482
- try:
2483
- return _check_feature_names_in(
2484
- self, variable_names, generate_names=generate_names
2485
- )
2486
- except TypeError:
2487
- return _check_feature_names_in(self, variable_names)
 
33
  init_julia,
34
  is_julia_version_greater_eq,
35
  )
36
+ from .utils import (
37
+ _csv_filename_to_pkl_filename,
38
+ _preprocess_julia_floats,
39
+ _safe_check_feature_names_in,
40
+ _subscriptify,
41
+ )
42
 
43
  Main = None # TODO: Rename to more descriptive name like "julia_runtime"
44
 
 
951
  model : PySRRegressor
952
  The model with fitted equations.
953
  """
954
+
955
+ pkl_filename = _csv_filename_to_pkl_filename(equation_file)
 
 
956
 
957
  # Try to load model from <equation_file>.pkl
958
  print(f"Checking if {pkl_filename} exists...")
 
2441
  clf, threshold=-np.inf, max_features=select_k_features, prefit=True
2442
  )
2443
  return selector.get_support(indices=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
pysr/utils.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+
4
+ from sklearn.utils.validation import _check_feature_names_in
5
+
6
+
7
+ def _csv_filename_to_pkl_filename(csv_filename: str) -> str:
8
+ if os.path.splitext(csv_filename)[1] == ".pkl":
9
+ return csv_filename
10
+
11
+ # Assume that the csv filename is of the form "foo.csv"
12
+ assert str(csv_filename).endswith(".csv")
13
+
14
+ dirname = str(os.path.dirname(csv_filename))
15
+ basename = str(os.path.basename(csv_filename))
16
+ base = str(os.path.splitext(basename)[0])
17
+
18
+ pkl_basename = base + ".pkl"
19
+
20
+ return os.path.join(dirname, pkl_basename)
21
+
22
+
23
+ _regexp_im = re.compile(r"\b(\d+\.\d+)im\b")
24
+ _regexp_im_sci = re.compile(r"\b(\d+\.\d+)[eEfF]([+-]?\d+)im\b")
25
+ _regexp_sci = re.compile(r"\b(\d+\.\d+)[eEfF]([+-]?\d+)\b")
26
+
27
+ _apply_regexp_im = lambda x: _regexp_im.sub(r"\1j", x)
28
+ _apply_regexp_im_sci = lambda x: _regexp_im_sci.sub(r"\1e\2j", x)
29
+ _apply_regexp_sci = lambda x: _regexp_sci.sub(r"\1e\2", x)
30
+
31
+
32
+ def _preprocess_julia_floats(s: str) -> str:
33
+ if isinstance(s, str):
34
+ s = _apply_regexp_im(s)
35
+ s = _apply_regexp_im_sci(s)
36
+ s = _apply_regexp_sci(s)
37
+ return s
38
+
39
+
40
+ def _safe_check_feature_names_in(self, variable_names, generate_names=True):
41
+ """_check_feature_names_in with compat for old versions."""
42
+ try:
43
+ return _check_feature_names_in(
44
+ self, variable_names, generate_names=generate_names
45
+ )
46
+ except TypeError:
47
+ return _check_feature_names_in(self, variable_names)
48
+
49
+
50
+ def _subscriptify(i: int) -> str:
51
+ """Converts integer to subscript text form.
52
+
53
+ For example, 123 -> "₁₂₃".
54
+ """
55
+ return "".join([chr(0x2080 + int(c)) for c in str(i)])