Spaces:
Running
Running
Commit
·
4d915b2
1
Parent(s):
09b1cf7
Change lambda_format to same format as torch/jax
Browse files- pysr/sr.py +2 -1
pysr/sr.py
CHANGED
@@ -795,7 +795,8 @@ def get_hof(equation_file=None, n_features=None, variable_names=None,
|
|
795 |
if output_jax_format:
|
796 |
func, params = sympy2jax(eqn, sympy_symbols)
|
797 |
jax_format.append({'callable': func, 'parameters': params})
|
798 |
-
|
|
|
799 |
curMSE = output.loc[i, 'MSE']
|
800 |
curComplexity = output.loc[i, 'Complexity']
|
801 |
|
|
|
795 |
if output_jax_format:
|
796 |
func, params = sympy2jax(eqn, sympy_symbols)
|
797 |
jax_format.append({'callable': func, 'parameters': params})
|
798 |
+
tmp_lambda = lambdify(sympy_symbols, eqn)
|
799 |
+
lambda_format.append(lambda X: tmp_lambda(*X.T))
|
800 |
curMSE = output.loc[i, 'MSE']
|
801 |
curComplexity = output.loc[i, 'Complexity']
|
802 |
|