MilesCranmer commited on
Commit
e68c63f
1 Parent(s): a06de5e

Add feature for operator-level size constraints

Browse files
Files changed (3) hide show
  1. docs/options.md +14 -8
  2. julia/sr.jl +110 -45
  3. pysr/sr.py +43 -4
docs/options.md CHANGED
@@ -14,7 +14,7 @@ may find useful include:
14
  - `maxsize`, `maxdepth`
15
  - `batching`, `batchSize`
16
  - `variable_names` (or pandas input)
17
- - Limiting pow complexity
18
  - LaTeX, SymPy, and callable equation output
19
 
20
  These are described below
@@ -129,13 +129,19 @@ alphabetical characters and `_` are used in these names.
129
 
130
  ## Limiting pow complexity
131
 
132
- One can limit the complexity of power laws
133
- with the `limitPowComplexity` argument.
134
- This will prevent the exponent part of the expression
135
- having complexity greater than one. This prevents uninterpretable
136
- expressions such as `x^(y+z^(5+y))`, which sometimes
137
- occur during training. But it still allows for, e.g., `5^y` or
138
- `(x+y)^5`: anything where the exponent is not overly complex.
 
 
 
 
 
 
139
 
140
  ## LaTeX, SymPy, callables
141
 
 
14
  - `maxsize`, `maxdepth`
15
  - `batching`, `batchSize`
16
  - `variable_names` (or pandas input)
17
+ - Constraining operator complexity
18
  - LaTeX, SymPy, and callable equation output
19
 
20
  These are described below
 
129
 
130
  ## Limiting pow complexity
131
 
132
+ One can limit the complexity of specific operators with the `constraints` parameter.
133
+ There is a "maxsize" parameter to PySR, but there is also an operator-level
134
+ "constraints" parameter. One supplies a dict, like so:
135
+
136
+ ```python
137
+ constraints={'pow': (-1, 1), 'mult': (3, 3), 'cos': 5}
138
+ ```
139
+
140
+ What this says is that: a power law x^y can have an expression of arbitrary (-1) complexity in the x, but only complexity 1 (e.g., a constant or variable) in the y. So (x0 + 3)^5.5 is allowed, but 5.5^(x0 + 3) is not.
141
+ I find this helps a lot for getting more interpretable equations.
142
+ The other terms say that each multiplication can only have sub-expressions
143
+ of up to complexity 3 (e.g., 5.0 + x2) in each side, and cosine can only operate on
144
+ expressions of complexity 5 (e.g., 5.0 + x2*exp(x3)).
145
 
146
  ## LaTeX, SymPy, callables
147
 
julia/sr.jl CHANGED
@@ -646,24 +646,46 @@ mutable struct PopMember
646
 
647
  end
648
 
649
- # Check if any power operator is to the power of a complex expression
650
- function deepPow(tree::Node)::Integer
651
  if tree.degree == 0
652
- return 0
653
  elseif tree.degree == 1
654
- return 0 + deepPow(tree.l)
655
  else
656
- if binops[tree.op] === pow
657
- complexity_in_power = countNodes(tree.r)
658
- is_deep_pow = (complexity_in_power > 1)
659
- if is_deep_pow
660
- return 1 + deepPow(tree.l)
661
- else
662
- return 0 + deepPow(tree.l)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
663
  end
664
- else
665
- return 0 + deepPow(tree.l) + deepPow(tree.r)
666
  end
 
 
 
667
  end
668
  end
669
 
@@ -671,61 +693,104 @@ end
671
  # exp(-delta/T) defines probability of accepting a change
672
  function iterate(member::PopMember, T::Float32, curmaxsize::Integer)::PopMember
673
  prev = member.tree
674
- tree = copyNode(prev)
675
  #TODO - reconsider this
676
  if batching
677
- beforeLoss = scoreFuncBatch(member.tree)
678
  else
679
  beforeLoss = member.score
680
  end
681
 
682
  mutationChoice = rand()
683
- weightAdjustmentMutateConstant = min(8, countConstants(tree))/8.0
684
- cur_weights = copy(mutationWeights) .* 1.0
685
  #More constants => more likely to do constant mutation
 
 
686
  cur_weights[1] *= weightAdjustmentMutateConstant
