MilesCranmer commited on
Commit
1e552a8
1 Parent(s): 0e52e29

Use more pure-PyJulia calls rather than evals

Browse files
Files changed (1) hide show
  1. pysr/sr.py +57 -27
pysr/sr.py CHANGED
@@ -479,6 +479,12 @@ def pysr(
479
  else:
480
  manifest_filepath = pkg_directory / "Manifest.toml"
481
 
 
 
 
 
 
 
482
  kwargs["need_install"] = False
483
 
484
  if not (manifest_filepath).is_file() and not pyjulia:
@@ -489,14 +495,37 @@ def pysr(
489
  print("OK. I will install at launch.")
490
  assert update
491
 
492
- kwargs["def_hyperparams"] = _create_inline_operators(**kwargs)
493
 
 
494
  _handle_constraints(**kwargs)
495
-
496
  kwargs["constraints_str"] = _make_constraints_str(**kwargs)
497
  kwargs["def_hyperparams"] = _make_hyperparams_julia_str(**kwargs)
498
 
499
  if pyjulia:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
500
  np_dtype = {16: np.float16, 32: np.float32, 64: np.float64}[precision]
501
 
502
  Main.X = np.array(X, dtype=np_dtype).T
@@ -509,23 +538,32 @@ def pysr(
509
  Main.weights = np.array(weights, dtype=np_dtype)
510
  else:
511
  Main.weights = np.array(weights, dtype=np_dtype).T
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
512
 
513
- kwargs["def_datasets"] = ""
514
  else:
515
  kwargs["def_datasets"] = _make_datasets_julia_str(**kwargs)
516
 
517
  _create_julia_files(**kwargs)
518
- global already_ran_with_pyjulia
519
- if pyjulia:
520
- # Read entire file as a single string:
521
- print("Running main runfile in PyJulia!")
522
- Main.eval(f"include('{_escape_filename(kwargs['runfile_filename'])}')")
523
- already_ran_with_pyjulia = True
524
- else:
525
  _final_pysr_process(**kwargs)
526
 
527
  _set_globals(**kwargs)
528
-
529
  equations = get_hof(**kwargs)
530
 
531
  if delete_tempfiles:
@@ -603,28 +641,19 @@ def _create_julia_files(
603
  weights,
604
  X,
605
  variable_names,
606
- pkg_directory,
607
  need_install,
608
  update,
609
  multithreading,
610
- pyjulia,
611
  **kwargs,
612
  ):
613
- global already_ran_with_pyjulia
614
  with open(hyperparam_filename, "w") as f:
615
  print(def_hyperparams, file=f)
616
 
617
- if not pyjulia:
618
  with open(dataset_filename, "w") as f:
619
  print(def_datasets, file=f)
620
 
621
  with open(runfile_filename, "w") as f:
622
- if julia_project is None:
623
- julia_project = pkg_directory
624
- else:
625
- julia_project = Path(julia_project)
626
 
627
- if (pyjulia and not already_ran_with_pyjulia) or (not pyjulia):
628
  print(f"import Pkg", file=f)
629
  print(f'Pkg.activate("{_escape_filename(julia_project)}")', file=f)
630
  if need_install:
@@ -637,15 +666,9 @@ def _create_julia_files(
637
 
638
  print(f'include("{_escape_filename(hyperparam_filename)}")', file=f)
639
 
640
- if not pyjulia:
641
  print(f'include("{_escape_filename(dataset_filename)}")', file=f)
642
 
643
- if len(variable_names) == 0:
644
- varMap = "[" + ",".join([f'"x{i}"' for i in range(X.shape[1])]) + "]"
645
- else:
646
- varMap = (
647
- "[" + ",".join(['"' + vname + '"' for vname in variable_names]) + "]"
648
- )
649
 
650
  cprocs = 0 if multithreading else procs
651
  if weights is not None:
@@ -660,6 +683,13 @@ def _create_julia_files(
660
  )
661
 
662
 
 
 
 
 
 
 
 
663
  def _make_datasets_julia_str(
664
  X,
665
  X_filename,
 
479
  else:
480
  manifest_filepath = pkg_directory / "Manifest.toml"
481
 
482
+ # Set julia project to correct directory:
483
+ if kwargs["julia_project"] is None:
484
+ kwargs["julia_project"] = pkg_directory
485
+ else:
486
+ kwargs["julia_project"] = Path(kwargs["julia_project"])
487
+
488
  kwargs["need_install"] = False
489
 
490
  if not (manifest_filepath).is_file() and not pyjulia:
 
495
  print("OK. I will install at launch.")
496
  assert update
497
 
498
+ global already_ran_with_pyjulia
499
 
500
+ kwargs["def_hyperparams"] = _create_inline_operators(**kwargs)
501
  _handle_constraints(**kwargs)
 
502
  kwargs["constraints_str"] = _make_constraints_str(**kwargs)
503
  kwargs["def_hyperparams"] = _make_hyperparams_julia_str(**kwargs)
504
 
505
  if pyjulia:
506
+
507
+ from julia import Pkg
508
+
509
+ Pkg.activate(f"{_escape_filename(kwargs['julia_project'])}")
510
+ if kwargs["need_install"]:
511
+ Pkg.instantiate()
512
+ Pkg.update()
513
+ Pkg.precompile()
514
+ elif update:
515
+ Pkg.update()
516
+ from julia import SymbolicRegression
517
+
518
+ already_ran_with_pyjulia = True
519
+
520
+ X = kwargs["X"]
521
+ y = kwargs["y"]
522
+ weights = kwargs["weights"]
523
+ def_hyperparams = kwargs["def_hyperparams"]
524
+ variable_names = kwargs["variable_names"]
525
+ multithreading = kwargs["multithreading"]
526
+ procs = kwargs["procs"]
527
+ niterations = kwargs["niterations"]
528
+ precision = kwargs["precision"]
529
  np_dtype = {16: np.float16, 32: np.float32, 64: np.float64}[precision]
530
 
531
  Main.X = np.array(X, dtype=np_dtype).T
 
538
  Main.weights = np.array(weights, dtype=np_dtype)
539
  else:
540
  Main.weights = np.array(weights, dtype=np_dtype).T
541
+ else:
542
+ Main.weights = None
543
+
544
+ Main.eval(def_hyperparams)
545
+
546
+ varMap = Main.eval(_make_varmap(X, variable_names))
547
+ cprocs = 0 if multithreading else procs
548
+
549
+ SymbolicRegression.EquationSearch(
550
+ Main.X,
551
+ Main.y,
552
+ weights=Main.weights,
553
+ niterations=niterations,
554
+ varMap=varMap,
555
+ options=Main.options,
556
+ numprocs=cprocs,
557
+ multithreading=multithreading,
558
+ )
559
 
 
560
  else:
561
  kwargs["def_datasets"] = _make_datasets_julia_str(**kwargs)
562
 
563
  _create_julia_files(**kwargs)
 
 
 
 
 
 
 
564
  _final_pysr_process(**kwargs)
565
 
566
  _set_globals(**kwargs)
 
567
  equations = get_hof(**kwargs)
568
 
569
  if delete_tempfiles:
 
641
  weights,
642
  X,
643
  variable_names,
 
644
  need_install,
645
  update,
646
  multithreading,
 
647
  **kwargs,
648
  ):
 
649
  with open(hyperparam_filename, "w") as f:
650
  print(def_hyperparams, file=f)
651
 
 
652
  with open(dataset_filename, "w") as f:
653
  print(def_datasets, file=f)
654
 
655
  with open(runfile_filename, "w") as f:
 
 
 
 
656
 
 
657
  print(f"import Pkg", file=f)
658
  print(f'Pkg.activate("{_escape_filename(julia_project)}")', file=f)
659
  if need_install:
 
666
 
667
  print(f'include("{_escape_filename(hyperparam_filename)}")', file=f)
668
 
 
669
  print(f'include("{_escape_filename(dataset_filename)}")', file=f)
670
 
671
+ varMap = _make_varmap(X, variable_names)
 
 
 
 
 
672
 
673
  cprocs = 0 if multithreading else procs
674
  if weights is not None:
 
683
  )
684
 
685
 
686
+ def _make_varmap(X, variable_names):
687
+ if len(variable_names) == 0:
688
+ return "[" + ",".join([f'"x{i}"' for i in range(X.shape[1])]) + "]"
689
+ else:
690
+ return "[" + ",".join(['"' + vname + '"' for vname in variable_names]) + "]"
691
+
692
+
693
  def _make_datasets_julia_str(
694
  X,
695
  X_filename,