MilesCranmer commited on
Commit
f77e43f
1 Parent(s): 1159740

Fix selection not being propagated

Browse files
Files changed (2) hide show
  1. pysr/sr.py +2 -1
  2. setup.py +1 -1
pysr/sr.py CHANGED
@@ -750,6 +750,7 @@ def get_hof(equation_file=None, n_features=None, variable_names=None,
750
  if output_jax_format is None: output_jax_format = global_state['output_jax_format']
751
  if multioutput is None: multioutput = global_state['multioutput']
752
  if nout is None: nout = global_state['nout']
 
753
 
754
  global_state['selection'] = selection
755
  global_state['equation_file'] = equation_file
@@ -812,7 +813,7 @@ def get_hof(equation_file=None, n_features=None, variable_names=None,
812
  # Torch:
813
  if output_torch_format:
814
  from .export_torch import sympy2torch
815
- module = sympy2torch(eqn, sympy_symbols, selection)
816
  torch_format.append(module)
817
 
818
  curMSE = output.loc[i, 'MSE']
 
750
  if output_jax_format is None: output_jax_format = global_state['output_jax_format']
751
  if multioutput is None: multioutput = global_state['multioutput']
752
  if nout is None: nout = global_state['nout']
753
+ if selection is None: selection = global_state['selection']
754
 
755
  global_state['selection'] = selection
756
  global_state['equation_file'] = equation_file
 
813
  # Torch:
814
  if output_torch_format:
815
  from .export_torch import sympy2torch
816
+ module = sympy2torch(eqn, sympy_symbols, selection=selection)
817
  torch_format.append(module)
818
 
819
  curMSE = output.loc[i, 'MSE']
setup.py CHANGED
@@ -5,7 +5,7 @@ with open("README.md", "r") as fh:
5
 
6
  setuptools.setup(
7
  name="pysr",
8
- version="0.6.0",
9
  author="Miles Cranmer",
10
  author_email="[email protected]",
11
  description="Simple and efficient symbolic regression",
 
5
 
6
  setuptools.setup(
7
  name="pysr",
8
+ version="0.6.0-1",
9
  author="Miles Cranmer",
10
  author_email="[email protected]",
11
  description="Simple and efficient symbolic regression",