MilesCranmer commited on
Commit
e84bed4
·
unverified ·
1 Parent(s): ddeae6c

test: fix mypy errors with sympy

Browse files
pysr/export_jax.py CHANGED
@@ -1,5 +1,5 @@
1
  import numpy as np # noqa: F401
2
- import sympy
3
 
4
  # Special since need to reduce arguments.
5
  MUL = 0
 
1
  import numpy as np # noqa: F401
2
+ import sympy # type: ignore
3
 
4
  # Special since need to reduce arguments.
5
  MUL = 0
pysr/export_latex.py CHANGED
@@ -3,8 +3,8 @@
3
  from typing import List, Optional, Tuple
4
 
5
  import pandas as pd
6
- import sympy
7
- from sympy.printing.latex import LatexPrinter
8
 
9
 
10
  class PreciseLatexPrinter(LatexPrinter):
 
3
  from typing import List, Optional, Tuple
4
 
5
  import pandas as pd
6
+ import sympy # type: ignore
7
+ from sympy.printing.latex import LatexPrinter # type: ignore
8
 
9
 
10
  class PreciseLatexPrinter(LatexPrinter):
pysr/export_numpy.py CHANGED
@@ -6,7 +6,7 @@ from typing import List, Union
6
  import numpy as np
7
  import pandas as pd
8
  from numpy.typing import NDArray
9
- from sympy import Expr, Symbol, lambdify
10
 
11
 
12
  def sympy2numpy(eqn, sympy_symbols, *, selection=None):
 
6
  import numpy as np
7
  import pandas as pd
8
  from numpy.typing import NDArray
9
+ from sympy import Expr, Symbol, lambdify # type: ignore
10
 
11
 
12
  def sympy2numpy(eqn, sympy_symbols, *, selection=None):
pysr/export_sympy.py CHANGED
@@ -2,7 +2,7 @@
2
 
3
  from typing import Callable, Dict, List, Optional
4
 
5
- import sympy
6
  from sympy import sympify
7
 
8
  from .utils import ArrayLike
 
2
 
3
  from typing import Callable, Dict, List, Optional
4
 
5
+ import sympy # type: ignore
6
  from sympy import sympify
7
 
8
  from .utils import ArrayLike
pysr/export_torch.py CHANGED
@@ -4,7 +4,7 @@ import collections as co
4
  import functools as ft
5
 
6
  import numpy as np # noqa: F401
7
- import sympy
8
 
9
 
10
  def _reduce(fn):
 
4
  import functools as ft
5
 
6
  import numpy as np # noqa: F401
7
+ import sympy # type: ignore
8
 
9
 
10
  def _reduce(fn):
pysr/test/test.py CHANGED
@@ -9,7 +9,7 @@ from pathlib import Path
9
 
10
  import numpy as np
11
  import pandas as pd
12
- import sympy
13
  from sklearn.utils.estimator_checks import check_estimator
14
 
15
  from pysr import PySRRegressor, install, jl
 
9
 
10
  import numpy as np
11
  import pandas as pd
12
+ import sympy # type: ignore
13
  from sklearn.utils.estimator_checks import check_estimator
14
 
15
  from pysr import PySRRegressor, install, jl
pysr/test/test_jax.py CHANGED
@@ -3,7 +3,7 @@ from functools import partial
3
 
4
  import numpy as np
5
  import pandas as pd
6
- import sympy
7
 
8
  import pysr
9
  from pysr import PySRRegressor, sympy2jax
@@ -102,7 +102,7 @@ class TestJAX(unittest.TestCase):
102
  )
103
 
104
  def test_issue_656(self):
105
- import sympy
106
 
107
  E_plus_x1 = sympy.exp(1) + sympy.symbols("x1")
108
  f, params = pysr.export_jax.sympy2jax(E_plus_x1, [sympy.symbols("x1")])
 
3
 
4
  import numpy as np
5
  import pandas as pd
6
+ import sympy # type: ignore
7
 
8
  import pysr
9
  from pysr import PySRRegressor, sympy2jax
 
102
  )
103
 
104
  def test_issue_656(self):
105
+ import sympy # type: ignore
106
 
107
  E_plus_x1 = sympy.exp(1) + sympy.symbols("x1")
108
  f, params = pysr.export_jax.sympy2jax(E_plus_x1, [sympy.symbols("x1")])
pysr/test/test_torch.py CHANGED
@@ -2,7 +2,7 @@ import unittest
2
 
3
  import numpy as np
4
  import pandas as pd
5
- import sympy
6
 
7
  import pysr
8
  from pysr import PySRRegressor, sympy2torch
 
2
 
3
  import numpy as np
4
  import pandas as pd
5
+ import sympy # type: ignore
6
 
7
  import pysr
8
  from pysr import PySRRegressor, sympy2torch