MilesCranmer commited on
Commit
c0ffbd2
1 Parent(s): 32690ba

Pass through `enable_autodiff` parameter

Browse files
Files changed (2) hide show
  1. docs/param_groupings.yml +1 -0
  2. pysr/sr.py +8 -0
docs/param_groupings.yml CHANGED
@@ -71,6 +71,7 @@
71
  - precision
72
  - fast_cycle
73
  - turbo
 
74
  - random_state
75
  - deterministic
76
  - warm_start
 
71
  - precision
72
  - fast_cycle
73
  - turbo
74
+ - enable_autodiff
75
  - random_state
76
  - deterministic
77
  - warm_start
pysr/sr.py CHANGED
@@ -525,6 +525,11 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
525
  If you pass complex data, the corresponding complex precision
526
  will be used (i.e., `64` for complex128, `32` for complex64).
527
  Default is `32`.
 
 
 
 
 
528
  random_state : int, Numpy RandomState instance or None
529
  Pass an int for reproducible results across multiple function calls.
530
  See :term:`Glossary <random_state>`.
@@ -747,6 +752,7 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
747
  fast_cycle=False,
748
  turbo=False,
749
  precision=32,
 
750
  random_state=None,
751
  deterministic=False,
752
  warm_start=False,
@@ -839,6 +845,7 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
839
  self.fast_cycle = fast_cycle
840
  self.turbo = turbo
841
  self.precision = precision
 
842
  self.random_state = random_state
843
  self.deterministic = deterministic
844
  self.warm_start = warm_start
@@ -1623,6 +1630,7 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
1623
  maxdepth=maxdepth,
1624
  fast_cycle=self.fast_cycle,
1625
  turbo=self.turbo,
 
1626
  migration=self.migration,
1627
  hof_migration=self.hof_migration,
1628
  fraction_replaced_hof=self.fraction_replaced_hof,
 
525
  If you pass complex data, the corresponding complex precision
526
  will be used (i.e., `64` for complex128, `32` for complex64).
527
  Default is `32`.
528
+ enable_autodiff : bool
529
+ Whether to create derivative versions of operators for automatic
530
+ differentiation. This is only necessary if you wish to compute
531
+ the gradients of an expression within a custom loss function.
532
+ Default is `False`.
533
  random_state : int, Numpy RandomState instance or None
534
  Pass an int for reproducible results across multiple function calls.
535
  See :term:`Glossary <random_state>`.
 
752
  fast_cycle=False,
753
  turbo=False,
754
  precision=32,
755
+ enable_autodiff=False,
756
  random_state=None,
757
  deterministic=False,
758
  warm_start=False,
 
845
  self.fast_cycle = fast_cycle
846
  self.turbo = turbo
847
  self.precision = precision
848
+ self.enable_autodiff = enable_autodiff
849
  self.random_state = random_state
850
  self.deterministic = deterministic
851
  self.warm_start = warm_start
 
1630
  maxdepth=maxdepth,
1631
  fast_cycle=self.fast_cycle,
1632
  turbo=self.turbo,
1633
+ enable_autodiff=self.enable_autodiff,
1634
  migration=self.migration,
1635
  hof_migration=self.hof_migration,
1636
  fraction_replaced_hof=self.fraction_replaced_hof,