MilesCranmer commited on
Commit
6210be0
1 Parent(s): f257e58

Allow specifying precision in LaTeX output

Browse files
Files changed (1) hide show
  1. pysr/sr.py +14 -3
pysr/sr.py CHANGED
@@ -1721,7 +1721,7 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
1721
  return [eq["sympy_format"] for eq in best_equation]
1722
  return best_equation["sympy_format"]
1723
 
1724
- def latex(self, index=None):
1725
  """
1726
  Return latex representation of the equation(s) chosen by `model_selection`.
1727
 
@@ -1733,6 +1733,9 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
1733
  the `model_selection` parameter. If there are multiple output
1734
  features, then pass a list of indices with the order the same
1735
  as the output feature.
 
 
 
1736
 
1737
  Returns
1738
  -------
@@ -1742,8 +1745,16 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
1742
  self.refresh()
1743
  sympy_representation = self.sympy(index=index)
1744
  if self.nout_ > 1:
1745
- return [sympy.latex(s) for s in sympy_representation]
1746
- return sympy.latex(sympy_representation)
 
 
 
 
 
 
 
 
1747
 
1748
  def jax(self, index=None):
1749
  """
 
1721
  return [eq["sympy_format"] for eq in best_equation]
1722
  return best_equation["sympy_format"]
1723
 
1724
+ def latex(self, index=None, precision=3):
1725
  """
1726
  Return latex representation of the equation(s) chosen by `model_selection`.
1727
 
 
1733
  the `model_selection` parameter. If there are multiple output
1734
  features, then pass a list of indices with the order the same
1735
  as the output feature.
1736
+ precision : int, default=3
1737
+ The number of significant figures shown in the LaTeX
1738
+ representation.
1739
 
1740
  Returns
1741
  -------
 
1745
  self.refresh()
1746
  sympy_representation = self.sympy(index=index)
1747
  if self.nout_ > 1:
1748
+ output = []
1749
+ for s in sympy_representation:
1750
+ raw_latex = sympy.latex(s)
1751
+ reduced_latex = set_precision_of_constants_in_string(
1752
+ raw_latex, precision
1753
+ )
1754
+ output.append(reduced_latex)
1755
+ return output
1756
+ raw_latex = sympy.latex(sympy_representation)
1757
+ return set_precision_of_constants_in_string(raw_latex, precision)
1758
 
1759
  def jax(self, index=None):
1760
  """