MilesCranmer commited on
Commit
d0788ef
1 Parent(s): 44216ab

Fix syntax error in JAX converter

Browse files
Files changed (2) hide show
  1. pysr/export.py +1 -1
  2. pysr/sr.py +2 -2
pysr/export.py CHANGED
@@ -62,7 +62,7 @@ def sympy2jaxtext(expr, parameters, symbols_in):
62
  parameters.append(float(expr))
63
  return f"parameters[{len(parameters) - 1}]"
64
  elif issubclass(expr.func, sympy.Integer):
65
- return "{int(expr)}"
66
  elif issubclass(expr.func, sympy.Symbol):
67
  return f"X[:, {[i for i in range(len(symbols_in)) if symbols_in[i] == expr][0]}]"
68
  else:
 
62
  parameters.append(float(expr))
63
  return f"parameters[{len(parameters) - 1}]"
64
  elif issubclass(expr.func, sympy.Integer):
65
+ return f"{int(expr)}"
66
  elif issubclass(expr.func, sympy.Symbol):
67
  return f"X[:, {[i for i in range(len(symbols_in)) if symbols_in[i] == expr][0]}]"
68
  else:
pysr/sr.py CHANGED
@@ -686,7 +686,7 @@ def get_hof(equation_file=None, n_features=None, variable_names=None,
686
  sympy_format.append(eqn)
687
  if output_jax_format:
688
  func, params = sympy2jax(eqn, sympy_symbols)
689
- jax_format.append({'callable': func, 'parameters': parameters})
690
  lambda_format.append(lambdify(sympy_symbols, eqn))
691
  curMSE = output.loc[i, 'MSE']
692
  curComplexity = output.loc[i, 'Complexity']
@@ -705,7 +705,7 @@ def get_hof(equation_file=None, n_features=None, variable_names=None,
705
  output['lambda_format'] = lambda_format
706
  output_cols = ['Complexity', 'MSE', 'score', 'Equation', 'sympy_format', 'lambda_format']
707
  if output_jax_format:
708
- output_cols += 'jax_format'
709
  output['jax_format'] = jax_format
710
 
711
  return output[output_cols]
 
686
  sympy_format.append(eqn)
687
  if output_jax_format:
688
  func, params = sympy2jax(eqn, sympy_symbols)
689
+ jax_format.append({'callable': func, 'parameters': params})
690
  lambda_format.append(lambdify(sympy_symbols, eqn))
691
  curMSE = output.loc[i, 'MSE']
692
  curComplexity = output.loc[i, 'Complexity']
 
705
  output['lambda_format'] = lambda_format
706
  output_cols = ['Complexity', 'MSE', 'score', 'Equation', 'sympy_format', 'lambda_format']
707
  if output_jax_format:
708
+ output_cols += ['jax_format']
709
  output['jax_format'] = jax_format
710
 
711
  return output[output_cols]