MilesCranmer commited on
Commit
71ca872
·
unverified ·
2 Parent(s): b41237e 9e6c4b4

Merge pull request #18 from johannbrehmer/platform-independence

Browse files
Files changed (1) hide show
  1. pysr/sr.py +34 -12
pysr/sr.py CHANGED
@@ -7,6 +7,10 @@ import pandas as pd
7
  import sympy
8
  from sympy import sympify, Symbol, lambdify
9
  import subprocess
 
 
 
 
10
 
11
  global_equation_file = 'hall_of_fame.csv'
12
  global_n_features = None
@@ -92,6 +96,8 @@ def pysr(X=None, y=None, weights=None,
92
  warmupMaxsize=0,
93
  constraints={},
94
  useFrequency=False,
 
 
95
  limitPowComplexity=False, #deprecated
96
  threads=None, #deprecated
97
  julia_optimization=3,
@@ -178,6 +184,8 @@ def pysr(X=None, y=None, weights=None,
178
  and use that instead of parsimony to explore equation space. Will
179
  naturally find equations of all complexities.
180
  :param julia_optimization: int, Optimization level (0, 1, 2, 3)
 
 
181
  :returns: pd.DataFrame, Results dataframe, giving complexity, MSE, and equations
182
  (as strings).
183
 
@@ -241,7 +249,15 @@ def pysr(X=None, y=None, weights=None,
241
  y = eval(eval_str)
242
  print("Running on", eval_str)
243
 
244
- pkg_directory = '/'.join(__file__.split('/')[:-2] + ['julia'])
 
 
 
 
 
 
 
 
245
 
246
  def_hyperparams = ""
247
 
@@ -273,7 +289,7 @@ def pysr(X=None, y=None, weights=None,
273
  elif op == 'mult':
274
  # Make sure the complex expression is in the left side.
275
  if constraints[op][0] == -1:
276
- continue
277
  elif constraints[op][1] == -1 or constraints[op][0] < constraints[op][1]:
278
  constraints[op][0], constraints[op][1] = constraints[op][1], constraints[op][0]
279
 
@@ -298,8 +314,7 @@ const bin_constraints = ["""
298
  first = False
299
  constraints_str += "]"
300
 
301
-
302
- def_hyperparams += f"""include("{pkg_directory}/operators.jl")
303
  {constraints_str}
304
  const binops = {'[' + ', '.join(binary_operators) + ']'}
305
  const unaops = {'[' + ', '.join(unary_operators) + ']'}
@@ -393,16 +408,16 @@ const weights = convert(Array{Float32, 1}, """f"{weight_str})"
393
  def_hyperparams += f"""
394
  const varMap = {'["' + '", "'.join(variable_names) + '"]'}"""
395
 
396
- with open(f'/tmp/.hyperparams_{rand_string}.jl', 'w') as f:
397
  print(def_hyperparams, file=f)
398
 
399
- with open(f'/tmp/.dataset_{rand_string}.jl', 'w') as f:
400
  print(def_datasets, file=f)
401
 
402
- with open(f'/tmp/.runfile_{rand_string}.jl', 'w') as f:
403
- print(f'@everywhere include("/tmp/.hyperparams_{rand_string}.jl")', file=f)
404
- print(f'@everywhere include("/tmp/.dataset_{rand_string}.jl")', file=f)
405
- print(f'@everywhere include("{pkg_directory}/sr.jl")', file=f)
406
  print(f'fullRun({niterations:d}, npop={npop:d}, ncyclesperiteration={ncyclesperiteration:d}, fractionReplaced={fractionReplaced:f}f0, verbosity=round(Int32, {verbosity:f}), topn={topn:d})', file=f)
407
  print(f'rmprocs(nprocs)', file=f)
408
 
@@ -410,7 +425,7 @@ const varMap = {'["' + '", "'.join(variable_names) + '"]'}"""
410
  command = [
411
  f'julia', f'-O{julia_optimization:d}',
412
  f'-p', f'{procs}',
413
- f'/tmp/.runfile_{rand_string}.jl',
414
  ]
415
  if timeout is not None:
416
  command = [f'timeout', f'{timeout}'] + command
@@ -439,6 +454,9 @@ const varMap = {'["' + '", "'.join(variable_names) + '"]'}"""
439
  print("Killing process... will return when done.")
440
  process.kill()
441
 
 
 
 
442
  return get_hof()
443
 
444
 
@@ -550,4 +568,8 @@ def best_callable(equations=None):
550
  if equations is None: equations = get_hof()
551
  return best_row(equations)['lambda_format']
552
 
553
-
 
 
 
 
 
7
  import sympy
8
  from sympy import sympify, Symbol, lambdify
9
  import subprocess
10
+ import tempfile
11
+ import shutil
12
+ from pathlib import Path
13
+
14
 
15
  global_equation_file = 'hall_of_fame.csv'
16
  global_n_features = None
 
96
  warmupMaxsize=0,
97
  constraints={},
98
  useFrequency=False,
99
+ tempdir=None,
100
+ delete_tempfiles=True,
101
  limitPowComplexity=False, #deprecated
102
  threads=None, #deprecated
103
  julia_optimization=3,
 
184
  and use that instead of parsimony to explore equation space. Will
185
  naturally find equations of all complexities.
186
  :param julia_optimization: int, Optimization level (0, 1, 2, 3)
187
+ :param tempdir: str or None, directory for the temporary files
188
+ :param delete_tempfiles: bool, whether to delete the temporary files after finishing
189
  :returns: pd.DataFrame, Results dataframe, giving complexity, MSE, and equations
190
  (as strings).
191
 
 
249
  y = eval(eval_str)
250
  print("Running on", eval_str)
251
 
252
+ # System-independent paths
253
+ pkg_directory = Path(__file__).parents[1] / 'julia'
254
+ pkg_filename = pkg_directory / "sr.jl"
255
+ operator_filename = pkg_directory / "operators.jl"
256
+
257
+ tmpdir = Path(tempfile.mkdtemp(dir=tempdir))
258
+ hyperparam_filename = tmpdir / f'.hyperparams_{rand_string}.jl'
259
+ dataset_filename = tmpdir / f'.dataset_{rand_string}.jl'
260
+ runfile_filename = tmpdir / f'.runfile_{rand_string}.jl'
261
 
262
  def_hyperparams = ""
263
 
 
289
  elif op == 'mult':
290
  # Make sure the complex expression is in the left side.
291
  if constraints[op][0] == -1:
292
+ continue
293
  elif constraints[op][1] == -1 or constraints[op][0] < constraints[op][1]:
294
  constraints[op][0], constraints[op][1] = constraints[op][1], constraints[op][0]
295
 
 
314
  first = False
315
  constraints_str += "]"
316
 
317
+ def_hyperparams += f"""include("{_escape_filename(operator_filename)}")
 
318
  {constraints_str}
319
  const binops = {'[' + ', '.join(binary_operators) + ']'}
320
  const unaops = {'[' + ', '.join(unary_operators) + ']'}
 
408
  def_hyperparams += f"""
409
  const varMap = {'["' + '", "'.join(variable_names) + '"]'}"""
410
 
411
+ with open(hyperparam_filename, 'w') as f:
412
  print(def_hyperparams, file=f)
413
 
414
+ with open(dataset_filename, 'w') as f:
415
  print(def_datasets, file=f)
416
 
417
+ with open(runfile_filename, 'w') as f:
418
+ print(f'@everywhere include("{_escape_filename(hyperparam_filename)}")', file=f)
419
+ print(f'@everywhere include("{_escape_filename(dataset_filename)}")', file=f)
420
+ print(f'@everywhere include("{_escape_filename(pkg_filename)}")', file=f)
421
  print(f'fullRun({niterations:d}, npop={npop:d}, ncyclesperiteration={ncyclesperiteration:d}, fractionReplaced={fractionReplaced:f}f0, verbosity=round(Int32, {verbosity:f}), topn={topn:d})', file=f)
422
  print(f'rmprocs(nprocs)', file=f)
423
 
 
425
  command = [
426
  f'julia', f'-O{julia_optimization:d}',
427
  f'-p', f'{procs}',
428
+ str(runfile_filename),
429
  ]
430
  if timeout is not None:
431
  command = [f'timeout', f'{timeout}'] + command
 
454
  print("Killing process... will return when done.")
455
  process.kill()
456
 
457
+ if delete_tempfiles:
458
+ shutil.rmtree(tmpdir)
459
+
460
  return get_hof()
461
 
462
 
 
568
  if equations is None: equations = get_hof()
569
  return best_row(equations)['lambda_format']
570
 
571
+ def _escape_filename(filename):
572
+ """Turns a file into a string representation with correctly escaped backslashes"""
573
+ repr = str(filename)
574
+ repr = repr.replace('\\', '\\\\')
575
+ return repr