MilesCranmer commited on
Commit
583beaf
1 Parent(s): 7113eed

refactor: typings for sympy export

Browse files
Files changed (1) hide show
  1. pysr/export_sympy.py +5 -3
pysr/export_sympy.py CHANGED
@@ -5,6 +5,8 @@ from typing import Callable, Dict, List, Optional
5
  import sympy
6
  from sympy import sympify
7
 
 
 
8
  sympy_mappings = {
9
  "div": lambda x, y: x / y,
10
  "mult": lambda x, y: x * y,
@@ -60,13 +62,13 @@ sympy_mappings = {
60
 
61
 
62
  def create_sympy_symbols_map(
63
- feature_names_in: List[str],
64
  ) -> Dict[str, sympy.Symbol]:
65
  return {variable: sympy.Symbol(variable) for variable in feature_names_in}
66
 
67
 
68
  def create_sympy_symbols(
69
- feature_names_in: List[str],
70
  ) -> List[sympy.Symbol]:
71
  return [sympy.Symbol(variable) for variable in feature_names_in]
72
 
@@ -74,7 +76,7 @@ def create_sympy_symbols(
74
  def pysr2sympy(
75
  equation: str,
76
  *,
77
- feature_names_in: Optional[List[str]] = None,
78
  extra_sympy_mappings: Optional[Dict[str, Callable]] = None,
79
  ):
80
  if feature_names_in is None:
 
5
  import sympy
6
  from sympy import sympify
7
 
8
+ from .utils import ArrayLike
9
+
10
  sympy_mappings = {
11
  "div": lambda x, y: x / y,
12
  "mult": lambda x, y: x * y,
 
62
 
63
 
64
  def create_sympy_symbols_map(
65
+ feature_names_in: ArrayLike[str],
66
  ) -> Dict[str, sympy.Symbol]:
67
  return {variable: sympy.Symbol(variable) for variable in feature_names_in}
68
 
69
 
70
  def create_sympy_symbols(
71
+ feature_names_in: ArrayLike[str],
72
  ) -> List[sympy.Symbol]:
73
  return [sympy.Symbol(variable) for variable in feature_names_in]
74
 
 
76
  def pysr2sympy(
77
  equation: str,
78
  *,
79
+ feature_names_in: Optional[ArrayLike[str]] = None,
80
  extra_sympy_mappings: Optional[Dict[str, Callable]] = None,
81
  ):
82
  if feature_names_in is None: