MilesCranmer commited on
Commit
aea0528
·
1 Parent(s): 574628b

Add way of limiting the complexity of power operators

Browse files
Files changed (2) hide show
  1. julia/sr.jl +28 -0
  2. pysr/sr.py +5 -0
julia/sr.jl CHANGED
@@ -599,6 +599,27 @@ mutable struct PopMember
599
 
600
  end
601
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
602
  # Go through one simulated annealing mutation cycle
603
  # exp(-delta/T) defines probability of accepting a change
604
  function iterate(member::PopMember, T::Float32, curmaxsize::Integer)::PopMember
@@ -652,6 +673,13 @@ function iterate(member::PopMember, T::Float32, curmaxsize::Integer)::PopMember
652
  return PopMember(tree, beforeLoss)
653
  end
654
 
 
 
 
 
 
 
 
655
  if batching
656
  afterLoss = scoreFuncBatch(tree)
657
  else
 
599
 
600
  end
601
 
602
+ # Check if any power operator is to the power of a complex expression
603
+ function deepPow(tree::Node)::Integer
604
+ if tree.degree == 0
605
+ return 0
606
+ elseif tree.degree == 1
607
+ return 0 + deepPow(tree.l)
608
+ else
609
+ if binops[tree.op] == pow
610
+ complexity_in_power = countNodes(tree.r)
611
+ is_deep_pow = (complexity_in_power > 1)
612
+ if is_deep_pow
613
+ return 1 + deepPow(tree.l)
614
+ else
615
+ return 0 + deepPow(tree.l)
616
+ end
617
+ else
618
+ return 0 + deepPow(tree.l) + deepPow(tree.r)
619
+ end
620
+ end
621
+ end
622
+
623
  # Go through one simulated annealing mutation cycle
624
  # exp(-delta/T) defines probability of accepting a change
625
  function iterate(member::PopMember, T::Float32, curmaxsize::Integer)::PopMember
 
673
  return PopMember(tree, beforeLoss)
674
  end
675
 
676
+
677
+ # Check for illegal functions
678
+ if limitPowComplexity && (deepPow(tree) > 0)
679
+ return PopMember(copyNode(prev), beforeLoss)
680
+ end
681
+
682
+
683
  if batching
684
  afterLoss = scoreFuncBatch(tree)
685
  else
pysr/sr.py CHANGED
@@ -87,6 +87,7 @@ def pysr(X=None, y=None, weights=None,
87
  batchSize=50,
88
  select_k_features=None,
89
  warmupMaxsize=0,
 
90
  threads=None, #deprecated
91
  julia_optimization=3,
92
  ):
@@ -163,6 +164,9 @@ def pysr(X=None, y=None, weights=None,
163
  a small number up to the maxsize (if greater than 0).
164
  If greater than 0, says how many cycles before the maxsize
165
  is increased.
 
 
 
166
  :param julia_optimization: int, Optimization level (0, 1, 2, 3)
167
  :returns: pd.DataFrame, Results dataframe, giving complexity, MSE, and equations
168
  (as strings).
@@ -275,6 +279,7 @@ const mutationWeights = [
275
  {weightDoNothing:f}
276
  ]
277
  const warmupMaxsize = {warmupMaxsize:d}
 
278
  """
279
 
280
  op_runner = ""
 
87
  batchSize=50,
88
  select_k_features=None,
89
  warmupMaxsize=0,
90
+ limitPowComplexity=False,
91
  threads=None, #deprecated
92
  julia_optimization=3,
93
  ):
 
164
  a small number up to the maxsize (if greater than 0).
165
  If greater than 0, says how many cycles before the maxsize
166
  is increased.
167
+ :param limitPowComplexity: bool, whether to prevent pow from having
168
+ complex right arguments. I.e., 3.0^(x+y) becomes impossible,
169
+ but 3.0^x is possible.
170
  :param julia_optimization: int, Optimization level (0, 1, 2, 3)
171
  :returns: pd.DataFrame, Results dataframe, giving complexity, MSE, and equations
172
  (as strings).
 
279
  {weightDoNothing:f}
280
  ]
281
  const warmupMaxsize = {warmupMaxsize:d}
282
+ const limitPowComplexity = {"true" if limitPowComplexity else "false"}
283
  """
284
 
285
  op_runner = ""