AutonLabTruth commited on
Commit
181a454
·
1 Parent(s): 762987c

Refactored till handle_constraints

Browse files
Files changed (1) hide show
  1. pysr/sr.py +73 -55
pysr/sr.py CHANGED
@@ -207,75 +207,29 @@ def pysr(X=None, y=None, weights=None,
207
 
208
  check_assertions(X, binary_operators, unary_operators, use_custom_variable_names, variable_names, weights, y)
209
 
 
 
210
  if maxdepth is None:
211
  maxdepth = maxsize
212
  if equation_file is None:
213
  date_time = datetime.now().strftime("%Y-%m-%d_%H%M%S.%f")[:-3]
214
  equation_file = 'hall_of_fame_' + date_time + '.csv'
215
-
216
- if select_k_features is not None:
217
- selection = run_feature_selection(X, y, select_k_features)
218
- print(f"Using features {selection}")
219
- X = X[:, selection]
220
-
221
- if use_custom_variable_names:
222
- variable_names = [variable_names[selection[i]] for i in range(len(selection))]
223
-
224
  if populations is None:
225
  populations = procs
226
-
227
- if isinstance(binary_operators, str): binary_operators = [binary_operators]
228
- if isinstance(unary_operators, str): unary_operators = [unary_operators]
229
-
230
  if X is None:
231
- if test == 'simple1':
232
- eval_str = "np.sign(X[:, 2])*np.abs(X[:, 2])**2.5 + 5*np.cos(X[:, 3]) - 5"
233
- elif test == 'simple2':
234
- eval_str = "np.sign(X[:, 2])*np.abs(X[:, 2])**3.5 + 1/(np.abs(X[:, 0])+1)"
235
- elif test == 'simple3':
236
- eval_str = "np.exp(X[:, 0]/2) + 12.0 + np.log(np.abs(X[:, 0])*10 + 1)"
237
- elif test == 'simple4':
238
- eval_str = "1.0 + 3*X[:, 0]**2 - 0.5*X[:, 0]**3 + 0.1*X[:, 0]**4"
239
- elif test == 'simple5':
240
- eval_str = "(np.exp(X[:, 3]) + 3)/(np.abs(X[:, 1]) + np.cos(X[:, 0]) + 1.1)"
241
-
242
- X = np.random.randn(100, 5)*3
243
- y = eval(eval_str)
244
- print("Running on", eval_str)
245
 
246
  def_hyperparams = ""
247
 
248
  # Add pre-defined functions to Julia
249
- for op_list in [binary_operators, unary_operators]:
250
- for i in range(len(op_list)):
251
- op = op_list[i]
252
- is_user_defined_operator = '(' in op
253
-
254
- if is_user_defined_operator:
255
- def_hyperparams += op + "\n"
256
- # Cut off from the first non-alphanumeric char:
257
- first_non_char = [
258
- j for j in range(len(op))
259
- if not (op[j].isalpha() or op[j].isdigit())][0]
260
- function_name = op[:first_non_char]
261
- op_list[i] = function_name
262
 
263
  #arbitrary complexity by default
264
- for op in unary_operators:
265
- if op not in constraints:
266
- constraints[op] = -1
267
- for op in binary_operators:
268
- if op not in constraints:
269
- constraints[op] = (-1, -1)
270
- if op in ['plus', 'sub']:
271
- if constraints[op][0] != constraints[op][1]:
272
- raise NotImplementedError("You need equal constraints on both sides for - and *, due to simplification strategies.")
273
- elif op == 'mult':
274
- # Make sure the complex expression is in the left side.
275
- if constraints[op][0] == -1:
276
- continue
277
- elif constraints[op][1] == -1 or constraints[op][0] < constraints[op][1]:
278
- constraints[op][0], constraints[op][1] = constraints[op][1], constraints[op][0]
279
 
280
  constraints_str = "const una_constraints = ["
281
  first = True
@@ -445,6 +399,70 @@ const varMap = {'["' + '", "'.join(variable_names) + '"]'}"""
445
  return get_hof()
446
 
447
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
448
  def set_paths(tempdir):
449
  # System-independent paths
450
  pkg_directory = Path(__file__).parents[1] / 'julia'
 
207
 
208
  check_assertions(X, binary_operators, unary_operators, use_custom_variable_names, variable_names, weights, y)
209
 
210
+ X, variable_names = handle_feature_selection(X, select_k_features, use_custom_variable_names, variable_names, y)
211
+
212
  if maxdepth is None:
213
  maxdepth = maxsize
214
  if equation_file is None:
215
  date_time = datetime.now().strftime("%Y-%m-%d_%H%M%S.%f")[:-3]
216
  equation_file = 'hall_of_fame_' + date_time + '.csv'
 
 
 
 
 
 
 
 
 
217
  if populations is None:
218
  populations = procs
219
+ if isinstance(binary_operators, str):
220
+ binary_operators = [binary_operators]
221
+ if isinstance(unary_operators, str):
222
+ unary_operators = [unary_operators]
223
  if X is None:
224
+ X, y = using_test_input(X, test, y)
 
 
 
 
 
 
 
 
 
 
 
 
 
225
 
226
  def_hyperparams = ""
227
 
228
  # Add pre-defined functions to Julia
229
+ def_hyperparams = predefined_function_addition(binary_operators, def_hyperparams, unary_operators)
 
 
 
 
 
 
 
 
 
 
 
 
230
 
231
  #arbitrary complexity by default
232
+ handle_constraints(binary_operators, constraints, unary_operators)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
233
 
234
  constraints_str = "const una_constraints = ["
235
  first = True
 
399
  return get_hof()
400
 
401
 
402
+ def handle_constraints(binary_operators, constraints, unary_operators):
403
+ for op in unary_operators:
404
+ if op not in constraints:
405
+ constraints[op] = -1
406
+ for op in binary_operators:
407
+ if op not in constraints:
408
+ constraints[op] = (-1, -1)
409
+ if op in ['plus', 'sub']:
410
+ if constraints[op][0] != constraints[op][1]:
411
+ raise NotImplementedError(
412
+ "You need equal constraints on both sides for - and *, due to simplification strategies.")
413
+ elif op == 'mult':
414
+ # Make sure the complex expression is in the left side.
415
+ if constraints[op][0] == -1:
416
+ continue
417
+ elif constraints[op][1] == -1 or constraints[op][0] < constraints[op][1]:
418
+ constraints[op][0], constraints[op][1] = constraints[op][1], constraints[op][0]
419
+
420
+
421
+ def predefined_function_addition(binary_operators, def_hyperparams, unary_operators):
422
+ for op_list in [binary_operators, unary_operators]:
423
+ for i in range(len(op_list)):
424
+ op = op_list[i]
425
+ is_user_defined_operator = '(' in op
426
+
427
+ if is_user_defined_operator:
428
+ def_hyperparams += op + "\n"
429
+ # Cut off from the first non-alphanumeric char:
430
+ first_non_char = [
431
+ j for j in range(len(op))
432
+ if not (op[j].isalpha() or op[j].isdigit())][0]
433
+ function_name = op[:first_non_char]
434
+ op_list[i] = function_name
435
+ return def_hyperparams
436
+
437
+
438
+ def using_test_input(X, test, y):
439
+ if test == 'simple1':
440
+ eval_str = "np.sign(X[:, 2])*np.abs(X[:, 2])**2.5 + 5*np.cos(X[:, 3]) - 5"
441
+ elif test == 'simple2':
442
+ eval_str = "np.sign(X[:, 2])*np.abs(X[:, 2])**3.5 + 1/(np.abs(X[:, 0])+1)"
443
+ elif test == 'simple3':
444
+ eval_str = "np.exp(X[:, 0]/2) + 12.0 + np.log(np.abs(X[:, 0])*10 + 1)"
445
+ elif test == 'simple4':
446
+ eval_str = "1.0 + 3*X[:, 0]**2 - 0.5*X[:, 0]**3 + 0.1*X[:, 0]**4"
447
+ elif test == 'simple5':
448
+ eval_str = "(np.exp(X[:, 3]) + 3)/(np.abs(X[:, 1]) + np.cos(X[:, 0]) + 1.1)"
449
+ X = np.random.randn(100, 5) * 3
450
+ y = eval(eval_str)
451
+ print("Running on", eval_str)
452
+ return X, y
453
+
454
+
455
+ def handle_feature_selection(X, select_k_features, use_custom_variable_names, variable_names, y):
456
+ if select_k_features is not None:
457
+ selection = run_feature_selection(X, y, select_k_features)
458
+ print(f"Using features {selection}")
459
+ X = X[:, selection]
460
+
461
+ if use_custom_variable_names:
462
+ variable_names = [variable_names[selection[i]] for i in range(len(selection))]
463
+ return X, variable_names
464
+
465
+
466
  def set_paths(tempdir):
467
  # System-independent paths
468
  pkg_directory = Path(__file__).parents[1] / 'julia'