MilesCranmer commited on
Commit
c253783
1 Parent(s): 7f5b38a

Fix sympy output variable names

Browse files
Files changed (1) hide show
  1. pysr/sr.py +4 -1
pysr/sr.py CHANGED
@@ -293,7 +293,10 @@ const varMap = {'["' + '", "'.join(variable_names) + '"]'}"""
293
  lastComplexity = 0
294
  sympy_format = []
295
  lambda_format = []
296
- sympy_symbols = [sympy.Symbol('x%d'%i) for i in range(X.shape[1])]
 
 
 
297
  for i in range(len(output)):
298
  eqn = sympify(output.loc[i, 'Equation'], locals=local_sympy_mappings)
299
  sympy_format.append(eqn)
 
293
  lastComplexity = 0
294
  sympy_format = []
295
  lambda_format = []
296
+ if len(variable_names) != 0:
297
+ sympy_symbols = [sympy.Symbol(variable_names[i]) for i in range(X.shape[1])]
298
+ else:
299
+ sympy_symbols = [sympy.Symbol('x%d'%i) for i in range(X.shape[1])]
300
  for i in range(len(output)):
301
  eqn = sympify(output.loc[i, 'Equation'], locals=local_sympy_mappings)
302
  sympy_format.append(eqn)