tttc3 commited on
Commit
6881818
1 Parent(s): 3e8d44d

Updated parameter validation

Browse files
Files changed (1) hide show
  1. pysr/sr.py +112 -97
pysr/sr.py CHANGED
@@ -529,6 +529,12 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
529
  List of indices for input features that are selected when
530
  :param`select_k_features` is set.
531
 
 
 
 
 
 
 
532
  raw_julia_state_ : tuple[list[PyCall.jlwrap], PyCall.jlwrap]
533
  The state for the julia SymbolicRegression.jl backend post fitting.
534
 
@@ -928,6 +934,71 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
928
  else:
929
  self.equation_file_ = self.equation_file
930
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
931
  def _validate_fit_params(self, X, y, Xresampled, variable_names):
932
  """
933
  Validates the parameters passed to the :term`fit` method.
@@ -965,39 +1036,6 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
965
 
966
  """
967
 
968
- # Ensure instance parameters are allowable values:
969
- if self.tournament_selection_n > self.population_size:
970
- raise ValueError(
971
- "tournament_selection_n parameter must be smaller than population_size."
972
- )
973
-
974
- if self.maxsize > 40:
975
- warnings.warn(
976
- "Note: Using a large maxsize for the equation search will be exponentially slower and use significant memory. You should consider turning `use_frequency` to False, and perhaps use `warmup_maxsize_by`."
977
- )
978
- elif self.maxsize < 7:
979
- raise ValueError("PySR requires a maxsize of at least 7")
980
-
981
- if self.extra_jax_mappings is not None:
982
- for value in self.extra_jax_mappings.values():
983
- if not isinstance(value, str):
984
- raise ValueError(
985
- "extra_jax_mappings must have keys that are strings! e.g., {sympy.sqrt: 'jnp.sqrt'}."
986
- )
987
-
988
- if self.extra_torch_mappings is not None:
989
- for value in self.extra_jax_mappings.values():
990
- if not callable(value):
991
- raise ValueError(
992
- "extra_torch_mappings must be callable functions! e.g., {sympy.sqrt: torch.sqrt}."
993
- )
994
-
995
- # NotImplementedError - Values that could be supported at a later time
996
- if self.optimizer_algorithm not in VALID_OPTIMIZER_ALGORITHMS:
997
- raise NotImplementedError(
998
- f"PySR currently only supports the following optimizer algorithms: {VALID_OPTIMIZER_ALGORITHMS}"
999
- )
1000
-
1001
  if isinstance(X, pd.DataFrame):
1002
  if variable_names:
1003
  variable_names = None
@@ -1020,13 +1058,10 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
1020
  "Spaces have been replaced with underscores. \n"
1021
  "Please use valid names instead."
1022
  )
1023
- # Only numpy values are needed from Xresampled, column metadata is
1024
- # provided by X
1025
- if isinstance(Xresampled, pd.DataFrame):
1026
- Xresampled = Xresampled.values
1027
 
1028
  # Data validation and feature name fetching via sklearn
1029
  # This method sets the n_features_in_ attribute
 
1030
  X, y = self._validate_data(X=X, y=y, reset=True, multi_output=True)
1031
  self.feature_names_in_ = _check_feature_names_in(self, variable_names)
1032
  variable_names = self.feature_names_in_
@@ -1126,7 +1161,7 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
1126
 
1127
  return X, y, variable_names
1128
 
