MilesCranmer commited on
Commit
3dff82f
1 Parent(s): e7b4ea9

Make __init__ not modify parameters again

Browse files
Files changed (1) hide show
  1. pysr/sr.py +120 -126
pysr/sr.py CHANGED
@@ -759,8 +759,6 @@ class PySRRegressor(BaseEstimator, RegressorMixin, MultiOutputMixin):
759
  f"{k} is not a valid keyword argument for PySRRegressor"
760
  )
761
 
762
- self._process_params()
763
-
764
  def __repr__(self):
765
  """
766
  Prints all current equations fitted by the model.
@@ -865,105 +863,6 @@ class PySRRegressor(BaseEstimator, RegressorMixin, MultiOutputMixin):
865
  f"{self.model_selection} is not a valid model selection strategy."
866
  )
867
 
868
- def _process_params(self):
869
- """
870
- Perform validation on the parameters defined in init for the
871
- dataset specified in :term`fit`, and update them if necessary.
872
- For example, this will change :param`binary_operators`
873
- into `["+", "-", "*", "/"]` if `binary_operators` is `None`.
874
-
875
- Raises
876
- ------
877
- ValueError
878
- Raised when on of the following occurs: `tournament_selection_n`
879
- parameter is larger than `population_size`; `maxsize` is
880
- less than 7; invalid `extra_jax_mappings` or
881
- `extra_torch_mappings`; invalid optimizer algorithms.
882
-
883
- """
884
- # Handle None values for instance parameters:
885
- if self.binary_operators is None:
886
- self.binary_operators = "+ * - /".split(" ")
887
- if self.unary_operators is None:
888
- self.unary_operators = []
889
- if self.extra_sympy_mappings is None:
890
- self.extra_sympy_mappings = {}
891
- if self.constraints is None:
892
- self.constraints = {}
893
- if self.multithreading is None:
894
- # Default is multithreading=True, unless explicitly set,
895
- # or procs is set to 0 (serial mode).
896
- self.multithreading = self.procs != 0 and self.cluster_manager is None
897
- if self.update_verbosity is None:
898
- self.update_verbosity = self.verbosity
899
- if self.maxdepth is None:
900
- self.maxdepth = self.maxsize
901
-
902
- # Handle type conversion for instance parameters:
903
- if isinstance(self.binary_operators, str):
904
- self.binary_operators = [self.binary_operators]
905
- if isinstance(self.unary_operators, str):
906
- self.unary_operators = [self.unary_operators]
907
-
908
- # Warn if instance parameters are not sensible values:
909
- if self.batch_size < 1:
910
- warnings.warn(
911
- "Given :param`batch_size` must be greater than or equal to one. "
912
- ":param`batch_size` has been increased to equal one."
913
- )
914
- self.batch_size = 1
915
-
916
- # Ensure instance parameters are allowable values:
917
- # ValueError - Incompatible values
918
- if self.tournament_selection_n > self.population_size:
919
- raise ValueError(
920
- "tournament_selection_n parameter must be smaller than population_size."
921
- )
922
-
923
- if self.maxsize > 40:
924
- warnings.warn(
925
- "Note: Using a large maxsize for the equation search will be exponentially slower and use significant memory. You should consider turning `use_frequency` to False, and perhaps use `warmup_maxsize_by`."
926
- )
927
- elif self.maxsize < 7:
928
- raise ValueError("PySR requires a maxsize of at least 7")
929
-
930
- if self.extra_jax_mappings is not None:
931
- for value in self.extra_jax_mappings.values():
932
- if not isinstance(value, str):
933
- raise ValueError(
934
- "extra_jax_mappings must have keys that are strings! e.g., {sympy.sqrt: 'jnp.sqrt'}."
935
- )
936
- else:
937
- self.extra_jax_mappings = {}
938
-
939
- if self.extra_torch_mappings is not None:
940
- for value in self.extra_jax_mappings.values():
941
- if not callable(value):
942
- raise ValueError(
943
- "extra_torch_mappings must be callable functions! e.g., {sympy.sqrt: torch.sqrt}."
944
- )
945
- else:
946
- self.extra_torch_mappings = {}
947
-
948
- # NotImplementedError - Values that could be supported at a later time
949
- if self.optimizer_algorithm not in self.VALID_OPTIMIZER_ALGORITHMS:
950
- raise NotImplementedError(
951
- f"PySR currently only supports the following optimizer algorithms: {self.VALID_OPTIMIZER_ALGORITHMS}"
952
- )
953
-
954
- # Handle presentation of the progress bar:
955
- buffer_available = "buffer" in sys.stdout.__dir__()
956
- if self.progress is not None:
957
- if self.progress and not buffer_available:
958
- warnings.warn(
959
- "Note: it looks like you are running in Jupyter. The progress bar will be turned off."
960
- )
961
- self.progress = False
962
- else:
963
- self.progress = buffer_available
964
-
965
- return self
966
-
967
  def _setup_equation_file(self):
968
  """
