deepsource-autofix[bot] commited on
Commit
5bb2875
·
unverified ·
1 Parent(s): b5d0afb

Refactor unnecessary `else` / `elif` when `if` block has a `return` statement

Browse files
Files changed (2) hide show
  1. pysr/export_jax.py +9 -11
  2. pysr/sr.py +8 -15
pysr/export_jax.py CHANGED
@@ -55,21 +55,19 @@ def sympy2jaxtext(expr, parameters, symbols_in):
55
  if issubclass(expr.func, sympy.Float):
56
  parameters.append(float(expr))
57
  return f"parameters[{len(parameters) - 1}]"
58
- elif issubclass(expr.func, sympy.Integer):
59
  return f"{int(expr)}"
60
- elif issubclass(expr.func, sympy.Symbol):
61
  return (
62
  f"X[:, {[i for i in range(len(symbols_in)) if symbols_in[i] == expr][0]}]"
63
  )
64
- else:
65
- _func = _jnp_func_lookup[expr.func]
66
- args = [sympy2jaxtext(arg, parameters, symbols_in) for arg in expr.args]
67
- if _func == MUL:
68
- return " * ".join(["(" + arg + ")" for arg in args])
69
- elif _func == ADD:
70
- return " + ".join(["(" + arg + ")" for arg in args])
71
- else:
72
- return f'{_func}({", ".join(args)})'
73
 
74
 
75
  jax_initialized = False
 
55
  if issubclass(expr.func, sympy.Float):
56
  parameters.append(float(expr))
57
  return f"parameters[{len(parameters) - 1}]"
58
+ if issubclass(expr.func, sympy.Integer):
59
  return f"{int(expr)}"
60
+ if issubclass(expr.func, sympy.Symbol):
61
  return (
62
  f"X[:, {[i for i in range(len(symbols_in)) if symbols_in[i] == expr][0]}]"
63
  )
64
+ _func = _jnp_func_lookup[expr.func]
65
+ args = [sympy2jaxtext(arg, parameters, symbols_in) for arg in expr.args]
66
+ if _func == MUL:
67
+ return " * ".join(["(" + arg + ")" for arg in args])
68
+ if _func == ADD:
69
+ return " + ".join(["(" + arg + ")" for arg in args])
70
+ return f'{_func}({", ".join(args)})'
 
 
71
 
72
 
73
  jax_initialized = False
pysr/sr.py CHANGED
@@ -643,10 +643,9 @@ def _make_hyperparams_julia_str(
643
  def tuple_fix(ops):
644
  if len(ops) > 1:
645
  return ", ".join(ops)
646
- elif len(ops) == 0:
647
  return ""
648
- else:
649
- return ops[0] + ","
650
 
651
  def_hyperparams += f"""\n
652
  plus=(+)
@@ -1025,8 +1024,7 @@ def get_hof(
1025
 
1026
  if multioutput:
1027
  return ret_outputs
1028
- else:
1029
- return ret_outputs[0]
1030
 
1031
 
1032
  def best_row(equations=None):
@@ -1037,8 +1035,7 @@ def best_row(equations=None):
1037
  equations = get_hof()
1038
  if isinstance(equations, list):
1039
  return [eq.iloc[np.argmax(eq["score"])] for eq in equations]
1040
- else:
1041
- return equations.iloc[np.argmax(equations["score"])]
1042
 
1043
 
1044
  def best_tex(equations=None):
@@ -1051,8 +1048,7 @@ def best_tex(equations=None):
1051
  return [
1052
  sympy.latex(best_row(eq)["sympy_format"].simplify()) for eq in equations
1053
  ]
1054
- else:
1055
- return sympy.latex(best_row(equations)["sympy_format"].simplify())
1056
 
1057
 
1058
  def best(equations=None):
@@ -1063,8 +1059,7 @@ def best(equations=None):
1063
  equations = get_hof()
1064
  if isinstance(equations, list):
1065
  return [best_row(eq)["sympy_format"].simplify() for eq in equations]
1066
- else:
1067
- return best_row(equations)["sympy_format"].simplify()
1068
 
1069
 
1070
  def best_callable(equations=None):
@@ -1075,8 +1070,7 @@ def best_callable(equations=None):
1075
  equations = get_hof()
1076
  if isinstance(equations, list):
1077
  return [best_row(eq)["lambda_format"] for eq in equations]
1078
- else:
1079
- return best_row(equations)["lambda_format"]
1080
 
1081
 
1082
  def _escape_filename(filename):
@@ -1114,5 +1108,4 @@ class CallableEquation(object):
1114
  def __call__(self, X):
1115
  if self._selection is not None:
1116
  return self._lambda(*X[:, self._selection].T)
1117
- else:
1118
- return self._lambda(*X.T)
 
643
  def tuple_fix(ops):
644
  if len(ops) > 1:
645
  return ", ".join(ops)
646
+ if len(ops) == 0:
647
  return ""
648
+ return ops[0] + ","
 
649
 
650
  def_hyperparams += f"""\n
651
  plus=(+)
 
1024
 
1025
  if multioutput:
1026
  return ret_outputs
1027
+ return ret_outputs[0]
 
1028
 
1029
 
1030
  def best_row(equations=None):
 
1035
  equations = get_hof()
1036
  if isinstance(equations, list):
1037
  return [eq.iloc[np.argmax(eq["score"])] for eq in equations]
1038
+ return equations.iloc[np.argmax(equations["score"])]
 
1039
 
1040
 
1041
  def best_tex(equations=None):
 
1048
  return [
1049
  sympy.latex(best_row(eq)["sympy_format"].simplify()) for eq in equations
1050
  ]
1051
+ return sympy.latex(best_row(equations)["sympy_format"].simplify())
 
1052
 
1053
 
1054
  def best(equations=None):
 
1059
  equations = get_hof()
1060
  if isinstance(equations, list):
1061
  return [best_row(eq)["sympy_format"].simplify() for eq in equations]
1062
+ return best_row(equations)["sympy_format"].simplify()
 
1063
 
1064
 
1065
  def best_callable(equations=None):
 
1070
  equations = get_hof()
1071
  if isinstance(equations, list):
1072
  return [best_row(eq)["lambda_format"] for eq in equations]
1073
+ return best_row(equations)["lambda_format"]
 
1074
 
1075
 
1076
  def _escape_filename(filename):
 
1108
  def __call__(self, X):
1109
  if self._selection is not None:
1110
  return self._lambda(*X[:, self._selection].T)
1111
+ return self._lambda(*X.T)