1129
- def _run(self, X, y, weights, seed):
1130
  """
1131
  Run the symbolic regression fitting process on the julia backend.
1132
 
@@ -1138,10 +1173,16 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
1138
  y : {ndarray | pandas.DataFrame} of shape (n_samples,) or (n_samples, n_targets)
1139
  Target values. Will be cast to X's dtype if necessary.
1140
 
1141
- weights : {ndarray | pandas.DataFrame} of the same shape as y, default=None
 
 
 
1142
  Each element is how to weight the mean-square-error loss
1143
  for that particular element of y.
1144
 
 
 
 
1145
  Returns
1146
  -------
1147
  self : object
@@ -1159,66 +1200,17 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
1159
 
1160
  # These are the parameters which may be modified from the ones
1161
  # specified in init, so we define them here locally:
1162
- binary_operators = self.binary_operators
1163
- unary_operators = self.unary_operators
1164
- constraints = self.constraints
 
1165
  nested_constraints = self.nested_constraints
1166
  complexity_of_operators = self.complexity_of_operators
1167
- multithreading = self.multithreading
1168
- update_verbosity = self.update_verbosity
1169
- maxdepth = self.maxdepth
1170
- batch_size = self.batch_size
1171
- progress = self.progress
1172
  cluster_manager = self.cluster_manager
1173
-
1174
- # TODO: Clean this up into a readable format, such that
1175
- # a function call automatically configures each default.
1176
-
1177
- # Deal with default values, and type conversions:
1178
- if binary_operators is None:
1179
- binary_operators = "+ * - /".split(" ")
1180
- elif isinstance(binary_operators, str):
1181
- binary_operators = [binary_operators]
1182
-
1183
- if unary_operators is None:
1184
- unary_operators = []
1185
- elif isinstance(unary_operators, str):
1186
- unary_operators = [unary_operators]
1187
-
1188
- assert len(unary_operators) + len(binary_operators) > 0
1189
-
1190
- if constraints is None:
1191
- constraints = {}
1192
-
1193
- if multithreading is None:
1194
- # Default is multithreading=True, unless explicitly set,
1195
- # or procs is set to 0 (serial mode).
1196
- multithreading = self.procs != 0 and cluster_manager is None
1197
-
1198
- if update_verbosity is None:
1199
- update_verbosity = self.verbosity
1200
-
1201
- if maxdepth is None:
1202
- maxdepth = self.maxsize
1203
-
1204
- # Warn if instance parameters are not sensible values:
1205
- if batch_size < 1:
1206
- warnings.warn(
1207
- "Given :param`batch_size` must be greater than or equal to one. "
1208
- ":param`batch_size` has been increased to equal one."
1209
- )
1210
- batch_size = 1
1211
-
1212
- # Handle presentation of the progress bar:
1213
- buffer_available = "buffer" in sys.stdout.__dir__()
1214
- if progress is not None:
1215
- if progress and not buffer_available:
1216
- warnings.warn(
1217
- "Note: it looks like you are running in Jupyter. The progress bar will be turned off."
1218
- )
1219
- progress = False
1220
- else:
1221
- progress = buffer_available
1222
 
1223
  # Start julia backend processes
1224
  if Main is None:
@@ -1455,6 +1447,8 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
1455
 
1456
  self._setup_equation_file()
1457
 
 
 
1458
  # Parameter input validation (for parameters defined in __init__)
1459
  X, y, Xresampled, variable_names = self._validate_fit_params(
1460
  X, y, Xresampled, variable_names
@@ -1505,7 +1499,7 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
1505
  )
1506
 
1507
  # Fitting procedure
1508
- return self._run(X=X, y=y, weights=weights, seed=seed)
1509
 
1510
  def refresh(self, checkpoint_file=None):
1511
  """
@@ -1736,6 +1730,27 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
1736
  "Couldn't find equation file! The equation search likely exited before a single iteration completed."
1737
  )
1738
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1739
  ret_outputs = []
1740
 
1741
  for output in all_outputs:
 
529
  List of indices for input features that are selected when
530
  :param`select_k_features` is set.
531
 
532
+ tempdir_ : Path
533
+ Path to the temporary equations directory.
534
+
535
+ equation_file_ : str
536
+ Output equation file name produced by the julia backend.
537
+
538
  raw_julia_state_ : tuple[list[PyCall.jlwrap], PyCall.jlwrap]
539
  The state for the julia SymbolicRegression.jl backend post fitting.
540
 
 
934
  else:
935
  self.equation_file_ = self.equation_file
936
 
937
+ def _validate_init_params(self):
938
+
939
+ # Immutable parameter validation
940
+ # Ensure instance parameters are allowable values:
941
+ if self.tournament_selection_n > self.population_size:
942
+ raise ValueError(
943
+ "tournament_selection_n parameter must be smaller than population_size."
944
+ )
945
+
946
+ if self.maxsize > 40:
947
+ warnings.warn(
948
+ "Note: Using a large maxsize for the equation search will be exponentially slower and use significant memory. You should consider turning `use_frequency` to False, and perhaps use `warmup_maxsize_by`."
949
+ )
950
+ elif self.maxsize < 7:
951
+ raise ValueError("PySR requires a maxsize of at least 7")
952
+
953
+ # NotImplementedError - Values that could be supported at a later time
954
+ if self.optimizer_algorithm not in VALID_OPTIMIZER_ALGORITHMS:
955
+ raise NotImplementedError(
956
+ f"PySR currently only supports the following optimizer algorithms: {VALID_OPTIMIZER_ALGORITHMS}"
957
+ )
958
+
959
+ # 'Mutable' parameter validation
960
+ buffer_available = "buffer" in sys.stdout.__dir__()
961
+ modifiable_params = {
962
+ "binary_operators": "+ * - /".split(" "),
963
+ "unary_operators": [],
964
+ "maxdepth": self.maxsize,
965
+ "constraints": {},
966
+ "multithreading": self.procs != 0 and self.cluster_manager is None,
967
+ "batch_size": 1,
968
+ "update_verbosity": self.verbosity,
969
+ "progress": buffer_available,
970
+ }
971
+ packed_modified_params = {}
972
+ for parameter, default_value in modifiable_params.items():
973
+ parameter_value = getattr(self, parameter)
974
+ if parameter_value is None:
975
+ parameter_value = default_value
976
+ else:
977
+ # Special cases such as when binary_operators is a string
978
+ if parameter in ["binary_operators", "unary_operators"] and isinstance(
979
+ parameter_value, str
980
+ ):
981
+ parameter_value = [parameter_value]
982
+ elif parameter is "batch_size" and parameter_value < 1:
983
+ warnings.warn(
984
+ "Given :param`batch_size` must be greater than or equal to one. "
985
+ ":param`batch_size` has been increased to equal one."
986
+ )
987
+ parameter_value = 1
988
+ elif parameter is "progress" and not buffer_available:
989
+ warnings.warn(
990
+ "Note: it looks like you are running in Jupyter. The progress bar will be turned off."
991
+ )
992
+ parameter_value = False
993
+ packed_modified_params[parameter] = parameter_value
994
+
995
+ assert (
996
+ len(packed_modified_params["binary_operators"])
997
+ + len(packed_modified_params["unary_operators"])
998
+ > 0
999
+ )
1000
+ return packed_modified_params
1001
+
1002
  def _validate_fit_params(self, X, y, Xresampled, variable_names):
