MilesCranmer commited on
Commit
76cb421
1 Parent(s): 095a45e

Include additional kwargs if passed

Browse files
Files changed (1) hide show
  1. pysr/sr.py +4 -0
pysr/sr.py CHANGED
@@ -149,6 +149,7 @@ def pysr(
149
  Xresampled=None,
150
  precision=32,
151
  multithreading=None,
 
152
  ):
153
  """Run symbolic regression to fit f(X[i, :]) ~ y[i] for all i.
154
  Note: most default parameters have been tuned over several example
@@ -265,6 +266,8 @@ def pysr(
265
  :type precision: int
266
  :param multithreading: Use multithreading instead of distributed backend. Default is yes. Using procs=0 will turn off both.
267
  :type multithreading: bool
 
 
268
  :returns: Results dataframe, giving complexity, MSE, and equations (as strings), as well as functional forms. If list, each element corresponds to a dataframe of equations for each output.
269
  :type: pd.DataFrame/list
270
  """
@@ -507,6 +510,7 @@ Tried to activate project {julia_project} but failed."""
507
  verbosity=int(verbosity),
508
  progress=progress,
509
  terminal_width=int(term_width),
 
510
  )
511
 
512
  np_dtype = {16: np.float16, 32: np.float32, 64: np.float64}[precision]
 
149
  Xresampled=None,
150
  precision=32,
151
  multithreading=None,
152
+ **kwargs,
153
  ):
154
  """Run symbolic regression to fit f(X[i, :]) ~ y[i] for all i.
155
  Note: most default parameters have been tuned over several example
 
266
  :type precision: int
267
  :param multithreading: Use multithreading instead of distributed backend. Default is yes. Using procs=0 will turn off both.
268
  :type multithreading: bool
269
+ :param **kwargs: Other options passed to SymbolicRegression.Options, for example, if you modify SymbolicRegression.jl to include additional arguments.
270
+ :type **kwargs: dict
271
  :returns: Results dataframe, giving complexity, MSE, and equations (as strings), as well as functional forms. If list, each element corresponds to a dataframe of equations for each output.
272
  :type: pd.DataFrame/list
273
  """
 
510
  verbosity=int(verbosity),
511
  progress=progress,
512
  terminal_width=int(term_width),
513
+ **kwargs,
514
  )
515
 
516
  np_dtype = {16: np.float16, 32: np.float32, 64: np.float64}[precision]