687
- n = countNodes(tree)
688
- depth = countDepth(tree)
689
 
690
  # If equation too big, don't add new operators
691
  if n >= curmaxsize || depth >= maxdepth
692
  cur_weights[3] = 0.0
693
  cur_weights[4] = 0.0
694
  end
695
-
696
  cur_weights /= sum(cur_weights)
697
  cweights = cumsum(cur_weights)
698
 
699
- if mutationChoice < cweights[1]
700
- tree = mutateConstant(tree, T)
701
- elseif mutationChoice < cweights[2]
702
- tree = mutateOperator(tree)
703
- elseif mutationChoice < cweights[3]
704
- if rand() < 0.5
705
- tree = appendRandomOp(tree)
706
- else
707
- tree = prependRandomOp(tree)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
708
  end
709
- elseif mutationChoice < cweights[4]
710
- tree = insertRandomOp(tree)
711
- elseif mutationChoice < cweights[5]
712
- tree = deleteRandomOp(tree)
713
- elseif mutationChoice < cweights[6]
714
- tree = simplifyTree(tree) # Sometimes we simplify tree
715
- tree = combineOperators(tree) # See if repeated constants at outer levels
716
- return PopMember(tree, beforeLoss)
717
- elseif mutationChoice < cweights[7]
718
- tree = genRandomTree(5) # Sometimes we generate a new tree completely tree
719
- else
720
- return PopMember(tree, beforeLoss)
721
- end
722
 
 
 
 
 
 
 
 
 
 
 
 
723
 
724
- # Check for illegal functions
725
- if limitPowComplexity && (deepPow(tree) > 0)
726
- return PopMember(copyNode(prev), beforeLoss)
727
  end
 
728
 
 
 
 
729
 
730
  if batching
731
  afterLoss = scoreFuncBatch(tree)
 
646
 
647
  end
648
 
649
+ # Check if any binary operator are overly complex
650
+ function flagBinOperatorComplexity(tree::Node, op::Int)::Bool
651
  if tree.degree == 0
652
+ return false
653
  elseif tree.degree == 1
654
+ return flagBinOperatorComplexity(tree.l, op)
655
  else
656
+ if tree.op == op
657
+ overly_complex = (
658
+ ((bin_constraints[op][1] > -1) &&
659
+ (countNodes(tree.l) > bin_constraints[op][1]))
660
+ ||
661
+ ((bin_constraints[op][2] > -1) &&
662
+ (countNodes(tree.r) > bin_constraints[op][2]))
663
+ )
664
+ if overly_complex
665
+ return true
666
+ end
667
+ end
668
+ return (flagBinOperatorComplexity(tree.l, op) || flagBinOperatorComplexity(tree.r, op))
669
+ end
670
+ end
671
+
672
+ # Check if any unary operators are overly complex
673
+ function flagUnaOperatorComplexity(tree::Node, op::Int)::Bool
674
+ if tree.degree == 0
675
+ return false
676
+ elseif tree.degree == 1
677
+ if tree.op == op
678
+ overly_complex = (
679
+ (una_constraints[op] > -1) &&
680
+ (countNodes(tree.l) > una_constraints[op])
681
+ )
682
+ if overly_complex
683
+ return true
684
  end
 
 
685
  end
686
+ return flagUnaOperatorComplexity(tree.l, op)
687
+ else
688
+ return (flagUnaOperatorComplexity(tree.l, op) || flagUnaOperatorComplexity(tree.r, op))
689
  end
690
  end
691
 
 
693
  # exp(-delta/T) defines probability of accepting a change
694
  function iterate(member::PopMember, T::Float32, curmaxsize::Integer)::PopMember
695
  prev = member.tree
696
+ tree = prev
697
  #TODO - reconsider this
698
  if batching
699
+ beforeLoss = scoreFuncBatch(prev)
700
  else
701
  beforeLoss = member.score
702
  end
703
 
704
  mutationChoice = rand()
 
 
705
  #More constants => more likely to do constant mutation
706
+ weightAdjustmentMutateConstant = min(8, countConstants(prev))/8.0
707
+ cur_weights = copy(mutationWeights) .* 1.0
708
  cur_weights[1] *= weightAdjustmentMutateConstant
