MilesCranmer commited on
Commit
4d915b2
·
1 Parent(s): 09b1cf7

Change lambda_format to same format as torch/jax

Browse files
Files changed (1) hide show
  1. pysr/sr.py +2 -1
pysr/sr.py CHANGED
@@ -795,7 +795,8 @@ def get_hof(equation_file=None, n_features=None, variable_names=None,
795
  if output_jax_format:
796
  func, params = sympy2jax(eqn, sympy_symbols)
797
  jax_format.append({'callable': func, 'parameters': params})
798
- lambda_format.append(lambdify(sympy_symbols, eqn))
 
799
  curMSE = output.loc[i, 'MSE']
800
  curComplexity = output.loc[i, 'Complexity']
801
 
 
795
  if output_jax_format:
796
  func, params = sympy2jax(eqn, sympy_symbols)
797
  jax_format.append({'callable': func, 'parameters': params})
798
+ tmp_lambda = lambdify(sympy_symbols, eqn)
799
+ lambda_format.append(lambda X: tmp_lambda(*X.T))
800
  curMSE = output.loc[i, 'MSE']
801
  curComplexity = output.loc[i, 'Complexity']
802