MilesCranmer commited on
Commit
bfe511a
·
unverified ·
2 Parent(s): c623864 5bb2875

Merge pull request #52 from MilesCranmer/deepsource-fix-147c2d73

Browse files

Refactor unnecessary `else` / `elif` when `if` block has a `return` statement

Files changed (2) hide show
  1. pysr/export_jax.py +9 -11
  2. pysr/sr.py +8 -15
pysr/export_jax.py CHANGED
@@ -55,21 +55,19 @@ def sympy2jaxtext(expr, parameters, symbols_in):
55
  if issubclass(expr.func, sympy.Float):
56
  parameters.append(float(expr))
57
  return f"parameters[{len(parameters) - 1}]"
58
- elif issubclass(expr.func, sympy.Integer):
59
  return f"{int(expr)}"
60
- elif issubclass(expr.func, sympy.Symbol):
61
  return (
62
  f"X[:, {[i for i in range(len(symbols_in)) if symbols_in[i] == expr][0]}]"
63
  )
64
- else:
65
- _func = _jnp_func_lookup[expr.func]
66
- args = [sympy2jaxtext(arg, parameters, symbols_in) for arg in expr.args]
67
- if _func == MUL:
68
- return " * ".join(["(" + arg + ")" for arg in args])
69
- elif _func == ADD:
70
- return " + ".join(["(" + arg + ")" for arg in args])
71
- else:
72
- return f'{_func}({", ".join(args)})'
73
 
74
 
75
  jax_initialized = False
 
55
  if issubclass(expr.func, sympy.Float):
56
  parameters.append(float(expr))
57
  return f"parameters[{len(parameters) - 1}]"
58
+ if issubclass(expr.func, sympy.Integer):
59
  return f"{int(expr)}"
60
+ if issubclass(expr.func, sympy.Symbol):
61
  return (
62
  f"X[:, {[i for i in range(len(symbols_in)) if symbols_in[i] == expr][0]}]"
63
  )
64
+ _func = _jnp_func_lookup[expr.func]
65
+ args = [sympy2jaxtext(arg, parameters, symbols_in) for arg in expr.args]
66
+ if _func == MUL:
67
+ return " * ".join(["(" + arg + ")" for arg in args])
68
+ if _func == ADD:
69
+ return " + ".join(["(" + arg + ")" for arg in args])
70
+ return f'{_func}({", ".join(args)})'
 
 
71
 
72
 
73
  jax_initialized = False
pysr/sr.py CHANGED
@@ -643,10 +643,9 @@ def _make_hyperparams_julia_str(
643
  def tuple_fix(ops):
644
  if len(ops) > 1:
645
  return ", ".join(ops)
646
- elif len(ops) == 0:
647
  return ""
648
- else:
649
- return ops[0] + ","
650
 
651
  def_hyperparams += f"""\n
652
  plus=(+)
@@ -1024,8 +1023,7 @@ def get_hof(
1024
 
1025
  if multioutput:
1026
  return ret_outputs
1027
- else:
1028
- return ret_outputs[0]
1029
 
1030
 
1031
  def best_row(equations=None):
@@ -1036,8 +1034,7 @@ def best_row(equations=None):
1036
  equations = get_hof()
1037
  if isinstance(equations, list):
1038
  return [eq.iloc[np.argmax(eq["score"])] for eq in equations]
1039
- else:
1040
- return equations.iloc[np.argmax(equations["score"])]
1041
 
1042
 
1043
  def best_tex(equations=None):
@@ -1050,8 +1047,7 @@ def best_tex(equations=None):
1050
  return [
1051
  sympy.latex(best_row(eq)["sympy_format"].simplify()) for eq in equations
1052
  ]
1053
- else:
1054
- return sympy.latex(best_row(equations)["sympy_format"].simplify())
1055
 
1056
 
1057
  def best(equations=None):
@@ -1062,8 +1058,7 @@ def best(equations=None):
1062
  equations = get_hof()
1063
  if isinstance(equations, list):
1064
  return [best_row(eq)["sympy_format"].simplify() for eq in equations]
1065
- else:
1066
- return best_row(equations)["sympy_format"].simplify()
1067
 
1068
 
1069
  def best_callable(equations=None):
@@ -1074,8 +1069,7 @@ def best_callable(equations=None):
1074
  equations = get_hof()
1075
  if isinstance(equations, list):
1076
  return [best_row(eq)["lambda_format"] for eq in equations]
1077
- else:
1078
- return best_row(equations)["lambda_format"]
1079
 
1080
 
1081
  def _escape_filename(filename):
@@ -1113,5 +1107,4 @@ class CallableEquation(object):
1113
  def __call__(self, X):
1114
  if self._selection is not None:
1115
  return self._lambda(*X[:, self._selection].T)
1116
- else:
1117
- return self._lambda(*X.T)
 
643
  def tuple_fix(ops):
644
  if len(ops) > 1:
645
  return ", ".join(ops)
646
+ if len(ops) == 0:
647
  return ""
648
+ return ops[0] + ","
 
649
 
650
  def_hyperparams += f"""\n
651
  plus=(+)
 
1023
 
1024
  if multioutput:
1025
  return ret_outputs
1026
+ return ret_outputs[0]
 
1027
 
1028
 
1029
  def best_row(equations=None):
 
1034
  equations = get_hof()
1035
  if isinstance(equations, list):
1036
  return [eq.iloc[np.argmax(eq["score"])] for eq in equations]
1037
+ return equations.iloc[np.argmax(equations["score"])]
 
1038
 
1039
 
1040
  def best_tex(equations=None):
 
1047
  return [
1048
  sympy.latex(best_row(eq)["sympy_format"].simplify()) for eq in equations
1049
  ]
1050
+ return sympy.latex(best_row(equations)["sympy_format"].simplify())
 
1051
 
1052
 
1053
  def best(equations=None):
 
1058
  equations = get_hof()
1059
  if isinstance(equations, list):
1060
  return [best_row(eq)["sympy_format"].simplify() for eq in equations]
1061
+ return best_row(equations)["sympy_format"].simplify()
 
1062
 
1063
 
1064
  def best_callable(equations=None):
 
1069
  equations = get_hof()
1070
  if isinstance(equations, list):
1071
  return [best_row(eq)["lambda_format"] for eq in equations]
1072
+ return best_row(equations)["lambda_format"]
 
1073
 
1074
 
1075
  def _escape_filename(filename):
 
1107
  def __call__(self, X):
1108
  if self._selection is not None:
1109
  return self._lambda(*X[:, self._selection].T)
1110
+ return self._lambda(*X.T)