709
+ n = countNodes(prev)
710
+ depth = countDepth(prev)
711
 
712
  # If equation too big, don't add new operators
713
  if n >= curmaxsize || depth >= maxdepth
714
  cur_weights[3] = 0.0
715
  cur_weights[4] = 0.0
716
  end
 
717
  cur_weights /= sum(cur_weights)
718
  cweights = cumsum(cur_weights)
719
 
720
+ successful_mutation = false
721
+ #TODO: Currently we dont take this \/ into account
722
+ is_success_always_possible = true
723
+ attempts = 0
724
+ max_attempts = 10
725
+
726
+ #############################################
727
+ # Mutations
728
+ #############################################
729
+ while (!successful_mutation) && attempts < max_attempts
730
+ tree = copyNode(prev)
731
+ successful_mutation = true
732
+ if mutationChoice < cweights[1]
733
+ tree = mutateConstant(tree, T)
734
+
735
+ is_success_always_possible = true
736
+ # Mutating a constant shouldn't invalidate an already-valid function
737
+
738
+ elseif mutationChoice < cweights[2]
739
+ tree = mutateOperator(tree)
740
+
741
+ is_success_always_possible = true
742
+ # Can always mutate to the same operator
743
+
744
+ elseif mutationChoice < cweights[3]
745
+ if rand() < 0.5
746
+ tree = appendRandomOp(tree)
747
+ else
748
+ tree = prependRandomOp(tree)
749
+ end
750
+ is_success_always_possible = false
751
+ # Can potentially have a situation without success
752
+ elseif mutationChoice < cweights[4]
753
+ tree = insertRandomOp(tree)
754
+ is_success_always_possible = false
755
+ elseif mutationChoice < cweights[5]
756
+ tree = deleteRandomOp(tree)
757
+ is_success_always_possible = true
758
+ elseif mutationChoice < cweights[6]
759
+ tree = simplifyTree(tree) # Sometimes we simplify tree
760
+ tree = combineOperators(tree) # See if repeated constants at outer levels
761
+ return PopMember(tree, beforeLoss)
762
+
763
+ is_success_always_possible = true
764
+ # Simplification shouldn't hurt complexity; unless some non-symmetric constraint
765
+ # to commutative operator...
766
+
767
+ elseif mutationChoice < cweights[7]
768
+ tree = genRandomTree(5) # Sometimes we generate a new tree completely tree
769
+
770
+ is_success_always_possible = true
771
+ else # no mutation applied
772
+ return PopMember(tree, beforeLoss)
773
  end
 
 
 
 
 
 
 
 
 
 
 
 
 
774
 
775
+ # Check for illegal equations
776
+ for i=1:nbin
777
+ if successful_mutation && flagBinOperatorComplexity(tree, i)
778
+ successful_mutation = false
779
+ end
780
+ end
781
+ for i=1:nuna
782
+ if successful_mutation && flagUnaOperatorComplexity(tree, i)
783
+ successful_mutation = false
784
+ end
785
+ end
786
 
787
+ attempts += 1
 
 
788
  end
789
+ #############################################
790
 
791
+ if !successful_mutation
792
+ return PopMember(copyNode(prev), beforeLoss)
793
+ end
794
 
795
  if batching
796
  afterLoss = scoreFuncBatch(tree)
pysr/sr.py CHANGED
@@ -89,7 +89,8 @@ def pysr(X=None, y=None, weights=None,
89
  batchSize=50,
90
  select_k_features=None,
91
  warmupMaxsize=0,
92
- limitPowComplexity=False,
 
93
  threads=None, #deprecated
94
  julia_optimization=3,
95
  ):
