tttc3 commited on
Commit
9490776
1 Parent(s): c7187a6

fixed issues from deepsource

Browse files
Files changed (2) hide show
  1. pysr/julia_helpers.py +5 -3
  2. pysr/sr.py +62 -62
pysr/julia_helpers.py CHANGED
@@ -6,9 +6,11 @@ from .version import __version__, __symbolic_regression_jl_version__
6
 
7
 
8
  def install(julia_project=None, quiet=False): # pragma: no cover
9
- """Install PyCall.jl and all required dependencies for SymbolicRegression.jl.
 
10
 
11
- Also updates the local Julia registry."""
 
12
  import julia
13
 
14
  julia.install(quiet=quiet)
@@ -40,7 +42,7 @@ def install(julia_project=None, quiet=False): # pragma: no cover
40
 
41
 
42
  def import_error_string(julia_project=None):
43
- s = f"""
44
  Required dependencies are not installed or built. Run the following code in the Python REPL:
45
 
46
  >>> import pysr
 
6
 
7
 
8
  def install(julia_project=None, quiet=False): # pragma: no cover
9
+ """
10
+ Install PyCall.jl and all required dependencies for SymbolicRegression.jl.
11
 
12
+ Also updates the local Julia registry.
13
+ """
14
  import julia
15
 
16
  julia.install(quiet=quiet)
 
42
 
43
 
44
  def import_error_string(julia_project=None):
45
+ s = """
46
  Required dependencies are not installed or built. Run the following code in the Python REPL:
47
 
48
  >>> import pysr
pysr/sr.py CHANGED
@@ -11,6 +11,7 @@ from pathlib import Path
11
  from datetime import datetime
12
  import warnings
13
  from multiprocessing import cpu_count
 
14
  from sklearn.base import BaseEstimator, RegressorMixin, MultiOutputMixin
15
  from sklearn.utils.validation import _check_feature_names_in, check_is_fitted
16
 
@@ -189,10 +190,9 @@ class CallableEquation:
189
  return self._lambda(
190
  **{k: X[k].values for k in self._variable_names}
191
  ) * np.ones(expected_shape)
192
- else:
193
- if self._selection is not None:
194
- X = X[:, self._selection]
195
- return self._lambda(*X.T) * np.ones(expected_shape)
196
 
197
 
198
  class PySRRegressor(BaseEstimator, RegressorMixin, MultiOutputMixin):
@@ -349,13 +349,15 @@ class PySRRegressor(BaseEstimator, RegressorMixin, MultiOutputMixin):
349
  Relative likelihood for mutation to leave the individual.
350
 
351
  weight_mutate_constant : float, default=0.048
352
- Relative likelihood for mutation to change the constant slightly in a random direction.
 
353
 
354
  weight_mutate_operator : float, default=0.47
355
  Relative likelihood for mutation to swap an operator.
356
 
357
  weight_randomize : float, default=0.00023
358
- Relative likelihood for mutation to completely delete and then randomly generate the equation
 
359
 
360
  weight_simplify : float, default=0.0020
361
  Relative likelihood for mutation to simplify constant parts by evaluation
@@ -666,13 +668,13 @@ class PySRRegressor(BaseEstimator, RegressorMixin, MultiOutputMixin):
666
  ):
667
 
668
  # Hyperparameters
669
- ## Model search parameters
670
  self.model_selection = model_selection
671
  self.binary_operators = binary_operators
672
  self.unary_operators = unary_operators
673
  self.niterations = niterations
674
  self.populations = populations
675
- ## Model search Constraints
676
  self.population_size = population_size
677
  self.max_evals = max_evals
678
  self.maxsize = maxsize
@@ -681,7 +683,7 @@ class PySRRegressor(BaseEstimator, RegressorMixin, MultiOutputMixin):
681
  self.timeout_in_seconds = timeout_in_seconds
682
  self.constraints = constraints
683
  self.nested_constraints = nested_constraints
684
- ## Loss parameters
685
  self.loss = loss
686
  self.complexity_of_operators = complexity_of_operators
687
  self.complexity_of_constants = complexity_of_constants
@@ -692,8 +694,8 @@ class PySRRegressor(BaseEstimator, RegressorMixin, MultiOutputMixin):
692
  self.alpha = alpha
693
  self.annealing = annealing
694
  self.early_stop_condition = early_stop_condition
695
- ## Evolutionary search parameters
696
- ### Mutation parameters
697
  self.ncyclesperiteration = ncyclesperiteration
