MilesCranmer commited on
Commit
97f43e5
·
1 Parent(s): 42acd41

Remove non-PyJulia parts of codebase

Browse files
Files changed (1) hide show
  1. pysr/sr.py +177 -529
pysr/sr.py CHANGED
@@ -27,7 +27,7 @@ global_state = dict(
27
  selection=None,
28
  )
29
 
30
- already_ran_with_pyjulia = False
31
 
32
  sympy_mappings = {
33
  "div": lambda x, y: x / y,
@@ -99,7 +99,6 @@ def pysr(
99
  weightRandomize=1,
100
  weightSimplify=0.01,
101
  perturbationFactor=1.0,
102
- timeout=None,
103
  extra_sympy_mappings=None,
104
  extra_torch_mappings=None,
105
  extra_jax_mappings=None,
@@ -118,7 +117,6 @@ def pysr(
118
  useFrequency=True,
119
  tempdir=None,
120
  delete_tempfiles=True,
121
- julia_optimization=3,
122
  julia_project=None,
123
  user_input=True,
124
  update=True,
@@ -135,7 +133,6 @@ def pysr(
135
  Xresampled=None,
136
  precision=32,
137
  multithreading=None,
138
- pyjulia=False,
139
  ):
140
  """Run symbolic regression to fit f(X[i, :]) ~ y[i] for all i.
141
  Note: most default parameters have been tuned over several example
@@ -202,8 +199,6 @@ def pysr(
202
  :type weightRandomize: float
203
  :param weightSimplify: Relative likelihood for mutation to simplify constant parts by evaluation
204
  :type weightSimplify: float
205
- :param timeout: Time in seconds to timeout search
206
- :type timeout: float
207
  :param equation_file: Where to save the files (.csv separated by |)
208
  :type equation_file: str
209
  :param verbosity: What verbosity level to use. 0 means minimal print statements.
@@ -230,8 +225,6 @@ def pysr(
230
  :type constraints: dict
231
  :param useFrequency: whether to measure the frequency of complexities, and use that instead of parsimony to explore equation space. Will naturally find equations of all complexities.
232
  :type useFrequency: bool
233
- :param julia_optimization: Optimization level (0, 1, 2, 3)
234
- :type julia_optimization: int
235
  :param tempdir: directory for the temporary files
236
  :type tempdir: str/None
237
  :param delete_tempfiles: whether to delete the temporary files after finishing
@@ -258,11 +251,11 @@ def pysr(
258
  :type precision: int
259
  :param multithreading: Use multithreading instead of distributed backend. Default is yes. Using procs=0 will turn off both.
260
  :type multithreading: bool
261
- :param pyjulia: Whether to use PyJulia instead of julia binary. PyJulia should reduce startup time for repeat calls.
262
- :type pyjulia: bool
263
  :returns: Results dataframe, giving complexity, MSE, and equations (as strings), as well as functional forms. If list, each element corresponds to a dataframe of equations for each output.
264
  :type: pd.DataFrame/list
265
  """
 
 
266
  if binary_operators is None:
267
  binary_operators = "+ * - /".split(" ")
268
  if unary_operators is None:
@@ -278,19 +271,14 @@ def pysr(
278
  # or procs is set to 0 (serial mode).
279
  multithreading = procs != 0
280
 
281
- # Start up Julia:
282
  global Main
283
- if pyjulia and Main is None:
284
- # if not multithreading:
285
- # raise AssertionError(
286
- # "PyJulia does not support multiprocessing. Turn multithreading=True."
287
- # )
288
-
289
  if multithreading:
290
  os.environ["JULIA_NUM_THREADS"] = str(procs)
 
291
  from julia import Main
292
 
293
- buffer_available = "buffer" in sys.stdout.__dir__() and not pyjulia
294
 
295
  if progress is not None:
296
  if progress and not buffer_available:
@@ -298,11 +286,6 @@ def pysr(
298
  "Note: it looks like you are running in Jupyter. The progress bar will be turned off."
299
  )
300
  progress = False
301
- if progress and pyjulia:
302
- warnings.warn(
303
- "Note: it looks like you are using PyJulia. The progress bar will be turned off."
304
- )
305
- progress = False
306
  else:
307
  progress = buffer_available
308
 
@@ -344,8 +327,6 @@ def pysr(
344
  weights,
345
  y,
346
  )
347
- if not pyjulia:
348
- _check_for_julia_installation()
349
 
350
  if len(X) > 10000 and not batching:
351
  warnings.warn(
@@ -398,503 +379,212 @@ def pysr(
398
  else:
399
  X, y = _denoise(X, y, Xresampled=Xresampled)
400
 
401
- kwargs = dict(
402
- X=X,
403
- y=y,
404
- weights=weights,
405
- alpha=alpha,
406
- annealing=annealing,
407
- batchSize=batchSize,
408
- batching=batching,
409
- binary_operators=binary_operators,
410
- fast_cycle=fast_cycle,
411
- fractionReplaced=fractionReplaced,
412
- ncyclesperiteration=ncyclesperiteration,
413
- niterations=niterations,
414
- npop=npop,
415
- topn=topn,
416
- verbosity=verbosity,
417
- progress=progress,
418
- update=update,
419
- julia_optimization=julia_optimization,
420
- timeout=timeout,
421
- fractionReplacedHof=fractionReplacedHof,
422
- hofMigration=hofMigration,
423
- maxdepth=maxdepth,
424
- maxsize=maxsize,
425
- migration=migration,
426
- optimizer_algorithm=optimizer_algorithm,
427
- optimizer_nrestarts=optimizer_nrestarts,
428
- optimize_probability=optimize_probability,
429
- optimizer_iterations=optimizer_iterations,
430
- parsimony=parsimony,
431
- perturbationFactor=perturbationFactor,
432
- populations=populations,
433
- procs=procs,
434
- shouldOptimizeConstants=shouldOptimizeConstants,
435
- unary_operators=unary_operators,
436
- useFrequency=useFrequency,
437
- use_custom_variable_names=use_custom_variable_names,
438
- variable_names=variable_names,
439
- warmupMaxsizeBy=warmupMaxsizeBy,
440
- weightAddNode=weightAddNode,
441
- weightDeleteNode=weightDeleteNode,
442
- weightDoNothing=weightDoNothing,
443
- weightInsertNode=weightInsertNode,
444
- weightMutateConstant=weightMutateConstant,
445
- weightMutateOperator=weightMutateOperator,
446
- weightRandomize=weightRandomize,
447
- weightSimplify=weightSimplify,
448
- constraints=constraints,
449
- extra_sympy_mappings=extra_sympy_mappings,
450
- extra_jax_mappings=extra_jax_mappings,
451
- extra_torch_mappings=extra_torch_mappings,
452
- julia_project=julia_project,
453
- loss=loss,
454
- output_jax_format=output_jax_format,
455
- output_torch_format=output_torch_format,
456
- selection=selection,
457
- multioutput=multioutput,
458
- nout=nout,
459
- tournament_selection_n=tournament_selection_n,
460
- tournament_selection_p=tournament_selection_p,
461
- denoise=denoise,
462
- precision=precision,
463
- multithreading=multithreading,
464
- pyjulia=pyjulia,
465
- )
466
-
467
- kwargs = {**_set_paths(tempdir), **kwargs}
468
 
469
  if temp_equation_file:
470
- equation_file = kwargs["tmpdir"] / "hall_of_fame.csv"
471
  elif equation_file is None:
472
  date_time = datetime.now().strftime("%Y-%m-%d_%H%M%S.%f")[:-3]
473
  equation_file = "hall_of_fame_" + date_time + ".csv"
474
 
475
- kwargs = {**dict(equation_file=equation_file), **kwargs}
476
-
477
- pkg_directory = kwargs["pkg_directory"]
478
- if kwargs["julia_project"] is not None:
479
- manifest_filepath = Path(kwargs["julia_project"]) / "Manifest.toml"
480
- else:
481
  manifest_filepath = pkg_directory / "Manifest.toml"
482
-
483
- # Set julia project to correct directory:
484
- if kwargs["julia_project"] is None:
485
- kwargs["julia_project"] = pkg_directory
486
  else:
487
- kwargs["julia_project"] = Path(kwargs["julia_project"])
 
488
 
489
- kwargs["need_install"] = False
490
 
491
  if not (manifest_filepath).is_file() and not pyjulia:
492
- kwargs["need_install"] = (not user_input) or _yesno(
493
  "I will install Julia packages using PySR's Project.toml file. OK?"
494
  )
495
- if kwargs["need_install"]:
496
  print("OK. I will install at launch.")
497
  assert update
498
 
499
- global already_ran_with_pyjulia
500
-
501
- kwargs["def_hyperparams"] = _create_inline_operators(**kwargs)
502
- _handle_constraints(**kwargs)
503
- kwargs["constraints_str"] = _make_constraints_str(**kwargs)
504
- kwargs["def_hyperparams"] = _make_hyperparams_julia_str(**kwargs)
505
-
506
- if pyjulia:
507
-
508
- from julia import Pkg
509
-
510
- Pkg.activate(f"{_escape_filename(kwargs['julia_project'])}")
511
- if kwargs["need_install"]:
512
- Pkg.instantiate()
513
- Pkg.update()
514
- Pkg.precompile()
515
- elif update:
516
- Pkg.update()
517
- from julia import SymbolicRegression
518
-
519
- already_ran_with_pyjulia = True
520
-
521
- X = kwargs["X"]
522
- y = kwargs["y"]
523
- weights = kwargs["weights"]
524
- def_hyperparams = kwargs["def_hyperparams"]
525
- variable_names = kwargs["variable_names"]
526
- multithreading = kwargs["multithreading"]
527
- procs = kwargs["procs"]
528
- niterations = kwargs["niterations"]
529
- precision = kwargs["precision"]
530
- np_dtype = {16: np.float16, 32: np.float32, 64: np.float64}[precision]
531
-
532
- Main.X = np.array(X, dtype=np_dtype).T
533
- if len(y.shape) == 1:
534
- Main.y = np.array(y, dtype=np_dtype)
535
- else:
536
- Main.y = np.array(y, dtype=np_dtype).T
537
- if weights is not None:
538
- if len(weights.shape) == 1:
539
- Main.weights = np.array(weights, dtype=np_dtype)
540
- else:
541
- Main.weights = np.array(weights, dtype=np_dtype).T
542
- else:
543
- Main.weights = None
544
-
545
- Main.eval(def_hyperparams)
546
-
547
- varMap = Main.eval(_make_varmap(X, variable_names))
548
- cprocs = 0 if multithreading else procs
549
-
550
- SymbolicRegression.EquationSearch(
551
- Main.X,
552
- Main.y,
553
- weights=Main.weights,
554
- niterations=niterations,
555
- varMap=varMap,
556
- options=Main.options,
557
- numprocs=cprocs,
558
- multithreading=multithreading,
559
- )
560
-
561
- else:
562
- kwargs["def_datasets"] = _make_datasets_julia_str(**kwargs)
563
-
564
- _create_julia_files(**kwargs)
565
- _final_pysr_process(**kwargs)
566
-
567
- _set_globals(**kwargs)
568
- equations = get_hof(**kwargs)
569
-
570
- if delete_tempfiles:
571
- shutil.rmtree(kwargs["tmpdir"])
572
-
573
- return equations
574
-
575
-
576
- def _set_globals(X, **kwargs):
577
- global global_state
578
-
579
- global_state["n_features"] = X.shape[1]
580
- for key, value in kwargs.items():
581
- if key in global_state:
582
- global_state[key] = value
583
-
584
-
585
- def _final_pysr_process(
586
- julia_optimization, runfile_filename, timeout, multithreading, procs, **kwargs
587
- ):
588
- command = [
589
- "julia",
590
- f"-O{julia_optimization:d}",
591
- ]
592
-
593
- if multithreading:
594
- command.append("--threads")
595
- command.append(f"{procs}")
596
-
597
- command.append(str(runfile_filename))
598
- if timeout is not None:
599
- command = ["timeout", f"{timeout}"] + command
600
- _cmd_runner(command, **kwargs)
601
 
 
 
602
 
603
- def _cmd_runner(command, progress, **kwargs):
604
- if kwargs["verbosity"] > 0:
605
- print("Running on", " ".join(command))
606
- process = subprocess.Popen(command, stdout=subprocess.PIPE, bufsize=-1)
607
  try:
608
- while True:
609
- line = process.stdout.readline()
610
- if not line:
611
- break
612
- decoded_line = line.decode("utf-8")
613
- if progress:
614
- decoded_line = (
615
- decoded_line.replace("\\033[K", "\033[K")
616
- .replace("\\033[1A", "\033[1A")
617
- .replace("\\033[1B", "\033[1B")
618
- .replace("\\r", "\r")
619
- .encode(sys.stdout.encoding, errors="replace")
620
- )
621
- sys.stdout.buffer.write(decoded_line)
622
- sys.stdout.flush()
623
- else:
624
- print(decoded_line, end="")
625
-
626
- process.stdout.close()
627
- process.wait()
628
- except KeyboardInterrupt:
629
- print("Killing process... will return when done.")
630
- process.kill()
631
-
632
-
633
- def _create_julia_files(
634
- dataset_filename,
635
- def_datasets,
636
- hyperparam_filename,
637
- def_hyperparams,
638
- niterations,
639
- runfile_filename,
640
- julia_project,
641
- procs,
642
- weights,
643
- X,
644
- variable_names,
645
- need_install,
646
- update,
647
- multithreading,
648
- **kwargs,
649
- ):
650
- with open(hyperparam_filename, "w") as f:
651
- print(def_hyperparams, file=f)
652
-
653
- with open(dataset_filename, "w") as f:
654
- print(def_datasets, file=f)
655
-
656
- with open(runfile_filename, "w") as f:
657
-
658
- print(f"import Pkg", file=f)
659
- print(f'Pkg.activate("{_escape_filename(julia_project)}")', file=f)
660
- if need_install:
661
- print(f"Pkg.instantiate()", file=f)
662
- print("Pkg.update()", file=f)
663
- print("Pkg.precompile()", file=f)
664
- elif update:
665
- print(f"Pkg.update()", file=f)
666
- print(f"using SymbolicRegression", file=f)
667
-
668
- print(f'include("{_escape_filename(hyperparam_filename)}")', file=f)
669
-
670
- print(f'include("{_escape_filename(dataset_filename)}")', file=f)
671
-
672
- varMap = _make_varmap(X, variable_names)
673
-
674
- cprocs = 0 if multithreading else procs
675
- if weights is not None:
676
- print(
677
- f"EquationSearch(X, y, weights=weights, niterations={niterations:d}, varMap={varMap}, options=options, numprocs={cprocs}, multithreading={'true' if multithreading else 'false'})",
678
- file=f,
679
- )
680
- else:
681
- print(
682
- f"EquationSearch(X, y, niterations={niterations:d}, varMap={varMap}, options=options, numprocs={cprocs}, multithreading={'true' if multithreading else 'false'})",
683
- file=f,
684
- )
685
-
686
 
687
- def _make_varmap(X, variable_names):
688
- if len(variable_names) == 0:
689
- return "[" + ",".join([f'"x{i}"' for i in range(X.shape[1])]) + "]"
690
- else:
691
- return "[" + ",".join(['"' + vname + '"' for vname in variable_names]) + "]"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
692
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
693
 
694
- def _make_datasets_julia_str(
695
- X,
696
- X_filename,
697
- weights,
698
- weights_filename,
699
- y,
700
- y_filename,
701
- multioutput,
702
- precision,
703
- **kwargs,
704
- ):
705
- def_datasets = """using DelimitedFiles"""
706
- julia_dtype = {16: "Float16", 32: "Float32", 64: "Float64"}[precision]
707
  np_dtype = {16: np.float16, 32: np.float32, 64: np.float64}[precision]
708
 
709
- np.savetxt(X_filename, X.astype(np_dtype), delimiter=",")
710
- if multioutput:
711
- np.savetxt(y_filename, y.astype(np_dtype), delimiter=",")
712
  else:
713
- np.savetxt(y_filename, y.reshape(-1, 1).astype(np_dtype), delimiter=",")
714
-
715
  if weights is not None:
716
- if multioutput:
717
- np.savetxt(weights_filename, weights.astype(np_dtype), delimiter=",")
718
  else:
719
- np.savetxt(
720
- weights_filename,
721
- weights.reshape(-1, 1).astype(np_dtype),
722
- delimiter=",",
723
- )
 
 
 
 
 
 
 
 
 
 
 
724
 
725
- def_datasets += f"""
726
- X = copy(transpose(readdlm("{_escape_filename(X_filename)}", ',', {julia_dtype}, '\\n')))"""
 
 
 
 
 
 
 
 
 
 
 
727
 
728
- if multioutput:
729
- def_datasets += f"""
730
- y = copy(transpose(readdlm("{_escape_filename(y_filename)}", ',', {julia_dtype}, '\\n')))"""
731
- else:
732
- def_datasets += f"""
733
- y = readdlm("{_escape_filename(y_filename)}", ',', {julia_dtype}, '\\n')[:, 1]"""
 
 
 
 
 
 
 
734
 
735
- if weights is not None:
736
- if multioutput:
737
- def_datasets += f"""
738
- weights = copy(transpose(readdlm("{_escape_filename(weights_filename)}", ',', {julia_dtype}, '\\n')))"""
739
- else:
740
- def_datasets += f"""
741
- weights = readdlm("{_escape_filename(weights_filename)}", ',', {julia_dtype}, '\\n')[:, 1]"""
742
- return def_datasets
743
 
 
744
 
745
- def _make_hyperparams_julia_str(
 
 
746
  X,
747
- alpha,
748
- annealing,
749
- batchSize,
750
- batching,
751
- binary_operators,
752
- constraints_str,
753
- def_hyperparams,
754
  equation_file,
755
- fast_cycle,
756
- fractionReplacedHof,
757
- hofMigration,
758
- maxdepth,
759
- maxsize,
760
- migration,
761
- optimizer_algorithm,
762
- optimizer_nrestarts,
763
- optimize_probability,
764
- optimizer_iterations,
765
- npop,
766
- parsimony,
767
- perturbationFactor,
768
- populations,
769
- shouldOptimizeConstants,
770
- unary_operators,
771
- useFrequency,
772
- warmupMaxsizeBy,
773
- weightAddNode,
774
- ncyclesperiteration,
775
- fractionReplaced,
776
- topn,
777
- verbosity,
778
- progress,
779
- loss,
780
- weightDeleteNode,
781
- weightDoNothing,
782
- weightInsertNode,
783
- weightMutateConstant,
784
- weightMutateOperator,
785
- weightRandomize,
786
- weightSimplify,
787
- tournament_selection_n,
788
- tournament_selection_p,
789
- **kwargs,
790
  ):
791
- try:
792
- term_width = shutil.get_terminal_size().columns
793
- except:
794
- _, term_width = subprocess.check_output(["stty", "size"]).split()
795
 
796
- def tuple_fix(ops):
797
- if len(ops) > 1:
798
- return ", ".join(ops)
799
- if len(ops) == 0:
800
- return ""
801
- return ops[0] + ","
802
-
803
- def_hyperparams += f"""\n
804
- plus=(+)
805
- sub=(-)
806
- mult=(*)
807
- square=SymbolicRegression.square
808
- cube=SymbolicRegression.cube
809
- pow=(^)
810
- div=(/)
811
- log_abs=SymbolicRegression.log_abs
812
- log2_abs=SymbolicRegression.log2_abs
813
- log10_abs=SymbolicRegression.log10_abs
814
- log1p_abs=SymbolicRegression.log1p_abs
815
- acosh_abs=SymbolicRegression.acosh_abs
816
- atanh_clip=SymbolicRegression.atanh_clip
817
- sqrt_abs=SymbolicRegression.sqrt_abs
818
- neg=SymbolicRegression.neg
819
- greater=SymbolicRegression.greater
820
- relu=SymbolicRegression.relu
821
- logical_or=SymbolicRegression.logical_or
822
- logical_and=SymbolicRegression.logical_and
823
- _custom_loss = {loss}
824
-
825
- options = SymbolicRegression.Options(binary_operators={'(' + tuple_fix(binary_operators) + ')'},
826
- unary_operators={'(' + tuple_fix(unary_operators) + ')'},
827
- {constraints_str}
828
- parsimony={parsimony:f}f0,
829
- loss=_custom_loss,
830
- alpha={alpha:f}f0,
831
- maxsize={maxsize:d},
832
- maxdepth={maxdepth:d},
833
- fast_cycle={'true' if fast_cycle else 'false'},
834
- migration={'true' if migration else 'false'},
835
- hofMigration={'true' if hofMigration else 'false'},
836
- fractionReplacedHof={fractionReplacedHof}f0,
837
- shouldOptimizeConstants={'true' if shouldOptimizeConstants else 'false'},
838
- hofFile="{_escape_filename(equation_file)}",
839
- npopulations={populations:d},
840
- optimizer_algorithm="{optimizer_algorithm}",
841
- optimizer_nrestarts={optimizer_nrestarts:d},
842
- optimize_probability={optimize_probability:f}f0,
843
- optimizer_iterations={optimizer_iterations:d},
844
- perturbationFactor={perturbationFactor:f}f0,
845
- annealing={"true" if annealing else "false"},
846
- batching={"true" if batching else "false"},
847
- batchSize={min([batchSize, len(X)]) if batching else len(X):d},
848
- mutationWeights=[
849
- {weightMutateConstant:f},
850
- {weightMutateOperator:f},
851
- {weightAddNode:f},
852
- {weightInsertNode:f},
853
- {weightDeleteNode:f},
854
- {weightSimplify:f},
855
- {weightRandomize:f},
856
- {weightDoNothing:f}
857
- ],
858
- warmupMaxsizeBy={warmupMaxsizeBy:f}f0,
859
- useFrequency={"true" if useFrequency else "false"},
860
- npop={npop:d},
861
- ns={tournament_selection_n:d},
862
- probPickFirst={tournament_selection_p:f}f0,
863
- ncyclesperiteration={ncyclesperiteration:d},
864
- fractionReplaced={fractionReplaced:f}f0,
865
- topn={topn:d},
866
- verbosity=round(Int32, {verbosity:f}),
867
- progress={'true' if progress else 'false'},
868
- terminal_width={term_width:d}
869
- """
870
-
871
- def_hyperparams += "\n)"
872
- return def_hyperparams
873
-
874
-
875
- def _make_constraints_str(binary_operators, constraints, unary_operators, **kwargs):
876
- constraints_str = "una_constraints = ["
877
- first = True
878
- for op in unary_operators:
879
- val = constraints[op]
880
- if not first:
881
- constraints_str += ", "
882
- constraints_str += f"{val:d}"
883
- first = False
884
- constraints_str += """],
885
- bin_constraints = ["""
886
- first = True
887
- for op in binary_operators:
888
- tup = constraints[op]
889
- if not first:
890
- constraints_str += ", "
891
- constraints_str += f"({tup[0]:d}, {tup[1]:d})"
892
- first = False
893
- constraints_str += "],"
894
- return constraints_str
895
 
896
 
897
- def _handle_constraints(binary_operators, constraints, unary_operators, **kwargs):
898
  for op in unary_operators:
899
  if op not in constraints:
900
  constraints[op] = -1
@@ -917,14 +607,14 @@ def _handle_constraints(binary_operators, constraints, unary_operators, **kwargs
917
  )
918
 
919
 
920
- def _create_inline_operators(binary_operators, unary_operators, **kwargs):
921
- def_hyperparams = ""
922
  for op_list in [binary_operators, unary_operators]:
923
  for i, op in enumerate(op_list):
924
  is_user_defined_operator = "(" in op
925
 
926
  if is_user_defined_operator:
927
- def_hyperparams += op + "\n"
928
  # Cut off from the first non-alphanumeric char:
929
  first_non_char = [
930
  j
@@ -933,7 +623,6 @@ def _create_inline_operators(binary_operators, unary_operators, **kwargs):
933
  ][0]
934
  function_name = op[:first_non_char]
935
  op_list[i] = function_name
936
- return def_hyperparams
937
 
938
 
939
  def _handle_feature_selection(
@@ -951,30 +640,6 @@ def _handle_feature_selection(
951
  return X, variable_names, selection
952
 
953
 
954
- def _set_paths(tempdir):
955
- # System-independent paths
956
- pkg_directory = Path(__file__).parents[1]
957
- default_project_file = pkg_directory / "Project.toml"
958
- tmpdir = Path(tempfile.mkdtemp(dir=tempdir))
959
- hyperparam_filename = tmpdir / f"hyperparams.jl"
960
- dataset_filename = tmpdir / f"dataset.jl"
961
- runfile_filename = tmpdir / "runfile.jl"
962
- X_filename = tmpdir / "X.csv"
963
- y_filename = tmpdir / "y.csv"
964
- weights_filename = tmpdir / "weights.csv"
965
- return dict(
966
- pkg_directory=pkg_directory,
967
- default_project_file=default_project_file,
968
- X_filename=X_filename,
969
- dataset_filename=dataset_filename,
970
- hyperparam_filename=hyperparam_filename,
971
- runfile_filename=runfile_filename,
972
- tmpdir=tmpdir,
973
- weights_filename=weights_filename,
974
- y_filename=y_filename,
975
- )
976
-
977
-
978
  def _check_assertions(
979
  X,
980
  binary_operators,
@@ -996,23 +661,6 @@ def _check_assertions(
996
  assert len(variable_names) == X.shape[1]
997
 
998
 
999
- def _check_for_julia_installation():
1000
- try:
1001
- process = subprocess.Popen(["julia", "-v"], stdout=subprocess.PIPE, bufsize=-1)
1002
- while True:
1003
- line = process.stdout.readline()
1004
- if not line:
1005
- break
1006
- process.stdout.close()
1007
- process.wait()
1008
- except FileNotFoundError:
1009
-
1010
- raise RuntimeError(
1011
- f"Your current $PATH is: {os.environ['PATH']}\nPySR could not start julia. Make sure julia is installed and on your $PATH."
1012
- )
1013
- process.kill()
1014
-
1015
-
1016
  def run_feature_selection(X, y, select_k_features):
1017
  """Use a gradient boosting tree regressor as a proxy for finding
1018
  the k most important features in X, returning indices for those
 
27
  selection=None,
28
  )
29
 
30
+ already_ran = False
31
 
32
  sympy_mappings = {
33
  "div": lambda x, y: x / y,
 
99
  weightRandomize=1,
100
  weightSimplify=0.01,
101
  perturbationFactor=1.0,
 
102
  extra_sympy_mappings=None,
103
  extra_torch_mappings=None,
104
  extra_jax_mappings=None,
 
117
  useFrequency=True,
118
  tempdir=None,
119
  delete_tempfiles=True,
 
120
  julia_project=None,
121
  user_input=True,
122
  update=True,
 
133
  Xresampled=None,
134
  precision=32,
135
  multithreading=None,
 
136
  ):
137
  """Run symbolic regression to fit f(X[i, :]) ~ y[i] for all i.
138
  Note: most default parameters have been tuned over several example
 
199
  :type weightRandomize: float
200
  :param weightSimplify: Relative likelihood for mutation to simplify constant parts by evaluation
201
  :type weightSimplify: float
 
 
202
  :param equation_file: Where to save the files (.csv separated by |)
203
  :type equation_file: str
204
  :param verbosity: What verbosity level to use. 0 means minimal print statements.
 
225
  :type constraints: dict
226
  :param useFrequency: whether to measure the frequency of complexities, and use that instead of parsimony to explore equation space. Will naturally find equations of all complexities.
227
  :type useFrequency: bool
 
 
228
  :param tempdir: directory for the temporary files
229
  :type tempdir: str/None
230
  :param delete_tempfiles: whether to delete the temporary files after finishing
 
251
  :type precision: int
252
  :param multithreading: Use multithreading instead of distributed backend. Default is yes. Using procs=0 will turn off both.
253
  :type multithreading: bool
 
 
254
  :returns: Results dataframe, giving complexity, MSE, and equations (as strings), as well as functional forms. If list, each element corresponds to a dataframe of equations for each output.
255
  :type: pd.DataFrame/list
256
  """
257
+ global already_ran
258
+
259
  if binary_operators is None:
260
  binary_operators = "+ * - /".split(" ")
261
  if unary_operators is None:
 
271
  # or procs is set to 0 (serial mode).
272
  multithreading = procs != 0
273
 
 
274
  global Main
275
+ if Main is None:
 
 
 
 
 
276
  if multithreading:
277
  os.environ["JULIA_NUM_THREADS"] = str(procs)
278
+
279
  from julia import Main
280
 
281
+ buffer_available = "buffer" in sys.stdout.__dir__()
282
 
283
  if progress is not None:
284
  if progress and not buffer_available:
 
286
  "Note: it looks like you are running in Jupyter. The progress bar will be turned off."
287
  )
288
  progress = False
 
 
 
 
 
289
  else:
290
  progress = buffer_available
291
 
 
327
  weights,
328
  y,
329
  )
 
 
330
 
331
  if len(X) > 10000 and not batching:
332
  warnings.warn(
 
379
  else:
380
  X, y = _denoise(X, y, Xresampled=Xresampled)
381
 
382
+ pkg_directory = Path(__file__).parents[1]
383
+ default_project_file = pkg_directory / "Project.toml"
384
+ tmpdir = Path(tempfile.mkdtemp(dir=tempdir))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
385
 
386
  if temp_equation_file:
387
+ equation_file = tmpdir / "hall_of_fame.csv"
388
  elif equation_file is None:
389
  date_time = datetime.now().strftime("%Y-%m-%d_%H%M%S.%f")[:-3]
390
  equation_file = "hall_of_fame_" + date_time + ".csv"
391
 
392
+ if julia_project is None:
 
 
 
 
 
393
  manifest_filepath = pkg_directory / "Manifest.toml"
394
+ julia_project = pkg_directory
 
 
 
395
  else:
396
+ manifest_filepath = Path(julia_project) / "Manifest.toml"
397
+ julia_project = Path(julia_project)
398
 
399
+ need_install = False
400
 
401
  if not (manifest_filepath).is_file() and not pyjulia:
402
+ need_install = (not user_input) or _yesno(
403
  "I will install Julia packages using PySR's Project.toml file. OK?"
404
  )
405
+ if need_install:
406
  print("OK. I will install at launch.")
407
  assert update
408
 
409
+ _create_inline_operators(
410
+ binary_operators=binary_operators, unary_operators=unary_operators
411
+ )
412
+ _handle_constraints(
413
+ binary_operators=binary_operators,
414
+ unary_operators=unary_operators,
415
+ constraints=constraints,
416
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
417
 
418
+ una_constraints = [constraints[op] for op in unary_operators]
419
+ bin_constraints = [constraints[op] for op in binary_operators]
420
 
 
 
 
 
421
  try:
422
+ term_width = shutil.get_terminal_size().columns
423
+ except:
424
+ _, term_width = subprocess.check_output(["stty", "size"]).split()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
425
 
426
+ from julia import Pkg
427
+
428
+ Pkg.activate(f"{_escape_filename(julia_project)}")
429
+ if need_install:
430
+ Pkg.instantiate()
431
+ Pkg.update()
432
+ Pkg.precompile()
433
+ elif update:
434
+ Pkg.update()
435
+
436
+ Main.eval("using SymbolicRegression")
437
+
438
+ Main.plus = Main.eval("(+)")
439
+ Main.sub = Main.eval("(-)")
440
+ Main.mult = Main.eval("(*)")
441
+ Main.pow = Main.eval("(^)")
442
+ Main.div = Main.eval("(/)")
443
+
444
+ Main.custom_loss = Main.eval(loss)
445
+
446
+ mutationWeights = [
447
+ float(weightMutateConstant),
448
+ float(weightMutateOperator),
449
+ float(weightAddNode),
450
+ float(weightInsertNode),
451
+ float(weightDeleteNode),
452
+ float(weightSimplify),
453
+ float(weightRandomize),
454
+ float(weightDoNothing),
455
+ ]
456
 
457
+ options = Main.Options(
458
+ binary_operators=Main.eval(str(tuple(binary_operators)).replace("'", "")),
459
+ unary_operators=Main.eval(str(tuple(unary_operators)).replace("'", "")),
460
+ bin_constraints=bin_constraints,
461
+ una_constraints=una_constraints,
462
+ parsimony=float(parsimony),
463
+ loss=Main.custom_loss,
464
+ alpha=float(alpha),
465
+ maxsize=int(maxsize),
466
+ maxdepth=int(maxdepth),
467
+ fast_cycle=fast_cycle,
468
+ migration=migration,
469
+ hofMigration=hofMigration,
470
+ fractionReplacedHof=float(fractionReplacedHof),
471
+ shouldOptimizeConstants=shouldOptimizeConstants,
472
+ hofFile=_escape_filename(equation_file),
473
+ npopulations=int(populations),
474
+ optimizer_algorithm=optimizer_algorithm,
475
+ optimizer_nrestarts=int(optimizer_nrestarts),
476
+ optimize_probability=float(optimize_probability),
477
+ optimizer_iterations=int(optimizer_iterations),
478
+ perturbationFactor=float(perturbationFactor),
479
+ annealing=annealing,
480
+ batching=batching,
481
+ batchSize=int(min([batchSize, len(X)]) if batching else len(X)),
482
+ mutationWeights=mutationWeights,
483
+ warmupMaxsizeBy=float(warmupMaxsizeBy),
484
+ useFrequency=useFrequency,
485
+ npop=int(npop),
486
+ ns=int(tournament_selection_n),
487
+ probPickFirst=float(tournament_selection_p),
488
+ ncyclesperiteration=int(ncyclesperiteration),
489
+ fractionReplaced=float(fractionReplaced),
490
+ topn=int(topn),
491
+ verbosity=int(verbosity),
492
+ progress=progress,
493
+ terminal_width=int(term_width),
494
+ )
495
 
 
 
 
 
 
 
 
 
 
 
 
 
 
496
  np_dtype = {16: np.float16, 32: np.float32, 64: np.float64}[precision]
497
 
498
+ Main.X = np.array(X, dtype=np_dtype).T
499
+ if len(y.shape) == 1:
500
+ Main.y = np.array(y, dtype=np_dtype)
501
  else:
502
+ Main.y = np.array(y, dtype=np_dtype).T
 
503
  if weights is not None:
504
+ if len(weights.shape) == 1:
505
+ Main.weights = np.array(weights, dtype=np_dtype)
506
  else:
507
+ Main.weights = np.array(weights, dtype=np_dtype).T
508
+ else:
509
+ Main.weights = None
510
+
511
+ cprocs = 0 if multithreading else procs
512
+
513
+ output_equations = Main.EquationSearch(
514
+ Main.X,
515
+ Main.y,
516
+ weights=Main.weights,
517
+ niterations=int(niterations),
518
+ varMap=variable_names,
519
+ options=options,
520
+ numprocs=int(cprocs),
521
+ multithreading=bool(multithreading),
522
+ )
523
 
524
+ _set_globals(
525
+ X=X,
526
+ equation_file=equation_file,
527
+ variable_names=variable_names,
528
+ extra_sympy_mappings=extra_sympy_mappings,
529
+ extra_torch_mappings=extra_torch_mappings,
530
+ extra_jax_mappings=extra_jax_mappings,
531
+ output_jax_format=output_jax_format,
532
+ output_torch_format=output_torch_format,
533
+ multioutput=multioutput,
534
+ nout=nout,
535
+ selection=selection,
536
+ )
537
 
538
+ equations = get_hof(
539
+ equation_file=equation_file,
540
+ n_features=X.shape[1],
541
+ variable_names=variable_names,
542
+ output_jax_format=output_jax_format,
543
+ output_torch_format=output_torch_format,
544
+ selection=selection,
545
+ extra_sympy_mappings=extra_sympy_mappings,
546
+ extra_jax_mappings=extra_jax_mappings,
547
+ extra_torch_mappings=extra_torch_mappings,
548
+ multioutput=multioutput,
549
+ nout=nout,
550
+ )
551
 
552
+ if delete_tempfiles:
553
+ shutil.rmtree(tmpdir)
 
 
 
 
 
 
554
 
555
+ return equations, output_equations
556
 
557
+
558
+ def _set_globals(
559
+ *,
560
  X,
 
 
 
 
 
 
 
561
  equation_file,
562
+ variable_names,
563
+ extra_sympy_mappings,
564
+ extra_torch_mappings,
565
+ extra_jax_mappings,
566
+ output_jax_format,
567
+ output_torch_format,
568
+ multioutput,
569
+ nout,
570
+ selection,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
571
  ):
572
+ global global_state
 
 
 
573
 
574
+ global_state["n_features"] = X.shape[1]
575
+ global_state["equation_file"] = equation_file
576
+ global_state["variable_names"] = variable_names
577
+ global_state["extra_sympy_mappings"] = extra_sympy_mappings
578
+ global_state["extra_torch_mappings"] = extra_torch_mappings
579
+ global_state["extra_jax_mappings"] = extra_jax_mappings
580
+ global_state["output_jax_format"] = output_jax_format
581
+ global_state["output_torch_format"] = output_torch_format
582
+ global_state["multioutput"] = multioutput
583
+ global_state["nout"] = nout
584
+ global_state["selection"] = selection
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
585
 
586
 
587
+ def _handle_constraints(binary_operators, unary_operators, constraints):
588
  for op in unary_operators:
589
  if op not in constraints:
590
  constraints[op] = -1
 
607
  )
608
 
609
 
610
+ def _create_inline_operators(binary_operators, unary_operators):
611
+ global Main
612
  for op_list in [binary_operators, unary_operators]:
613
  for i, op in enumerate(op_list):
614
  is_user_defined_operator = "(" in op
615
 
616
  if is_user_defined_operator:
617
+ Main.eval(op)
618
  # Cut off from the first non-alphanumeric char:
619
  first_non_char = [
620
  j
 
623
  ][0]
624
  function_name = op[:first_non_char]
625
  op_list[i] = function_name
 
626
 
627
 
628
  def _handle_feature_selection(
 
640
  return X, variable_names, selection
641
 
642
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
643
  def _check_assertions(
644
  X,
645
  binary_operators,
 
661
  assert len(variable_names) == X.shape[1]
662
 
663
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
664
  def run_feature_selection(X, y, select_k_features):
665
  """Use a gradient boosting tree regressor as a proxy for finding
666
  the k most important features in X, returning indices for those