@@ -166,9 +167,11 @@ def pysr(X=None, y=None, weights=None,
166
  a small number up to the maxsize (if greater than 0).
167
  If greater than 0, says how many cycles before the maxsize
168
  is increased.
169
- :param limitPowComplexity: bool, whether to prevent pow from having
170
- complex right arguments. I.e., 3.0^(x+y) becomes impossible,
171
- but 3.0^x is possible.
 
 
172
  :param julia_optimization: int, Optimization level (0, 1, 2, 3)
173
  :returns: pd.DataFrame, Results dataframe, giving complexity, MSE, and equations
174
  (as strings).
@@ -176,6 +179,8 @@ def pysr(X=None, y=None, weights=None,
176
  """
177
  if threads is not None:
178
  raise ValueError("The threads kwarg is deprecated. Use procs.")
 
 
179
  if maxdepth is None:
180
  maxdepth = maxsize
181
 
@@ -207,6 +212,17 @@ def pysr(X=None, y=None, weights=None,
207
  if populations is None:
208
  populations = procs
209
 
 
 
 
 
 
 
 
 
 
 
 
210
  rand_string = f'{"".join([str(np.random.rand())[2] for i in range(20)])}'
211
 
212
  if isinstance(binary_operators, str): binary_operators = [binary_operators]
@@ -247,7 +263,30 @@ def pysr(X=None, y=None, weights=None,
247
  function_name = op[:first_non_char]
248
  op_list[i] = function_name
249
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
250
  def_hyperparams += f"""include("{pkg_directory}/operators.jl")
 
251
  const binops = {'[' + ', '.join(binary_operators) + ']'}
252
  const unaops = {'[' + ', '.join(unary_operators) + ']'}
253
  const ns=10;
 
89
  batchSize=50,
90
  select_k_features=None,
91
  warmupMaxsize=0,
92
+ constraints={},
93
+ limitPowComplexity=False, #deprecated
94
  threads=None, #deprecated
95
  julia_optimization=3,
96
  ):
 
167
  a small number up to the maxsize (if greater than 0).
168
  If greater than 0, says how many cycles before the maxsize
169
  is increased.
170
+ :param constraints: dict of int (unary) or 2-tuples (binary),
171
+ this enforces maxsize constraints on the individual
172
+ arguments of operators. E.g., `'pow': (-1, 1)`
173
+ says that power laws can have any complexity left argument, but only
174
+ 1 complexity exponent. Use this to force more interpretable solutions.
175
  :param julia_optimization: int, Optimization level (0, 1, 2, 3)
176
  :returns: pd.DataFrame, Results dataframe, giving complexity, MSE, and equations
177
  (as strings).
 
179
  """
180
  if threads is not None:
181
  raise ValueError("The threads kwarg is deprecated. Use procs.")
182
+ if limitPowComplexity:
183
+ raise ValueError("The limitPowComplexity kwarg is deprecated. Use constraints.")
184
  if maxdepth is None:
185
  maxdepth = maxsize
186
 
 
212
  if populations is None:
213
  populations = procs
214
 
215
+ #arbitrary complexity by default
216
+ for op in unary_operators:
217
+ if op not in constraints:
218
+ constraints[op] = -1
219
+ for op in binary_operators:
220
+ if op not in constraints:
221
+ constraints[op] = (-1, -1)
222
+ if op in ['mult', 'plus', 'sub']:
223
+ if constraints[op][0] != constraints[op][1]:
224
+ raise NotImplementedError("You need equal constraints on both sides for +, -, and *, due to simplification strategies.")
225
+
226
  rand_string = f'{"".join([str(np.random.rand())[2] for i in range(20)])}'
227
 
228
  if isinstance(binary_operators, str): binary_operators = [binary_operators]
 
263
  function_name = op[:first_non_char]
264
  op_list[i] = function_name
265
 
266
+ constraints_str = "const una_constraints = ["
267
+ first = True
268
+ for op in unary_operators:
269
+ val = constraints[op]
270
+ if not first:
271
+ constraints_str += ", "
272
+ constraints_str += f"{val:d}"
273
+ first = False
274
+
275
+ constraints_str += """]
276
+ const bin_constraints = ["""
277
+
278
+ first = True
279
+ for op in binary_operators:
280
+ tup = constraints[op]
281
+ if not first:
282
+ constraints_str += ", "
283
+ constraints_str += f"({tup[0]:d}, {tup[1]:d})"
284
+ first = False
285
+ constraints_str += "]"
286
+
287
+
288
  def_hyperparams += f"""include("{pkg_directory}/operators.jl")
289
+ {constraints_str}
290
  const binops = {'[' + ', '.join(binary_operators) + ']'}
291
  const unaops = {'[' + ', '.join(unary_operators) + ']'}
292
  const ns=10;