698
  self.fraction_replaced = fraction_replaced
699
  self.fraction_replaced_hof = fraction_replaced_hof
@@ -707,18 +709,18 @@ class PySRRegressor(BaseEstimator, RegressorMixin, MultiOutputMixin):
707
  self.weight_simplify = weight_simplify
708
  self.crossover_probability = crossover_probability
709
  self.skip_mutation_failures = skip_mutation_failures
710
- ### Migration parameters
711
  self.migration = migration
712
  self.hof_migration = hof_migration
713
  self.topn = topn
714
- ### Constants parameters
715
  self.should_optimize_constants = should_optimize_constants
716
  self.optimizer_algorithm = optimizer_algorithm
717
  self.optimizer_nrestarts = optimizer_nrestarts
718
  self.optimize_probability = optimize_probability
719
  self.optimizer_iterations = optimizer_iterations
720
  self.perturbation_factor = perturbation_factor
721
- ### Selection parameters
722
  self.tournament_selection_n = tournament_selection_n
723
  self.tournament_selection_p = tournament_selection_p
724
  # Solver parameters
@@ -730,11 +732,11 @@ class PySRRegressor(BaseEstimator, RegressorMixin, MultiOutputMixin):
730
  self.fast_cycle = fast_cycle
731
  self.precision = precision
732
  # Additional runtime parameters
733
- ## Runtime user interface
734
  self.verbosity = verbosity
735
  self.update_verbosity = update_verbosity
736
  self.progress = progress
737
- ## Project management
738
  self.equation_file = equation_file
739
  self.temp_equation_file = temp_equation_file
740
  self.tempdir = tempdir
@@ -921,29 +923,27 @@ class PySRRegressor(BaseEstimator, RegressorMixin, MultiOutputMixin):
921
  # Warn if instance parameters are not sensible values:
922
  if self.batch_size < 1:
923
  warnings.warn(
924
- f"Given :param`batch_size` must be greater than or equal to one."
925
- f":param`batch_size` has been increased to equal one."
926
  )
927
  self.batch_size = 1
928
 
929
  if n_samples > 10000 and not self.batching:
930
  warnings.warn(
931
- """
932
- Note: you are running with more than 10,000 datapoints.
933
- You should consider turning on batching (https://astroautomata.com/PySR/#/options?id=batching).
934
- You should also reconsider if you need that many datapoints.
935
- Unless you have a large amount of noise (in which case you
936
- should smooth your dataset first), generally < 10,000 datapoints
937
- is enough to find a functional form with symbolic regression.
938
- More datapoints will lower the search speed."
939
- """,
940
  )
941
 
942
  # Ensure instance parameters are allowable values:
943
  # ValueError - Incompatible values
944
- if not (self.tournament_selection_n < self.population_size):
945
  raise ValueError(
946
- f"tournament_selection_n parameter must be smaller than population_size"
947
  )
948
 
949
  if self.maxsize > 40:
@@ -951,7 +951,7 @@ class PySRRegressor(BaseEstimator, RegressorMixin, MultiOutputMixin):
951
  "Note: Using a large maxsize for the equation search will be exponentially slower and use significant memory. You should consider turning `use_frequency` to False, and perhaps use `warmup_maxsize_by`."
952
  )
953
  elif self.maxsize < 7:
954
- raise ValueError(f"PySR requires a maxsize of at least 7")
955
 
956
  if self.extra_jax_mappings is not None:
957
  for value in self.extra_jax_mappings.values():
@@ -971,7 +971,7 @@ class PySRRegressor(BaseEstimator, RegressorMixin, MultiOutputMixin):
971
  else:
972
  self.extra_torch_mappings = {}
973
 
974
- # NotImplementedError - Currently incompatible values that could be supported later
975
  if self.optimizer_algorithm not in self.VALID_OPTIMIZER_ALGORITHMS:
976
  raise NotImplementedError(
977
  f"PySR currently only supports the following optimizer algorithms: {self.VALID_OPTIMIZER_ALGORITHMS}"
@@ -1004,7 +1004,8 @@ class PySRRegressor(BaseEstimator, RegressorMixin, MultiOutputMixin):
1004
  y : {ndarray | pandas.DataFrame} of shape (n_samples,) or (n_samples, n_targets)
1005
  Target values. Will be cast to X's dtype if necessary.
1006
 
1007
- Xresampled : {ndarray | pandas.DataFrame} of shape (n_resampled, n_features), default=None
 
1008
  Resampled training data used for denoising.
1009
 
1010
  variable_names : list[str] of length n_features
@@ -1022,7 +1023,6 @@ class PySRRegressor(BaseEstimator, RegressorMixin, MultiOutputMixin):
1022
  Validated list of variable names for each feature in `X`.
1023
 
1024
  """
1025
-
1026
  if isinstance(X, pd.DataFrame):
1027
  variable_names = None
1028
  warnings.warn(
@@ -1037,14 +1037,13 @@ class PySRRegressor(BaseEstimator, RegressorMixin, MultiOutputMixin):
1037
  "Spaces have been replaced with underscores. \n"
1038
  "Please rename the columns to valid names."
1039
  )
1040
- elif variable_names:
1041
- if [" " in name for name in variable_names].any():
1042
- variable_names = [name.replace(" ", "_") for name in variable_names]
1043
- warnings.warn(
1044
- "Spaces in `variable_names` are not supported. "
1045
- "Spaces have been replaced with underscores. \n"
1046
- "Please use valid names instead."
1047
- )
1048
  # Only numpy values are needed from Xresampled, column metadata is
1049
  # provided by X
1050
  if isinstance(Xresampled, pd.DataFrame):
@@ -1080,7 +1079,8 @@ class PySRRegressor(BaseEstimator, RegressorMixin, MultiOutputMixin):
1080
  y : {ndarray | pandas.DataFrame} of shape (n_samples,) or (n_samples, n_targets)
1081
  Target values. Will be cast to X's dtype if necessary.
1082
 
1083
- Xresampled : {ndarray | pandas.DataFrame} of shape (n_resampled, n_features), default=None
 
1084
  Resampled training data used for denoising.
1085
 
1086
  variable_names : list[str] of length n_features
@@ -1118,17 +1118,17 @@ class PySRRegressor(BaseEstimator, RegressorMixin, MultiOutputMixin):
1118
  variable_names = [variable_names[i] for i in self.selection_mask_]
1119
 
1120
  # Re-perform data validation and feature name updating
1121
- X, y_transformed = self._validate_data(
1122
  X=X, y=y, reset=True, multi_output=True
1123
  )
1124
  # Update feature names with selected variable names
1125
  self.feature_names_in_ = _check_feature_names_in(self, variable_names)
1126
- print(f"Using features {[name for name in self.feature_names_in_]}")
1127
 
1128
  # Denoising transformation
1129
  if self.denoise:
1130
  if self.nout_ > 1:
1131
- y_transformed = np.stack(
1132
  [
1133
  _denoise(X, y[:, i], Xresampled=Xresampled)[1]
1134
  for i in range(self.nout_)
@@ -1168,8 +1168,8 @@ class PySRRegressor(BaseEstimator, RegressorMixin, MultiOutputMixin):
1168
  ImportError
1169
  Raised when the julia backend fails to import a package.
1170
  """
1171
-
1172
- # Need to be global as we don't want to recreate/reinstate julia for every new instance of PySRRegressor
1173
  global already_ran
1174
  global Main
1175
 
@@ -1379,7 +1379,8 @@ class PySRRegressor(BaseEstimator, RegressorMixin, MultiOutputMixin):
1379
  y : {ndarray | pandas.DataFrame} of shape (n_samples,) or (n_samples, n_targets)
1380
  Target values. Will be cast to X's dtype if necessary.
1381
 
1382
- Xresampled : {ndarray | pandas.DataFrame} of shape (n_resampled, n_features), default=None
 
1383
  Resampled training data used for denoising.
1384
 
1385
  weights : {ndarray | pandas.DataFrame} of the same shape as y, default=None
@@ -1420,7 +1421,8 @@ class PySRRegressor(BaseEstimator, RegressorMixin, MultiOutputMixin):
1420
  X, y, Xresampled, variable_names
1421
  )
1422
 
1423
- # Warn about large feature counts (still warn if feature count is large after running feature selection)
 
1424
  if self.n_features_in_ >= 10:
1425
  warnings.warn(
1426
  "Note: you are running with 10 features or more. "
@@ -1512,8 +1514,8 @@ class PySRRegressor(BaseEstimator, RegressorMixin, MultiOutputMixin):
1512
  def predict(self, X, index=None):
1513
  """Predict y from input X using the equation chosen by `model_selection`.
1514
 
1515
- You may see what equation is used by printing this object. X should have the same
1516
- columns as the training data.
1517
 
1518
  Parameters
1519
  ----------
@@ -1550,10 +1552,10 @@ class PySRRegressor(BaseEstimator, RegressorMixin, MultiOutputMixin):
1550
  SymPy representation of the best equation.
1551
  """
1552
  self.refresh()
1553
- best = self.get_best(index=index)
1554
  if self.nout_ > 1:
1555
- return [eq["sympy_format"] for eq in best]
1556
- return best["sympy_format"]
1557
 
1558
  def latex(self, index=None):
1559
  """Return latex representation of the equation(s) chosen by `model_selection`.
@@ -1596,13 +1598,12 @@ class PySRRegressor(BaseEstimator, RegressorMixin, MultiOutputMixin):
1596
  Dictionary of callable jax function in "callable" key,
1597
  and jax array of parameters as "parameters" key.
1598
  """
1599
-
1600
  self.set_params(output_jax_format=True)
1601
  self.refresh()
1602
- best = self.get_best(index=index)
1603
  if self.nout_ > 1:
1604
- return [eq["jax_format"] for eq in best]
1605
- return best["jax_format"]
1606
 
1607
  def pytorch(self, index=None):
1608
  """Return pytorch representation of the equation(s) chosen by `model_selection`.
@@ -1626,10 +1627,10 @@ class PySRRegressor(BaseEstimator, RegressorMixin, MultiOutputMixin):
1626
  """
1627
  self.set_params(output_torch_format=True)
1628
  self.refresh()
1629
- best = self.get_best(index=index)
1630
  if self.nout_ > 1:
1631
- return [eq["torch_format"] for eq in best]
1632
- return best["torch_format"]
1633
 
1634
  def get_hof(self):
1635
  """Get the equations from a hall of fame file. If no arguments
@@ -1796,7 +1797,6 @@ def run_feature_selection(X, y, select_k_features):
1796
  """Use a gradient boosting tree regressor as a proxy for finding
1797
  the k most important features in X, returning indices for those
1798
  features as output."""
1799
-
1800
  from sklearn.ensemble import RandomForestRegressor
1801
  from sklearn.feature_selection import SelectFromModel
1802
 
 
11
  from datetime import datetime
12
  import warnings
13
  from multiprocessing import cpu_count
14
+ from sklearn.linear_model import LinearRegression
15
  from sklearn.base import BaseEstimator, RegressorMixin, MultiOutputMixin
16
  from sklearn.utils.validation import _check_feature_names_in, check_is_fitted
17
 
 
190
  return self._lambda(
191
  **{k: X[k].values for k in self._variable_names}
192
  ) * np.ones(expected_shape)
193
+ if self._selection is not None:
194
+ X = X[:, self._selection]
195
+ return self._lambda(*X.T) * np.ones(expected_shape)
 
196
 
197
 
198
  class PySRRegressor(BaseEstimator, RegressorMixin, MultiOutputMixin):
 
349
  Relative likelihood for mutation to leave the individual.
350
 
351
  weight_mutate_constant : float, default=0.048
352
+ Relative likelihood for mutation to change the constant slightly
353
+ in a random direction.
354
 
355
  weight_mutate_operator : float, default=0.47
356
  Relative likelihood for mutation to swap an operator.
357
 
358
  weight_randomize : float, default=0.00023
359
+ Relative likelihood for mutation to completely delete and then
360
+ randomly generate the equation
361
 
362
  weight_simplify : float, default=0.0020
363
  Relative likelihood for mutation to simplify constant parts by evaluation
 
668
  ):
669
 
670
  # Hyperparameters
671
+ # - Model search parameters
672
  self.model_selection = model_selection
673
  self.binary_operators = binary_operators
674
  self.unary_operators = unary_operators
675
  self.niterations = niterations
676
  self.populations = populations
677
+ # - Model search Constraints
678
  self.population_size = population_size
679
  self.max_evals = max_evals
680
  self.maxsize = maxsize
 
683
  self.timeout_in_seconds = timeout_in_seconds
684
  self.constraints = constraints
685
  self.nested_constraints = nested_constraints
686
+ # - Loss parameters
687
  self.loss = loss
688
  self.complexity_of_operators = complexity_of_operators
689
  self.complexity_of_constants = complexity_of_constants
 
694
  self.alpha = alpha
695
  self.annealing = annealing
696
  self.early_stop_condition = early_stop_condition
697
+ # - Evolutionary search parameters
698
+ # -- Mutation parameters
699
  self.ncyclesperiteration = ncyclesperiteration
700
  self.fraction_replaced = fraction_replaced
701
  self.fraction_replaced_hof = fraction_replaced_hof
 
709
  self.weight_simplify = weight_simplify
710
  self.crossover_probability = crossover_probability
711
  self.skip_mutation_failures = skip_mutation_failures
712
+ # -- Migration parameters
713
  self.migration = migration
714
  self.hof_migration = hof_migration
715
  self.topn = topn
716
+ # -- Constants parameters
717
  self.should_optimize_constants = should_optimize_constants
718
  self.optimizer_algorithm = optimizer_algorithm
719
  self.optimizer_nrestarts = optimizer_nrestarts
720
  self.optimize_probability = optimize_probability
721
  self.optimizer_iterations = optimizer_iterations
722
  self.perturbation_factor = perturbation_factor
723
+ # -- Selection parameters
724
  self.tournament_selection_n = tournament_selection_n
725
  self.tournament_selection_p = tournament_selection_p
726
  # Solver parameters
 
732
  self.fast_cycle = fast_cycle
733
  self.precision = precision
734
  # Additional runtime parameters
735
+ # - Runtime user interface
736
  self.verbosity = verbosity
737
  self.update_verbosity = update_verbosity
738
  self.progress = progress
739
+ # - Project management
740
  self.equation_file = equation_file
741
  self.temp_equation_file = temp_equation_file
742
  self.tempdir = tempdir
 
923
  # Warn if instance parameters are not sensible values:
924
  if self.batch_size < 1:
925
  warnings.warn(
926
+ "Given :param`batch_size` must be greater than or equal to one. "
927
+ ":param`batch_size` has been increased to equal one."
928
  )
929
  self.batch_size = 1
930
 
931
  if n_samples > 10000 and not self.batching:
932
  warnings.warn(
933
+ "Note: you are running with more than 10,000 datapoints. "
934
+ "You should consider turning on batching (https://astroautomata.com/PySR/#/options?id=batching). "
935
+ "You should also reconsider if you need that many datapoints. "
936
+ "Unless you have a large amount of noise (in which case you "
937
+ "should smooth your dataset first), generally < 10,000 datapoints "
938
+ "is enough to find a functional form with symbolic regression. "
939
+ "More datapoints will lower the search speed."
 
 
940
  )
941
 
942
  # Ensure instance parameters are allowable values:
943
  # ValueError - Incompatible values
944
+ if self.tournament_selection_n > self.population_size:
945
  raise ValueError(
946
+ "tournament_selection_n parameter must be smaller than population_size."
947
  )
948
 
949
  if self.maxsize > 40:
 
951
  "Note: Using a large maxsize for the equation search will be exponentially slower and use significant memory. You should consider turning `use_frequency` to False, and perhaps use `warmup_maxsize_by`."
952
  )
953
  elif self.maxsize < 7:
954
+ raise ValueError("PySR requires a maxsize of at least 7")
955
 
956
  if self.extra_jax_mappings is not None:
957
  for value in self.extra_jax_mappings.values():
 
971
  else:
972
  self.extra_torch_mappings = {}
973
 
974
+ # NotImplementedError - Values that could be supported at a later time
975
  if self.optimizer_algorithm not in self.VALID_OPTIMIZER_ALGORITHMS:
976
  raise NotImplementedError(
977
  f"PySR currently only supports the following optimizer algorithms: {self.VALID_OPTIMIZER_ALGORITHMS}"
 
1004
  y : {ndarray | pandas.DataFrame} of shape (n_samples,) or (n_samples, n_targets)
1005
  Target values. Will be cast to X's dtype if necessary.
1006
 
1007
+ Xresampled : {ndarray | pandas.DataFrame} of shape
1008
+ (n_resampled, n_features), default=None
1009
  Resampled training data used for denoising.
1010
 
1011
  variable_names : list[str] of length n_features
 
1023
  Validated list of variable names for each feature in `X`.
1024
 
1025
  """
 
1026
  if isinstance(X, pd.DataFrame):
1027
  variable_names = None
1028
  warnings.warn(
 
1037
  "Spaces have been replaced with underscores. \n"
1038
  "Please rename the columns to valid names."
1039
  )
1040
+ elif variable_names and [" " in name for name in variable_names].any():
1041
+ variable_names = [name.replace(" ", "_") for name in variable_names]
1042
+ warnings.warn(
1043
+ "Spaces in `variable_names` are not supported. "
1044
+ "Spaces have been replaced with underscores. \n"
1045
+ "Please use valid names instead."
1046
+ )
 
1047
  # Only numpy values are needed from Xresampled, column metadata is
1048
  # provided by X
1049
  if isinstance(Xresampled, pd.DataFrame):
 
1079
  y : {ndarray | pandas.DataFrame} of shape (n_samples,) or (n_samples, n_targets)
1080
  Target values. Will be cast to X's dtype if necessary.
1081
 
1082
+ Xresampled : {ndarray | pandas.DataFrame} of shape
1083
+ (n_resampled, n_features), default=None
1084
  Resampled training data used for denoising.
1085
 
1086
  variable_names : list[str] of length n_features
 
1118
  variable_names = [variable_names[i] for i in self.selection_mask_]
1119
 
1120
  # Re-perform data validation and feature name updating
1121
+ X, y = self._validate_data(
1122
  X=X, y=y, reset=True, multi_output=True
1123
  )
1124
  # Update feature names with selected variable names
1125
  self.feature_names_in_ = _check_feature_names_in(self, variable_names)
1126
+ print(f"Using features {self.feature_names_in_}")
1127
 
1128
  # Denoising transformation
1129
  if self.denoise:
1130
  if self.nout_ > 1:
1131
+ y = np.stack(
1132
  [
1133
  _denoise(X, y[:, i], Xresampled=Xresampled)[1]
1134
  for i in range(self.nout_)
 
1168
  ImportError
1169
  Raised when the julia backend fails to import a package.
1170
  """
1171
+ # Need to be global as we don't want to recreate/reinstate julia for
1172
+ # every new instance of PySRRegressor
1173
  global already_ran
1174
  global Main
1175
 
 
1379
  y : {ndarray | pandas.DataFrame} of shape (n_samples,) or (n_samples, n_targets)
1380
  Target values. Will be cast to X's dtype if necessary.
1381
 
1382
+ Xresampled : {ndarray | pandas.DataFrame} of shape
1383
+ (n_resampled, n_features), default=None
1384
  Resampled training data used for denoising.
1385
 
1386
  weights : {ndarray | pandas.DataFrame} of the same shape as y, default=None
 
1421
  X, y, Xresampled, variable_names
1422
  )
1423
 
1424
+ # Warn about large feature counts (still warn if feature count is large
1425
+ # after running feature selection)
1426
  if self.n_features_in_ >= 10:
1427
  warnings.warn(
1428
  "Note: you are running with 10 features or more. "
 
1514
  def predict(self, X, index=None):
1515
  """Predict y from input X using the equation chosen by `model_selection`.
1516
 
1517
+ You may see what equation is used by printing this object. X should
1518
+ have the same columns as the training data.
1519
 
1520
  Parameters
1521
  ----------
 
1552
  SymPy representation of the best equation.
1553
  """
1554
  self.refresh()
1555
+ best_equation = self.get_best(index=index)
1556
  if self.nout_ > 1:
1557
+ return [eq["sympy_format"] for eq in best_equation]
1558
+ return best_equation["sympy_format"]
1559
 
1560
  def latex(self, index=None):
1561
  """Return latex representation of the equation(s) chosen by `model_selection`.
 
1598
  Dictionary of callable jax function in "callable" key,
1599
  and jax array of parameters as "parameters" key.
1600
  """
 
1601
  self.set_params(output_jax_format=True)
1602
  self.refresh()
1603
+ best_equation = self.get_best(index=index)
1604
  if self.nout_ > 1:
1605
+ return [eq["jax_format"] for eq in best_equation]
1606
+ return best_equation["jax_format"]
1607
 
1608
  def pytorch(self, index=None):
1609
  """Return pytorch representation of the equation(s) chosen by `model_selection`.
 
1627
  """
1628
  self.set_params(output_torch_format=True)
1629
  self.refresh()
1630
+ best_equation = self.get_best(index=index)
1631
  if self.nout_ > 1:
1632
+ return [eq["torch_format"] for eq in best_equation]
1633
+ return best_equation["torch_format"]
1634
 
1635
  def get_hof(self):
1636
  """Get the equations from a hall of fame file. If no arguments
 
1797
  """Use a gradient boosting tree regressor as a proxy for finding
1798
  the k most important features in X, returning indices for those
1799
  features as output."""
 
1800
  from sklearn.ensemble import RandomForestRegressor
1801
  from sklearn.feature_selection import SelectFromModel
1802