969
  Sets the full pathname of the equation file, using :param`tempdir` and
@@ -1016,6 +915,39 @@ class PySRRegressor(BaseEstimator, RegressorMixin, MultiOutputMixin):
1016
 
1017
  """
1018
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1019
  if isinstance(X, pd.DataFrame):
1020
  if variable_names:
1021
  variable_names = None
@@ -1165,23 +1097,82 @@ class PySRRegressor(BaseEstimator, RegressorMixin, MultiOutputMixin):
1165
  global already_ran
1166
  global Main
1167
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1168
  # Start julia backend processes
1169
  if Main is None:
1170
- if self.multithreading:
1171
  os.environ["JULIA_NUM_THREADS"] = str(self.procs)
1172
 
1173
  Main = init_julia()
1174
 
1175
- if self.cluster_manager is not None:
1176
- Main.eval(f"import ClusterManagers: addprocs_{self.cluster_manager}")
1177
- cluster_manager = Main.eval(f"addprocs_{self.cluster_manager}")
1178
- else:
1179
- cluster_manager = None
1180
 
1181
  if not already_ran:
1182
  julia_project, is_shared = _get_julia_project(self.julia_project)
1183
  Main.eval("using Pkg")
1184
- io = "devnull" if self.update_verbosity == 0 else "stderr"
1185
  io_arg = f"io={io}" if is_julia_version_greater_eq(Main, "1.6") else ""
1186
 
1187
  Main.eval(
@@ -1211,39 +1202,35 @@ class PySRRegressor(BaseEstimator, RegressorMixin, MultiOutputMixin):
1211
 
1212
  # TODO(mcranmer): These functions should be part of this class.
1213
  binary_operators, unary_operators = _maybe_create_inline_operators(
1214
- binary_operators=self.binary_operators, unary_operators=self.unary_operators
1215
  )
1216
  constraints = _process_constraints(
1217
  binary_operators=binary_operators,
1218
  unary_operators=unary_operators,
1219
- constraints=self.constraints,
1220
  )
1221
 
1222
  una_constraints = [constraints[op] for op in unary_operators]
1223
  bin_constraints = [constraints[op] for op in binary_operators]
1224
 
1225
  # Parse dict into Julia Dict for nested constraints::
1226
- if self.nested_constraints is not None:
1227
  nested_constraints_str = "Dict("
1228
- for outer_k, outer_v in self.nested_constraints.items():
1229
  nested_constraints_str += f"({outer_k}) => Dict("
1230
  for inner_k, inner_v in outer_v.items():
1231
  nested_constraints_str += f"({inner_k}) => {inner_v}, "
1232
  nested_constraints_str += "), "
1233
  nested_constraints_str += ")"
1234
  nested_constraints = Main.eval(nested_constraints_str)
1235
- else:
1236
- nested_constraints = None
1237
 
1238
  # Parse dict into Julia Dict for complexities:
1239
- if self.complexity_of_operators is not None:
1240
  complexity_of_operators_str = "Dict("
1241
- for k, v in self.complexity_of_operators.items():
1242
  complexity_of_operators_str += f"({k}) => {v}, "
1243
  complexity_of_operators_str += ")"
1244
  complexity_of_operators = Main.eval(complexity_of_operators_str)
1245
- else:
1246
- complexity_of_operators = None
1247
 
1248
  Main.custom_loss = Main.eval(self.loss)
1249
 
@@ -1274,14 +1261,14 @@ class PySRRegressor(BaseEstimator, RegressorMixin, MultiOutputMixin):
1274
  hofFile=_escape_filename(self.equation_file_),
1275
  npopulations=int(self.populations),
1276
  batching=self.batching,
1277
- batchSize=int(min([self.batch_size, len(X)]) if self.batching else len(X)),
1278
  mutationWeights=mutationWeights,
1279
  probPickFirst=self.tournament_selection_p,
1280
  ns=self.tournament_selection_n,
1281
  # These have the same name:
1282
  parsimony=self.parsimony,
1283
  alpha=self.alpha,
1284
- maxdepth=self.maxdepth,
1285
  fast_cycle=self.fast_cycle,
1286
  migration=self.migration,
1287
  hofMigration=self.hof_migration,
@@ -1302,7 +1289,7 @@ class PySRRegressor(BaseEstimator, RegressorMixin, MultiOutputMixin):
1302
  perturbationFactor=self.perturbation_factor,
1303
  annealing=self.annealing,
1304
  stateReturn=True, # Required for state saving.
1305
- progress=self.progress,
1306
  timeout_in_seconds=self.timeout_in_seconds,
1307
  crossoverProbability=self.crossover_probability,
1308
  skip_mutation_failures=self.skip_mutation_failures,
@@ -1313,6 +1300,7 @@ class PySRRegressor(BaseEstimator, RegressorMixin, MultiOutputMixin):
1313
  # Convert data to desired precision
1314
  np_dtype = {16: np.float16, 32: np.float32, 64: np.float64}[self.precision]
1315
 
 
1316
  Main.X = np.array(X, dtype=np_dtype).T
1317
  if len(y.shape) == 1:
1318
  Main.y = np.array(y, dtype=np_dtype)
@@ -1326,7 +1314,7 @@ class PySRRegressor(BaseEstimator, RegressorMixin, MultiOutputMixin):
1326
  else:
1327
  Main.weights = None
1328
 
1329
- cprocs = 0 if self.multithreading else self.procs
1330
 
1331
  # Call to Julia backend.
1332
  # See https://github.com/search?q=%22function+EquationSearch%22+repo%3AMilesCranmer%2FSymbolicRegression.jl+path%3A%2Fsrc%2F+filename%3ASymbolicRegression.jl+language%3AJulia&type=Code
@@ -1338,7 +1326,7 @@ class PySRRegressor(BaseEstimator, RegressorMixin, MultiOutputMixin):
1338
  varMap=self.feature_names_in_.tolist(),
1339
  options=options,
1340
  numprocs=int(cprocs),
1341
- multithreading=bool(self.multithreading),
1342
  saved_state=self.raw_julia_state_,
1343
  addprocs_function=cluster_manager,
1344
  )
@@ -1714,7 +1702,7 @@ class PySRRegressor(BaseEstimator, RegressorMixin, MultiOutputMixin):
1714
  if self.output_torch_format:
1715
  torch_format = []
1716
  local_sympy_mappings = {
1717
- **self.extra_sympy_mappings,
1718
  **sympy_mappings,
1719
  }
1720
 
@@ -1741,7 +1729,9 @@ class PySRRegressor(BaseEstimator, RegressorMixin, MultiOutputMixin):
1741
  eqn,
1742
  sympy_symbols,
1743
  selection=self.selection_mask_,
1744
- extra_jax_mappings=self.extra_jax_mappings,
 
 
1745
  )
1746
  jax_format.append({"callable": func, "parameters": params})
1747
 
@@ -1753,7 +1743,11 @@ class PySRRegressor(BaseEstimator, RegressorMixin, MultiOutputMixin):
1753
  eqn,
1754
  sympy_symbols,
1755
  selection=self.selection_mask_,
1756
- extra_torch_mappings=self.extra_torch_mappings,
 
 
 
 
1757
  )
1758
  torch_format.append(module)
1759
 
 
759
  f"{k} is not a valid keyword argument for PySRRegressor"
760
  )
761
 
 
 
762
  def __repr__(self):
763
  """
764
  Prints all current equations fitted by the model.
 
863
  f"{self.model_selection} is not a valid model selection strategy."
864
  )
865
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
866
  def _setup_equation_file(self):
867
  """
868
  Sets the full pathname of the equation file, using :param`tempdir` and
 
915
 
916
  """
917
 
918
+ # Ensure instance parameters are allowable values:
919
+ if self.tournament_selection_n > self.population_size:
920
+ raise ValueError(
921
+ "tournament_selection_n parameter must be smaller than population_size."
922
+ )
923
+
924
+ if self.maxsize > 40:
925
+ warnings.warn(
926
+ "Note: Using a large maxsize for the equation search will be exponentially slower and use significant memory. You should consider turning `use_frequency` to False, and perhaps use `warmup_maxsize_by`."
927
+ )
928
+ elif self.maxsize < 7:
929
+ raise ValueError("PySR requires a maxsize of at least 7")
930
+
931
+ if self.extra_jax_mappings is not None:
932
+ for value in self.extra_jax_mappings.values():
933
+ if not isinstance(value, str):
934
+ raise ValueError(
935
+ "extra_jax_mappings must have keys that are strings! e.g., {sympy.sqrt: 'jnp.sqrt'}."
936
+ )
937
+
938
+ if self.extra_torch_mappings is not None:
939
+ for value in self.extra_jax_mappings.values():
940
+ if not callable(value):
941
+ raise ValueError(
942
+ "extra_torch_mappings must be callable functions! e.g., {sympy.sqrt: torch.sqrt}."
943
+ )
944
+
945
+ # NotImplementedError - Values that could be supported at a later time
946
+ if self.optimizer_algorithm not in self.VALID_OPTIMIZER_ALGORITHMS:
947
+ raise NotImplementedError(
948
+ f"PySR currently only supports the following optimizer algorithms: {self.VALID_OPTIMIZER_ALGORITHMS}"
949
+ )
950
+
951
  if isinstance(X, pd.DataFrame):
