MilesCranmer commited on
Commit
1e13cd6
1 Parent(s): 3d94cb6

Fix bug with different precision in numpy data

Browse files
Files changed (1) hide show
  1. pysr/sr.py +3 -3
pysr/sr.py CHANGED
@@ -354,7 +354,7 @@ def _create_julia_files(dataset_filename, def_datasets, hyperparam_filename, de
354
  if len(variable_names) == 0:
355
  varMap = "[" + ",".join([f'"x{i}"' for i in range(X.shape[1])]) + "]"
356
  else:
357
- varMap = "[" + ",".join(variable_names) + "]"
358
 
359
  if weights is not None:
360
  print(f'EquationSearch(X, y, weights=weights, niterations={niterations:d}, varMap={varMap}, options=options, numprocs={procs})', file=f)
@@ -364,8 +364,8 @@ def _create_julia_files(dataset_filename, def_datasets, hyperparam_filename, de
364
 
365
  def _make_datasets_julia_str(X, X_filename, weights, weights_filename, y, y_filename, **kwargs):
366
  def_datasets = """using DelimitedFiles"""
367
- np.savetxt(X_filename, X, delimiter=',')
368
- np.savetxt(y_filename, y.reshape(-1, 1), delimiter=',')
369
  if weights is not None:
370
  np.savetxt(weights_filename, weights.reshape(-1, 1), delimiter=',')
371
  def_datasets += f"""
 
354
  if len(variable_names) == 0:
355
  varMap = "[" + ",".join([f'"x{i}"' for i in range(X.shape[1])]) + "]"
356
  else:
357
+ varMap = "[" + ",".join(['"' + vname + '"' for vname in variable_names]) + "]"
358
 
359
  if weights is not None:
360
  print(f'EquationSearch(X, y, weights=weights, niterations={niterations:d}, varMap={varMap}, options=options, numprocs={procs})', file=f)
 
364
 
365
  def _make_datasets_julia_str(X, X_filename, weights, weights_filename, y, y_filename, **kwargs):
366
  def_datasets = """using DelimitedFiles"""
367
+ np.savetxt(X_filename, X.astype(np.float32), delimiter=',')
368
+ np.savetxt(y_filename, y.reshape(-1, 1).astype(np.float32), delimiter=',')
369
  if weights is not None:
370
  np.savetxt(weights_filename, weights.reshape(-1, 1), delimiter=',')
371
  def_datasets += f"""