MilesCranmer commited on
Commit
82b18ca
1 Parent(s): 44b5271

Include unit-test for multi-output equations

Browse files
Files changed (1) hide show
  1. test/test.py +53 -7
test/test.py CHANGED
@@ -296,13 +296,24 @@ def manually_create_model(equations, feature_names=None):
296
  )
297
 
298
  # Set up internal parameters as if it had been fitted:
299
- model.equation_file_ = "equation_file.csv"
300
- model.nout_ = 1
301
- model.selection_mask_ = None
302
- model.feature_names_in_ = np.array(feature_names, dtype=object)
303
- equations["complexity loss equation".split(" ")].to_csv(
304
- "equation_file.csv.bkup", sep="|"
305
- )
 
 
 
 
 
 
 
 
 
 
 
306
 
307
  model.refresh()
308
 
@@ -581,6 +592,41 @@ class TestLaTeXTable(unittest.TestCase):
581
  true_latex_table_str = self.create_true_latex(middle_part)
582
  self.assertEqual(latex_table_str, true_latex_table_str)
583
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
584
  def test_latex_float_precision(self):
585
  """Test that we can print latex expressions with custom precision"""
586
  expr = sympy.Float(4583.4485748, dps=50)
 
296
  )
297
 
298
  # Set up internal parameters as if it had been fitted:
299
+ if isinstance(equations, list):
300
+ # Multi-output.
301
+ model.equation_file_ = "equation_file.csv"
302
+ model.nout_ = len(equations)
303
+ model.selection_mask_ = None
304
+ model.feature_names_in_ = np.array(feature_names, dtype=object)
305
+ for i in range(model.nout_):
306
+ equations[i]["complexity loss equation".split(" ")].to_csv(
307
+ f"equation_file.csv.out{i+1}.bkup", sep="|"
308
+ )
309
+ else:
310
+ model.equation_file_ = "equation_file.csv"
311
+ model.nout_ = 1
312
+ model.selection_mask_ = None
313
+ model.feature_names_in_ = np.array(feature_names, dtype=object)
314
+ equations["complexity loss equation".split(" ")].to_csv(
315
+ "equation_file.csv.bkup", sep="|"
316
+ )
317
 
318
  model.refresh()
319
 
 
592
  true_latex_table_str = self.create_true_latex(middle_part)
593
  self.assertEqual(latex_table_str, true_latex_table_str)
594
 
595
+ def test_multi_output(self):
596
+ equations1 = pd.DataFrame(
597
+ dict(
598
+ equation=["x0", "cos(x0)", "x0 + x1 - cos(x1 * x0)"],
599
+ loss=[1.052, 0.02315, 1.12347e-15],
600
+ complexity=[1, 2, 8],
601
+ )
602
+ )
603
+ equations2 = pd.DataFrame(
604
+ dict(
605
+ equation=["x1", "cos(x1)", "x0 * x0 * x1"],
606
+ loss=[1.32, 0.052, 2e-15],
607
+ complexity=[1, 2, 5],
608
+ )
609
+ )
610
+ equations = [equations1, equations2]
611
+ model = manually_create_model(equations)
612
+ middle_part_1 = r"""
613
+ $x_{0}$ & $1$ & $1.05$ & $0.0$ \\
614
+ $\cos{\left(x_{0} \right)}$ & $2$ & $0.0232$ & $3.82$ \\
615
+ $x_{0} + x_{1} - \cos{\left(x_{0} x_{1} \right)}$ & $8$ & $1.12 \cdot 10^{-15}$ & $5.11$ \\
616
+ """
617
+ middle_part_2 = r"""
618
+ $x_{1}$ & $1$ & $1.32$ & $0.0$ \\
619
+ $\cos{\left(x_{1} \right)}$ & $2$ & $0.0520$ & $3.23$ \\
620
+ $x_{0}^{2} x_{1}$ & $5$ & $2.00 \cdot 10^{-15}$ & $10.3$ \\
621
+ """
622
+ true_latex_table_str = "\n\n".join(
623
+ self.create_true_latex(part, include_score=True)
624
+ for part in [middle_part_1, middle_part_2]
625
+ )
626
+ latex_table_str = model.latex_table()
627
+
628
+ self.assertEqual(latex_table_str, true_latex_table_str)
629
+
630
  def test_latex_float_precision(self):
631
  """Test that we can print latex expressions with custom precision"""
632
  expr = sympy.Float(4583.4485748, dps=50)