MilesCranmer commited on
Commit
d11773b
1 Parent(s): 787cf18

Fix varMap for no feature selection

Browse files
Files changed (1) hide show
  1. pysr/sr.py +9 -6
pysr/sr.py CHANGED
@@ -343,6 +343,7 @@ def pysr(
343
  )
344
 
345
  use_custom_variable_names = len(variable_names) != 0
 
346
 
347
  _check_assertions(
348
  X,
@@ -366,7 +367,9 @@ def pysr(
366
  if maxsize < 7:
367
  raise NotImplementedError("PySR requires a maxsize of at least 7")
368
 
369
- X, selection = _handle_feature_selection(X, select_k_features, y, variable_names)
 
 
370
 
371
  if maxdepth is None:
372
  maxdepth = maxsize
@@ -540,7 +543,7 @@ Tried to activate project {julia_project} but failed."""
540
  Main.y,
541
  weights=Main.weights,
542
  niterations=int(niterations),
543
- varMap=[variable_names[i] for i in selection],
544
  options=options,
545
  numprocs=int(cprocs),
546
  multithreading=bool(multithreading),
@@ -654,7 +657,9 @@ def _create_inline_operators(binary_operators, unary_operators):
654
  op_list[i] = function_name
655
 
656
 
657
- def _handle_feature_selection(X, select_k_features, y, variable_names):
 
 
658
  if select_k_features is not None:
659
  selection = run_feature_selection(X, y, select_k_features)
660
  print(f"Using features {[variable_names[i] for i in selection]}")
@@ -796,9 +801,7 @@ def get_hof(
796
  sympy_format.append(eqn)
797
 
798
  # Numpy:
799
- lambda_format.append(
800
- CallableEquation(sympy_symbols, eqn, selection, variable_names)
801
- )
802
 
803
  # JAX:
804
  if output_jax_format:
 
343
  )
344
 
345
  use_custom_variable_names = len(variable_names) != 0
346
+ # TODO: this is always true.
347
 
348
  _check_assertions(
349
  X,
 
367
  if maxsize < 7:
368
  raise NotImplementedError("PySR requires a maxsize of at least 7")
369
 
370
+ X, selection = _handle_feature_selection(
371
+ X, select_k_features, y, variable_names
372
+ )
373
 
374
  if maxdepth is None:
375
  maxdepth = maxsize
 
543
  Main.y,
544
  weights=Main.weights,
545
  niterations=int(niterations),
546
+ varMap=([variable_names[i] for i in selection] if selection else None),
547
  options=options,
548
  numprocs=int(cprocs),
549
  multithreading=bool(multithreading),
 
657
  op_list[i] = function_name
658
 
659
 
660
+ def _handle_feature_selection(
661
+ X, select_k_features, y, variable_names
662
+ ):
663
  if select_k_features is not None:
664
  selection = run_feature_selection(X, y, select_k_features)
665
  print(f"Using features {[variable_names[i] for i in selection]}")
 
801
  sympy_format.append(eqn)
802
 
803
  # Numpy:
804
+ lambda_format.append(CallableEquation(sympy_symbols, eqn, selection, variable_names))
 
 
805
 
806
  # JAX:
807
  if output_jax_format: