MilesCranmer commited on
Commit
5e0dd71
1 Parent(s): 0a3d3e9

Fix up other arguments in test

Browse files
Files changed (2) hide show
  1. test/test_jax.py +1 -3
  2. test/test_torch.py +4 -8
test/test_jax.py CHANGED
@@ -38,9 +38,7 @@ class TestJAX(unittest.TestCase):
38
  model = PySRRegressor(
39
  equation_file="equation_file.csv",
40
  output_jax_format=True,
41
- variables_names="x1 x2 x3".split(" "),
42
- multioutput=False,
43
- nout=1,
44
  selection=[1, 2, 3],
45
  )
46
 
 
38
  model = PySRRegressor(
39
  equation_file="equation_file.csv",
40
  output_jax_format=True,
41
+ variable_names="x1 x2 x3".split(" "),
 
 
42
  selection=[1, 2, 3],
43
  )
44
 
test/test_torch.py CHANGED
@@ -37,13 +37,11 @@ class TestTorch(unittest.TestCase):
37
  model = PySRRegressor(
38
  model_selection="accuracy",
39
  equation_file="equation_file.csv",
40
- variables_names="x1 x2 x3".split(" "),
41
  extra_sympy_mappings={},
42
  output_torch_format=True,
43
- multioutput=False,
44
- nout=1,
45
- selection=[1, 2, 3],
46
  )
 
47
  model.n_features = 2 # TODO: Why is this 2 and not 3?
48
  model.using_pandas = False
49
  model.refresh()
@@ -91,14 +89,12 @@ class TestTorch(unittest.TestCase):
91
  model = PySRRegressor(
92
  model_selection="accuracy",
93
  equation_file="equation_file_custom_operator.csv",
94
- variables_names="x1 x2 x3".split(" "),
95
  extra_sympy_mappings={"mycustomoperator": sympy.sin},
96
  extra_torch_mappings={"mycustomoperator": torch.sin},
97
  output_torch_format=True,
98
- multioutput=False,
99
- nout=1,
100
- selection=[0, 1, 2],
101
  )
 
102
  model.n_features = 3
103
  model.using_pandas = False
104
  model.refresh()
 
37
  model = PySRRegressor(
38
  model_selection="accuracy",
39
  equation_file="equation_file.csv",
40
+ variable_names="x1 x2 x3".split(" "),
41
  extra_sympy_mappings={},
42
  output_torch_format=True,
 
 
 
43
  )
44
+ model.selection = [1, 2, 3]
45
  model.n_features = 2 # TODO: Why is this 2 and not 3?
46
  model.using_pandas = False
47
  model.refresh()
 
89
  model = PySRRegressor(
90
  model_selection="accuracy",
91
  equation_file="equation_file_custom_operator.csv",
92
+ variable_names="x1 x2 x3".split(" "),
93
  extra_sympy_mappings={"mycustomoperator": sympy.sin},
94
  extra_torch_mappings={"mycustomoperator": torch.sin},
95
  output_torch_format=True,
 
 
 
96
  )
97
+ model.selection = [0, 1, 2]
98
  model.n_features = 3
99
  model.using_pandas = False
100
  model.refresh()