1003
  """
1004
  Validates the parameters passed to the :term`fit` method.
 
1036
 
1037
  """
1038
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1039
  if isinstance(X, pd.DataFrame):
1040
  if variable_names:
1041
  variable_names = None
 
1058
  "Spaces have been replaced with underscores. \n"
1059
  "Please use valid names instead."
1060
  )
 
 
 
 
1061
 
1062
  # Data validation and feature name fetching via sklearn
1063
  # This method sets the n_features_in_ attribute
1064
+ Xresampled = check_array(Xresampled)
1065
  X, y = self._validate_data(X=X, y=y, reset=True, multi_output=True)
1066
  self.feature_names_in_ = _check_feature_names_in(self, variable_names)
1067
  variable_names = self.feature_names_in_
 
1161
 
1162
  return X, y, variable_names
1163
 
1164
+ def _run(self, X, y, mutated_params, weights, seed):
1165
  """
1166
  Run the symbolic regression fitting process on the julia backend.
1167
 
 
1173
  y : {ndarray | pandas.DataFrame} of shape (n_samples,) or (n_samples, n_targets)
1174
  Target values. Will be cast to X's dtype if necessary.
1175
 
1176
+ mutated_params : dict[str, Any]
1177
+ Dictionary of mutated versions of some parameters passed in __init__.
1178
+
1179
+ weights : {ndarray | pandas.DataFrame} of the same shape as y
1180
  Each element is how to weight the mean-square-error loss
1181
  for that particular element of y.
1182
 
1183
+ seed : int
1184
+ Random seed for julia backend process.
1185
+
1186
  Returns
1187
  -------
1188
  self : object
 
1200
 
1201
  # These are the parameters which may be modified from the ones
1202
  # specified in init, so we define them here locally:
1203
+ binary_operators = mutated_params["binary_operators"]
1204
+ unary_operators = mutated_params["unary_operators"]
1205
+ maxdepth = mutated_params["maxdepth"]
1206
+ constraints = mutated_params["constraints"]
1207
  nested_constraints = self.nested_constraints
1208
  complexity_of_operators = self.complexity_of_operators
1209
+ multithreading = mutated_params["multithreading"]
 
 
 
 
1210
  cluster_manager = self.cluster_manager
1211
+ batch_size = mutated_params["batch_size"]
1212
+ update_verbosity = mutated_params["update_verbosity"]
1213
+ progress = mutated_params["progress"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1214
 
1215
  # Start julia backend processes
1216
  if Main is None:
 
1447
 
1448
  self._setup_equation_file()
1449
 
1450
+ mutated_params = self._validate_init_params()
1451
+
1452
  # Parameter input validation (for parameters defined in __init__)
1453
  X, y, Xresampled, variable_names = self._validate_fit_params(
1454
  X, y, Xresampled, variable_names
 
1499
  )
1500
 
1501
  # Fitting procedure
1502
+ return self._run(X, y, mutated_params, weights=weights, seed=seed)
1503
 
1504
  def refresh(self, checkpoint_file=None):
1505
  """
 
1730
  "Couldn't find equation file! The equation search likely exited before a single iteration completed."
1731
  )
1732
 
1733
+ # It is expected extra_jax/torch_mappings will be updated after fit.
1734
+ # Thus, validation is performed here instead of in _validate_init_params
1735
+ extra_jax_mappings = self.extra_jax_mappings
1736
+ extra_torch_mappings = self.extra_torch_mappings
1737
+ if extra_jax_mappings is not None:
1738
+ for value in self.extra_jax_mappings.values():
1739
+ if not isinstance(value, str):
1740
+ raise ValueError(
1741
+ "extra_jax_mappings must have keys that are strings! e.g., {sympy.sqrt: 'jnp.sqrt'}."
1742
+ )
1743
+ else:
1744
+ extra_jax_mappings = {}
1745
+ if extra_torch_mappings is not None:
1746
+ for value in self.extra_jax_mappings.values():
1747
+ if not callable(value):
1748
+ raise ValueError(
1749
+ "extra_torch_mappings must be callable functions! e.g., {sympy.sqrt: torch.sqrt}."
1750
+ )
1751
+ else:
1752
+ extra_torch_mappings = {}
1753
+
1754
  ret_outputs = []
1755
 
1756
  for output in all_outputs: