tttc3 commited on
Commit
ad1c492
1 Parent(s): 3182a3b

Addressed some DeepSource issues

Browse files
Files changed (1) hide show
  1. pysr/sr.py +44 -30
pysr/sr.py CHANGED
@@ -2,7 +2,6 @@ import os
2
  import sys
3
  import numpy as np
4
  import pandas as pd
5
- from sklearn.utils import check_array, check_consistent_length, check_random_state
6
  import sympy
7
  from sympy import sympify
8
  import re
@@ -13,6 +12,7 @@ from datetime import datetime
13
  import warnings
14
  from multiprocessing import cpu_count
15
  from sklearn.base import BaseEstimator, RegressorMixin, MultiOutputMixin
 
16
  from sklearn.utils.validation import (
17
  _check_feature_names_in,
18
  check_is_fitted,
@@ -76,7 +76,8 @@ sympy_mappings = {
76
 
77
  def pysr(X, y, weights=None, **kwargs): # pragma: no cover
78
  warnings.warn(
79
- "Calling `pysr` is deprecated. Please use `model = PySRRegressor(**params); model.fit(X, y)` going forward.",
 
80
  FutureWarning,
81
  )
82
  model = PySRRegressor(**kwargs)
@@ -95,7 +96,8 @@ def _process_constraints(binary_operators, unary_operators, constraints):
95
  if op in ["plus", "sub", "+", "-"]:
96
  if constraints[op][0] != constraints[op][1]:
97
  raise NotImplementedError(
98
- "You need equal constraints on both sides for - and +, due to simplification strategies."
 
99
  )
100
  elif op in ["mult", "*"]:
101
  # Make sure the complex expression is in the left side.
@@ -128,7 +130,8 @@ def _maybe_create_inline_operators(binary_operators, unary_operators):
128
  if not re.match(r"^[a-zA-Z0-9_]+$", function_name):
129
  raise ValueError(
130
  f"Invalid function name {function_name}. "
131
- "Only alphanumeric characters, numbers, and underscores are allowed."
 
132
  )
133
  op_list[i] = function_name
134
  return binary_operators, unary_operators
@@ -154,25 +157,32 @@ def _check_assertions(
154
 
155
  def best(*args, **kwargs): # pragma: no cover
156
  raise NotImplementedError(
157
- "`best` has been deprecated. Please use the `PySRRegressor` interface. After fitting, you can return `.sympy()` to get the sympy representation of the best equation."
 
 
158
  )
159
 
160
 
161
  def best_row(*args, **kwargs): # pragma: no cover
162
  raise NotImplementedError(
163
- "`best_row` has been deprecated. Please use the `PySRRegressor` interface. After fitting, you can run `print(model)` to view the best equation, or `model.get_best()` to return the best equation's row in `model.equations`."
 
 
164
  )
165
 
166
 
167
  def best_tex(*args, **kwargs): # pragma: no cover
168
  raise NotImplementedError(
169
- "`best_tex` has been deprecated. Please use the `PySRRegressor` interface. After fitting, you can return `.latex()` to get the sympy representation of the best equation."
 
 
170
  )
171
 
172
 
173
  def best_callable(*args, **kwargs): # pragma: no cover
174
  raise NotImplementedError(
175
- "`best_callable` has been deprecated. Please use the `PySRRegressor` interface. After fitting, you can use `.predict(X)` to use the best callable."
 
176
  )
177
 
178
 
@@ -775,7 +785,7 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
775
  setattr(self, updated_kwarg_name, v)
776
  warnings.warn(
777
  f"{k} has been renamed to {updated_kwarg_name} in PySRRegressor. "
778
- " Please use that instead.",
779
  FutureWarning,
780
  )
781
  # Handle kwargs that have been moved to the fit method
@@ -787,7 +797,7 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
787
  )
788
  else:
789
  raise TypeError(
790
- f"{k} is not a valid keyword argument for PySRRegressor"
791
  )
792
 
793
  def __repr__(self):
@@ -964,7 +974,6 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
964
  values. For example, default parameters are set here
965
  when a parameter is left set to `None`.
966
  """
967
-
968
  # Immutable parameter validation
969
  # Ensure instance parameters are allowable values:
970
  if self.tournament_selection_n > self.population_size:
@@ -974,27 +983,29 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
974
 
975
  if self.maxsize > 40:
976
  warnings.warn(
977
- "Note: Using a large maxsize for the equation search will be exponentially slower and use significant memory. You should consider turning `use_frequency` to False, and perhaps use `warmup_maxsize_by`."
 
 
978
  )
979
  elif self.maxsize < 7:
980
  raise ValueError("PySR requires a maxsize of at least 7")
981
 
982
- if self.deterministic:
983
- if not (
984
- self.multithreading in [False, None]
985
- and self.procs == 0
986
- and self.random_state != None
987
- ):
988
- raise ValueError(
989
- "To ensure deterministic searches, you must set `random_state` to a seed, "
990
- "`procs` to `0`, and `multithreading` to `False` or `None`."
991
- )
992
 
993
- if self.random_state != None and (not self.deterministic or self.procs != 0):
 
 
994
  warnings.warn(
995
  "Note: Setting `random_state` without also setting `deterministic` "
996
- "to True and `procs` to 0 "
997
- "will result in non-deterministic searches. "
998
  )
999
 
1000
  # NotImplementedError - Values that could be supported at a later time
@@ -1035,7 +1046,8 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
1035
  parameter_value = 1
1036
  elif parameter == "progress" and not buffer_available:
1037
  warnings.warn(
1038
- "Note: it looks like you are running in Jupyter. The progress bar will be turned off."
 
1039
  )
1040
  parameter_value = False
1041
  packed_modified_params[parameter] = parameter_value
@@ -1087,7 +1099,6 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
1087
  Validated list of variable names for each feature in `X`.
1088
 
1089
  """
1090
-
1091
  if isinstance(X, pd.DataFrame):
1092
  if variable_names:
1093
  variable_names = None
@@ -1803,7 +1814,8 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
1803
  )
1804
  except FileNotFoundError:
1805
  raise RuntimeError(
1806
- "Couldn't find equation file! The equation search likely exited before a single iteration completed."
 
1807
  )
1808
 
1809
  # It is expected extra_jax/torch_mappings will be updated after fit.
@@ -1814,7 +1826,8 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
1814
  for value in extra_jax_mappings.values():
1815
  if not isinstance(value, str):
1816
  raise ValueError(
1817
- "extra_jax_mappings must have keys that are strings! e.g., {sympy.sqrt: 'jnp.sqrt'}."
 
1818
  )
1819
  else:
1820
  extra_jax_mappings = {}
@@ -1822,7 +1835,8 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
1822
  for value in extra_jax_mappings.values():
1823
  if not callable(value):
1824
  raise ValueError(
1825
- "extra_torch_mappings must be callable functions! e.g., {sympy.sqrt: torch.sqrt}."
 
1826
  )
1827
  else:
1828
  extra_torch_mappings = {}
 
2
  import sys
3
  import numpy as np
4
  import pandas as pd
 
5
  import sympy
6
  from sympy import sympify
7
  import re
 
12
  import warnings
13
  from multiprocessing import cpu_count
14
  from sklearn.base import BaseEstimator, RegressorMixin, MultiOutputMixin
15
+ from sklearn.utils import check_array, check_consistent_length, check_random_state
16
  from sklearn.utils.validation import (
17
  _check_feature_names_in,
18
  check_is_fitted,
 
76
 
77
  def pysr(X, y, weights=None, **kwargs): # pragma: no cover
78
  warnings.warn(
79
+ "Calling `pysr` is deprecated. "
80
+ "Please use `model = PySRRegressor(**params); model.fit(X, y)` going forward.",
81
  FutureWarning,
82
  )
83
  model = PySRRegressor(**kwargs)
 
96
  if op in ["plus", "sub", "+", "-"]:
97
  if constraints[op][0] != constraints[op][1]:
98
  raise NotImplementedError(
99
+ "You need equal constraints on both sides for - and +, "
100
+ "due to simplification strategies."
101
  )
102
  elif op in ["mult", "*"]:
103
  # Make sure the complex expression is in the left side.
 
130
  if not re.match(r"^[a-zA-Z0-9_]+$", function_name):
131
  raise ValueError(
132
  f"Invalid function name {function_name}. "
133
+ "Only alphanumeric characters, numbers, "
134
+ "and underscores are allowed."
135
  )
136
  op_list[i] = function_name
137
  return binary_operators, unary_operators
 
157
 
158
  def best(*args, **kwargs): # pragma: no cover
159
  raise NotImplementedError(
160
+ "`best` has been deprecated. Please use the `PySRRegressor` interface. "
161
+ "After fitting, you can return `.sympy()` to get the sympy representation "
162
+ "of the best equation."
163
  )
164
 
165
 
166
  def best_row(*args, **kwargs): # pragma: no cover
167
  raise NotImplementedError(
168
+ "`best_row` has been deprecated. Please use the `PySRRegressor` interface. "
169
+ "After fitting, you can run `print(model)` to view the best equation, or "
170
+ "`model.get_best()` to return the best equation's row in `model.equations`."
171
  )
172
 
173
 
174
  def best_tex(*args, **kwargs): # pragma: no cover
175
  raise NotImplementedError(
176
+ "`best_tex` has been deprecated. Please use the `PySRRegressor` interface. "
177
+ "After fitting, you can return `.latex()` to get the sympy representation "
178
+ "of the best equation."
179
  )
180
 
181
 
182
  def best_callable(*args, **kwargs): # pragma: no cover
183
  raise NotImplementedError(
184
+ "`best_callable` has been deprecated. Please use the `PySRRegressor` "
185
+ "interface. After fitting, you can use `.predict(X)` to use the best callable."
186
  )
187
 
188
 
 
785
  setattr(self, updated_kwarg_name, v)
786
  warnings.warn(
787
  f"{k} has been renamed to {updated_kwarg_name} in PySRRegressor. "
788
+ "Please use that instead.",
789
  FutureWarning,
790
  )
791
  # Handle kwargs that have been moved to the fit method
 
797
  )
798
  else:
799
  raise TypeError(
800
+ f"{k} is not a valid keyword argument for PySRRegressor."
801
  )
802
 
803
  def __repr__(self):
 
974
  values. For example, default parameters are set here
975
  when a parameter is left set to `None`.
976
  """
 
977
  # Immutable parameter validation
978
  # Ensure instance parameters are allowable values:
979
  if self.tournament_selection_n > self.population_size:
 
983
 
984
  if self.maxsize > 40:
985
  warnings.warn(
986
+ "Note: Using a large maxsize for the equation search will be "
987
+ "exponentially slower and use significant memory. You should consider "
988
+ "turning `use_frequency` to False, and perhaps use `warmup_maxsize_by`."
989
  )
990
  elif self.maxsize < 7:
991
  raise ValueError("PySR requires a maxsize of at least 7")
992
 
993
+ if self.deterministic and not (
994
+ self.multithreading in [False, None]
995
+ and self.procs == 0
996
+ and self.random_state is not None
997
+ ):
998
+ raise ValueError(
999
+ "To ensure deterministic searches, you must set `random_state` to a seed, "
1000
+ "`procs` to `0`, and `multithreading` to `False` or `None`."
1001
+ )
 
1002
 
1003
+ if self.random_state is not None and (
1004
+ not self.deterministic or self.procs != 0
1005
+ ):
1006
  warnings.warn(
1007
  "Note: Setting `random_state` without also setting `deterministic` "
1008
+ "to True and `procs` to 0 will result in non-deterministic searches. "
 
1009
  )
1010
 
1011
  # NotImplementedError - Values that could be supported at a later time
 
1046
  parameter_value = 1
1047
  elif parameter == "progress" and not buffer_available:
1048
  warnings.warn(
1049
+ "Note: it looks like you are running in Jupyter. "
1050
+ "The progress bar will be turned off."
1051
  )
1052
  parameter_value = False
1053
  packed_modified_params[parameter] = parameter_value
 
1099
  Validated list of variable names for each feature in `X`.
1100
 
1101
  """
 
1102
  if isinstance(X, pd.DataFrame):
1103
  if variable_names:
1104
  variable_names = None
 
1814
  )
1815
  except FileNotFoundError:
1816
  raise RuntimeError(
1817
+ "Couldn't find equation file! The equation search likely exited "
1818
+ "before a single iteration completed."
1819
  )
1820
 
1821
  # It is expected extra_jax/torch_mappings will be updated after fit.
 
1826
  for value in extra_jax_mappings.values():
1827
  if not isinstance(value, str):
1828
  raise ValueError(
1829
+ "extra_jax_mappings must have keys that are strings! "
1830
+ "e.g., {sympy.sqrt: 'jnp.sqrt'}."
1831
  )
1832
  else:
1833
  extra_jax_mappings = {}
 
1835
  for value in extra_jax_mappings.values():
1836
  if not callable(value):
1837
  raise ValueError(
1838
+ "extra_torch_mappings must be callable functions! "
1839
+ "e.g., {sympy.sqrt: torch.sqrt}."
1840
  )
1841
  else:
1842
  extra_torch_mappings = {}