952
  if variable_names:
953
  variable_names = None
 
1097
  global already_ran
1098
  global Main
1099
 
1100
+ # These are the parameters which may be modified from the ones
1101
+ # specified in init, so we define them here locally:
1102
+ binary_operators = self.binary_operators
1103
+ unary_operators = self.unary_operators
1104
+ constraints = self.constraints
1105
+ nested_constraints = self.nested_constraints
1106
+ complexity_of_operators = self.complexity_of_operators
1107
+ multithreading = self.multithreading
1108
+ update_verbosity = self.update_verbosity
1109
+ maxdepth = self.maxdepth
1110
+ batch_size = self.batch_size
1111
+ progress = self.progress
1112
+ cluster_manager = self.cluster_manager
1113
+
1114
+ # TODO: Clean this up into a readable format, such that
1115
+ # a function call automatically configures each default.
1116
+
1117
+ # Deal with default values, and type conversions:
1118
+ if binary_operators is None:
1119
+ binary_operators = "+ * - /".split(" ")
1120
+ elif isinstance(binary_operators, str):
1121
+ binary_operators = [binary_operators]
1122
+
1123
+ if unary_operators is None:
1124
+ unary_operators = []
1125
+ elif isinstance(unary_operators, str):
1126
+ unary_operators = [unary_operators]
1127
+
1128
+ if constraints is None:
1129
+ constraints = {}
1130
+
1131
+ if multithreading is None:
1132
+ # Default is multithreading=True, unless explicitly set,
1133
+ # or procs is set to 0 (serial mode).
1134
+ multithreading = self.procs != 0 and cluster_manager is None
1135
+
1136
+ if update_verbosity is None:
1137
+ update_verbosity = self.verbosity
1138
+
1139
+ if maxdepth is None:
1140
+ maxdepth = self.maxsize
1141
+
1142
+ # Warn if instance parameters are not sensible values:
1143
+ if batch_size < 1:
1144
+ warnings.warn(
1145
+ "Given :param`batch_size` must be greater than or equal to one. "
1146
+ ":param`batch_size` has been increased to equal one."
1147
+ )
1148
+ batch_size = 1
1149
+
1150
+ # Handle presentation of the progress bar:
1151
+ buffer_available = "buffer" in sys.stdout.__dir__()
1152
+ if progress is not None:
1153
+ if progress and not buffer_available:
1154
+ warnings.warn(
1155
+ "Note: it looks like you are running in Jupyter. The progress bar will be turned off."
1156
+ )
1157
+ progress = False
1158
+ else:
1159
+ progress = buffer_available
1160
+
1161
  # Start julia backend processes
1162
  if Main is None:
1163
+ if multithreading:
1164
  os.environ["JULIA_NUM_THREADS"] = str(self.procs)
1165
 
1166
  Main = init_julia()
1167
 
1168
+ if cluster_manager is not None:
1169
+ Main.eval(f"import ClusterManagers: addprocs_{cluster_manager}")
1170
+ cluster_manager = Main.eval(f"addprocs_{cluster_manager}")
 
 
1171
 
1172
  if not already_ran:
1173
  julia_project, is_shared = _get_julia_project(self.julia_project)
1174
  Main.eval("using Pkg")
1175
+ io = "devnull" if update_verbosity == 0 else "stderr"
1176
  io_arg = f"io={io}" if is_julia_version_greater_eq(Main, "1.6") else ""
1177
 
