MilesCranmer commited on
Commit
17c9b1a
1 Parent(s): beaf20b

Fix sympy2jax for rational numbers

Browse files
Files changed (1) hide show
  1. pysr/export_jax.py +4 -2
pysr/export_jax.py CHANGED
@@ -58,9 +58,11 @@ def sympy2jaxtext(expr, parameters, symbols_in, extra_jax_mappings=None):
58
  if issubclass(expr.func, sympy.Float):
59
  parameters.append(float(expr))
60
  return f"parameters[{len(parameters) - 1}]"
61
- if issubclass(expr.func, sympy.Integer):
 
 
62
  return f"{int(expr)}"
63
- if issubclass(expr.func, sympy.Symbol):
64
  return (
65
  f"X[:, {[i for i in range(len(symbols_in)) if symbols_in[i] == expr][0]}]"
66
  )
 
58
  if issubclass(expr.func, sympy.Float):
59
  parameters.append(float(expr))
60
  return f"parameters[{len(parameters) - 1}]"
61
+ elif issubclass(expr.func, sympy.Rational):
62
+ return f"{float(expr)}"
63
+ elif issubclass(expr.func, sympy.Integer):
64
  return f"{int(expr)}"
65
+ elif issubclass(expr.func, sympy.Symbol):
66
  return (
67
  f"X[:, {[i for i in range(len(symbols_in)) if symbols_in[i] == expr][0]}]"
68
  )