MilesCranmer commited on
Commit
c96b30c
·
1 Parent(s): 8c55475

Clean up global variables into single dict

Browse files
Files changed (2) hide show
  1. pysr/sr.py +60 -49
  2. test/test.py +2 -1
pysr/sr.py CHANGED
@@ -14,12 +14,19 @@ from pathlib import Path
14
  from datetime import datetime
15
  import warnings
16
 
17
- global_equation_file = 'hall_of_fame.csv'
18
- global_n_features = None
19
- global_variable_names = []
20
- global_extra_sympy_mappings = {}
21
- global_multioutput = False
22
- global_nout = 1
 
 
 
 
 
 
 
23
 
24
  sympy_mappings = {
25
  'div': lambda x, y : x/y,
@@ -62,16 +69,20 @@ sympy_mappings = {
62
 
63
  class CallableEquation(object):
64
  """Simple wrapper for numpy lambda functions built with sympy"""
65
- def __init__(self, sympy_symbols, eqn):
66
  self._sympy = eqn
67
  self._sympy_symbols = sympy_symbols
 
68
  self._lambda = lambdify(sympy_symbols, eqn)
69
 
70
  def __repr__(self):
71
  return f"PySRFunction(X=>{self._sympy})"
72
 
73
  def __call__(self, X):
74
- return self._lambda(*X.T)
 
 
 
75
 
76
  def pysr(X, y, weights=None,
77
  binary_operators=None,
@@ -284,7 +295,7 @@ def pysr(X, y, weights=None,
284
  if maxsize > 40:
285
  warnings.warn("Note: Using a large maxsize for the equation search will be slow and use significant memory. You should consider turning `useFrequency` to False, and perhaps use `warmupMaxsizeBy`.")
286
 
287
- X, variable_names = _handle_feature_selection(
288
  X, select_k_features,
289
  use_custom_variable_names, variable_names, y
290
  )
@@ -343,6 +354,7 @@ def pysr(X, y, weights=None,
343
  julia_project=julia_project, loss=loss,
344
  output_jax_format=output_jax_format,
345
  output_torch_format=output_torch_format,
 
346
  multioutput=multioutput, nout=nout)
347
 
348
  kwargs = {**_set_paths(tempdir), **kwargs}
@@ -391,21 +403,13 @@ def pysr(X, y, weights=None,
391
  return equations
392
 
393
 
 
 
394
 
395
- def _set_globals(X, equation_file, extra_sympy_mappings, variable_names,
396
- multioutput, nout, **kwargs):
397
- global global_n_features
398
- global global_equation_file
399
- global global_variable_names
400
- global global_extra_sympy_mappings
401
- global global_multioutput
402
- global global_nout
403
- global_n_features = X.shape[1]
404
- global_equation_file = equation_file
405
- global_variable_names = variable_names
406
- global_extra_sympy_mappings = extra_sympy_mappings
407
- global_multioutput = multioutput
408
- global_nout = nout
409
 
410
 
411
  def _final_pysr_process(julia_optimization, runfile_filename, timeout, **kwargs):
@@ -668,7 +672,9 @@ def _handle_feature_selection(X, select_k_features, use_custom_variable_names, v
668
 
669
  if use_custom_variable_names:
670
  variable_names = [variable_names[selection[i]] for i in range(len(selection))]
671
- return X, variable_names
 
 
672
 
673
 
674
  def _set_paths(tempdir):
@@ -732,33 +738,38 @@ def run_feature_selection(X, y, select_k_features):
732
  return selector.get_support(indices=True)
733
 
734
  def get_hof(equation_file=None, n_features=None, variable_names=None,
735
- extra_sympy_mappings=None, output_jax_format=False,
736
- output_torch_format=False,
737
  extra_jax_mappings=None, extra_torch_mappings=None,
738
  multioutput=None, nout=None, **kwargs):
739
  """Get the equations from a hall of fame file. If no arguments
740
  entered, the ones used previously from a call to PySR will be used."""
741
 
742
- global global_n_features
743
- global global_equation_file
744
- global global_variable_names
745
- global global_extra_sympy_mappings
746
- global global_multioutput
747
- global global_nout
748
-
749
- if equation_file is None: equation_file = global_equation_file
750
- if n_features is None: n_features = global_n_features
751
- if variable_names is None: variable_names = global_variable_names
752
- if extra_sympy_mappings is None: extra_sympy_mappings = global_extra_sympy_mappings
753
- if multioutput is None: multioutput = global_multioutput
754
- if nout is None: nout = global_nout
755
-
756
- global_equation_file = equation_file
757
- global_n_features = n_features
758
- global_variable_names = variable_names
759
- global_extra_sympy_mappings = extra_sympy_mappings
760
- global_multioutput = multioutput
761
- global_nout = nout
 
 
 
 
 
762
 
763
  try:
764
  if multioutput:
@@ -797,18 +808,18 @@ def get_hof(equation_file=None, n_features=None, variable_names=None,
797
  sympy_format.append(eqn)
798
 
799
  # Numpy:
800
- lambda_format.append(CallableEquation(sympy_symbols, eqn))
801
 
802
  # JAX:
803
  if output_jax_format:
804
  from .export_jax import sympy2jax
805
- func, params = sympy2jax(eqn, sympy_symbols)
806
  jax_format.append({'callable': func, 'parameters': params})
807
 
808
  # Torch:
809
  if output_torch_format:
810
  from .export_torch import sympy2torch
811
- module = sympy2torch(eqn, sympy_symbols)
812
  torch_format.append(module)
813
 
814
  curMSE = output.loc[i, 'MSE']
 
14
  from datetime import datetime
15
  import warnings
16
 
17
+ global_state = dict(
18
+ equation_file='hall_of_fame.csv',
19
+ n_features=None,
20
+ variable_names=[],
21
+ extra_sympy_mappings={},
22
+ extra_torch_mappings={},
23
+ extra_jax_mappings={},
24
+ output_jax_format=False,
25
+ output_torch_format=False,
26
+ multioutput=False,
27
+ nout=1,
28
+ selection=None
29
+ )
30
 
31
  sympy_mappings = {
32
  'div': lambda x, y : x/y,
 
69
 
70
  class CallableEquation(object):
71
  """Simple wrapper for numpy lambda functions built with sympy"""
72
+ def __init__(self, sympy_symbols, eqn, selection=None):
73
  self._sympy = eqn
74
  self._sympy_symbols = sympy_symbols
75
+ self._selection = selection
76
  self._lambda = lambdify(sympy_symbols, eqn)
77
 
78
  def __repr__(self):
79
  return f"PySRFunction(X=>{self._sympy})"
80
 
81
  def __call__(self, X):
82
+ if self._selection is not None:
83
+ return self._lambda(*X[:, self._selection].T)
84
+ else:
85
+ return self._lambda(*X.T)
86
 
87
  def pysr(X, y, weights=None,
88
  binary_operators=None,
 
295
  if maxsize > 40:
296
  warnings.warn("Note: Using a large maxsize for the equation search will be slow and use significant memory. You should consider turning `useFrequency` to False, and perhaps use `warmupMaxsizeBy`.")
297
 
298
+ X, variable_names, selection = _handle_feature_selection(
299
  X, select_k_features,
300
  use_custom_variable_names, variable_names, y
301
  )
 
354
  julia_project=julia_project, loss=loss,
355
  output_jax_format=output_jax_format,
356
  output_torch_format=output_torch_format,
357
+ selection=selection,
358
  multioutput=multioutput, nout=nout)
359
 
360
  kwargs = {**_set_paths(tempdir), **kwargs}
 
403
  return equations
404
 
405
 
406
+ def _set_globals(X, **kwargs):
407
+ global global_state
408
 
409
+ global_state['n_features'] = X.shape[1]
410
+ for key, value in kwargs.items():
411
+ if key in global_state:
412
+ global_state[key] = value
 
 
 
 
 
 
 
 
 
 
413
 
414
 
415
  def _final_pysr_process(julia_optimization, runfile_filename, timeout, **kwargs):
 
672
 
673
  if use_custom_variable_names:
674
  variable_names = [variable_names[selection[i]] for i in range(len(selection))]
675
+ else:
676
+ selection = None
677
+ return X, variable_names, selection
678
 
679
 
680
  def _set_paths(tempdir):
 
738
  return selector.get_support(indices=True)
739
 
740
  def get_hof(equation_file=None, n_features=None, variable_names=None,
741
+ output_jax_format=None, output_torch_format=None,
742
+ selection=None, extra_sympy_mappings=None,
743
  extra_jax_mappings=None, extra_torch_mappings=None,
744
  multioutput=None, nout=None, **kwargs):
745
  """Get the equations from a hall of fame file. If no arguments
746
  entered, the ones used previously from a call to PySR will be used."""
747
 
748
+ global global_state
749
+
750
+ if equation_file is None: equation_file = global_state['equation_file']
751
+ if n_features is None: n_features = global_state['n_features']
752
+ if variable_names is None: variable_names = global_state['variable_names']
753
+ if extra_sympy_mappings is None: extra_sympy_mappings = global_state['extra_sympy_mappings']
754
+ if extra_jax_mappings is None: extra_jax_mappings = global_state['extra_jax_mappings']
755
+ if extra_torch_mappings is None: extra_torch_mappings = global_state['extra_torch_mappings']
756
+ if output_torch_format is None: output_torch_format = global_state['output_torch_format']
757
+ if output_jax_format is None: output_jax_format = global_state['output_jax_format']
758
+ if multioutput is None: multioutput = global_state['multioutput']
759
+ if nout is None: nout = global_state['nout']
760
+
761
+ global_state['selection'] = selection
762
+ global_state['equation_file'] = equation_file
763
+ global_state['n_features'] = n_features
764
+ global_state['variable_names'] = variable_names
765
+ global_state['extra_sympy_mappings'] = extra_sympy_mappings
766
+ global_state['extra_jax_mappings'] = extra_jax_mappings
767
+ global_state['extra_torch_mappings'] = extra_torch_mappings
768
+ global_state['output_torch_format'] = output_torch_format
769
+ global_state['output_jax_format'] = output_jax_format
770
+ global_state['multioutput'] = multioutput
771
+ global_state['nout'] = nout
772
+ global_state['selection'] = selection
773
 
774
  try:
775
  if multioutput:
 
808
  sympy_format.append(eqn)
809
 
810
  # Numpy:
811
+ lambda_format.append(CallableEquation(sympy_symbols, eqn, selection))
812
 
813
  # JAX:
814
  if output_jax_format:
815
  from .export_jax import sympy2jax
816
+ func, params = sympy2jax(eqn, sympy_symbols, selection)
817
  jax_format.append({'callable': func, 'parameters': params})
818
 
819
  # Torch:
820
  if output_torch_format:
821
  from .export_torch import sympy2torch
822
+ module = sympy2torch(eqn, sympy_symbols, selection)
823
  torch_format.append(module)
824
 
825
  curMSE = output.loc[i, 'MSE']
test/test.py CHANGED
@@ -112,11 +112,12 @@ class TestFeatureSelection(unittest.TestCase):
112
  X = np.random.randn(20000, 5)
113
  y = X[:, 2]**2 + X[:, 3]**2
114
  var_names = [f'x{i}' for i in range(5)]
115
- selected_X, selected_var_names = _handle_feature_selection(
116
  X, select_k_features=2,
117
  use_custom_variable_names=True,
118
  variable_names=[f'x{i}' for i in range(5)],
119
  y=y)
 
120
  self.assertEqual(set(selected_var_names), set('x2 x3'.split(' ')))
121
  np.testing.assert_array_equal(
122
  np.sort(selected_X, axis=1),
 
112
  X = np.random.randn(20000, 5)
113
  y = X[:, 2]**2 + X[:, 3]**2
114
  var_names = [f'x{i}' for i in range(5)]
115
+ selected_X, selected_var_names, selection = _handle_feature_selection(
116
  X, select_k_features=2,
117
  use_custom_variable_names=True,
118
  variable_names=[f'x{i}' for i in range(5)],
119
  y=y)
120
+ self.assertTrue((2 in selection) and (3 in selection))
121
  self.assertEqual(set(selected_var_names), set('x2 x3'.split(' ')))
122
  np.testing.assert_array_equal(
123
  np.sort(selected_X, axis=1),