1178
  Main.eval(
 
1202
 
1203
  # TODO(mcranmer): These functions should be part of this class.
1204
  binary_operators, unary_operators = _maybe_create_inline_operators(
1205
+ binary_operators=binary_operators, unary_operators=unary_operators
1206
  )
1207
  constraints = _process_constraints(
1208
  binary_operators=binary_operators,
1209
  unary_operators=unary_operators,
1210
+ constraints=constraints,
1211
  )
1212
 
1213
  una_constraints = [constraints[op] for op in unary_operators]
1214
  bin_constraints = [constraints[op] for op in binary_operators]
1215
 
1216
  # Parse dict into Julia Dict for nested constraints::
1217
+ if nested_constraints is not None:
1218
  nested_constraints_str = "Dict("
1219
+ for outer_k, outer_v in nested_constraints.items():
1220
  nested_constraints_str += f"({outer_k}) => Dict("
1221
  for inner_k, inner_v in outer_v.items():
1222
  nested_constraints_str += f"({inner_k}) => {inner_v}, "
1223
  nested_constraints_str += "), "
1224
  nested_constraints_str += ")"
1225
  nested_constraints = Main.eval(nested_constraints_str)
 
 
1226
 
1227
  # Parse dict into Julia Dict for complexities:
1228
+ if complexity_of_operators is not None:
1229
  complexity_of_operators_str = "Dict("
1230
+ for k, v in complexity_of_operators.items():
1231
  complexity_of_operators_str += f"({k}) => {v}, "
1232
  complexity_of_operators_str += ")"
1233
  complexity_of_operators = Main.eval(complexity_of_operators_str)
 
 
1234
 
1235
  Main.custom_loss = Main.eval(self.loss)
1236
 
 
1261
  hofFile=_escape_filename(self.equation_file_),
1262
  npopulations=int(self.populations),
1263
  batching=self.batching,
1264
+ batchSize=int(min([batch_size, len(X)]) if self.batching else len(X)),
1265
  mutationWeights=mutationWeights,
1266
  probPickFirst=self.tournament_selection_p,
1267
  ns=self.tournament_selection_n,
1268
  # These have the same name:
1269
  parsimony=self.parsimony,
1270
  alpha=self.alpha,
1271
+ maxdepth=maxdepth,
1272
  fast_cycle=self.fast_cycle,
1273
  migration=self.migration,
1274
  hofMigration=self.hof_migration,
 
1289
  perturbationFactor=self.perturbation_factor,
1290
  annealing=self.annealing,
1291
  stateReturn=True, # Required for state saving.
1292
+ progress=progress,
1293
  timeout_in_seconds=self.timeout_in_seconds,
1294
  crossoverProbability=self.crossover_probability,
1295
  skip_mutation_failures=self.skip_mutation_failures,
 
1300
  # Convert data to desired precision
1301
  np_dtype = {16: np.float16, 32: np.float32, 64: np.float64}[self.precision]
1302
 
1303
+ # This converts the data into a Julia array:
1304
  Main.X = np.array(X, dtype=np_dtype).T
1305
  if len(y.shape) == 1:
1306
  Main.y = np.array(y, dtype=np_dtype)
 
1314
  else:
1315
  Main.weights = None
1316
 
1317
+ cprocs = 0 if multithreading else self.procs
1318
 
1319
  # Call to Julia backend.
1320
  # See https://github.com/search?q=%22function+EquationSearch%22+repo%3AMilesCranmer%2FSymbolicRegression.jl+path%3A%2Fsrc%2F+filename%3ASymbolicRegression.jl+language%3AJulia&type=Code
 
1326
  varMap=self.feature_names_in_.tolist(),
1327
  options=options,
1328
  numprocs=int(cprocs),
1329
+ multithreading=bool(multithreading),
1330
  saved_state=self.raw_julia_state_,
1331
  addprocs_function=cluster_manager,
1332
  )
 
1702
  if self.output_torch_format:
1703
  torch_format = []
1704
  local_sympy_mappings = {
1705
+ **(self.extra_sympy_mappings if self.extra_sympy_mappings else {}),
1706
  **sympy_mappings,
1707
  }
1708
 
 
1729
  eqn,
1730
  sympy_symbols,
1731
  selection=self.selection_mask_,
1732
+ extra_jax_mappings=(
1733
+ self.extra_jax_mappings if self.extra_jax_mappings else {}
1734
+ ),
1735
  )
1736
  jax_format.append({"callable": func, "parameters": params})
1737
 
 
1743
  eqn,
1744
  sympy_symbols,
1745
  selection=self.selection_mask_,
1746
+ extra_torch_mappings=(
1747
+ self.extra_torch_mappings
1748
+ if self.extra_torch_mappings
1749
+ else {}
1750
+ ),
1751
  )
1752
  torch_format.append(module)
1753