MilesCranmer commited on
Commit
69fc6d0
1 Parent(s): 1b17efe

Make function names work when custom function has underscore

Browse files
Files changed (2) hide show
  1. pysr/sr.py +10 -6
  2. test/test.py +1 -1
pysr/sr.py CHANGED
@@ -4,7 +4,7 @@ import numpy as np
4
  import pandas as pd
5
  import sympy
6
  from sympy import sympify, lambdify
7
- import subprocess
8
  import tempfile
9
  import shutil
10
  from pathlib import Path
@@ -155,12 +155,16 @@ def _create_inline_operators(binary_operators, unary_operators):
155
  if is_user_defined_operator:
156
  Main.eval(op)
157
  # Cut off from the first non-alphanumeric char:
158
- first_non_char = [
159
- j
160
- for j, char in enumerate(op)
161
- if not (char.isalpha() or char.isdigit())
162
- ][0]
163
  function_name = op[:first_non_char]
 
 
 
 
 
 
 
 
164
  op_list[i] = function_name
165
 
166
 
 
4
  import pandas as pd
5
  import sympy
6
  from sympy import sympify, lambdify
7
+ import re
8
  import tempfile
9
  import shutil
10
  from pathlib import Path
 
155
  if is_user_defined_operator:
156
  Main.eval(op)
157
  # Cut off from the first non-alphanumeric char:
158
+ first_non_char = [j for j, char in enumerate(op) if char == "("][0]
 
 
 
 
159
  function_name = op[:first_non_char]
160
+ # Assert that function_name only contains
161
+ # alphabetical characters, numbers,
162
+ # and underscores:
163
+ if not re.match(r"^[a-zA-Z0-9_]+$", function_name):
164
+ raise ValueError(
165
+ f"Invalid function name {function_name}. "
166
+ "Only alphanumeric characters, numbers, and underscores are allowed."
167
+ )
168
  op_list[i] = function_name
169
 
170
 
test/test.py CHANGED
@@ -49,7 +49,7 @@ class TestPipeline(unittest.TestCase):
49
  model.fit(self.X, y)
50
  equations = model.equations
51
  print(equations)
52
- self.assertIn("square_op", equations.sympy())
53
  self.assertLessEqual(equations[0].iloc[-1]["loss"], 1e-4)
54
  self.assertLessEqual(equations[1].iloc[-1]["loss"], 1e-4)
55
 
 
49
  model.fit(self.X, y)
50
  equations = model.equations
51
  print(equations)
52
+ self.assertIn("square_op", model.equations[0].iloc[-1]["equation"])
53
  self.assertLessEqual(equations[0].iloc[-1]["loss"], 1e-4)
54
  self.assertLessEqual(equations[1].iloc[-1]["loss"], 1e-4)
55