MilesCranmer commited on
Commit
7f6d86d
1 Parent(s): 5130409

Only evaluate individuals on small batches during evolution

Browse files
Files changed (3) hide show
  1. README.md +1 -0
  2. julia/sr.jl +104 -16
  3. pysr/sr.py +1 -0
README.md CHANGED
@@ -309,6 +309,7 @@ pd.DataFrame, Results dataframe, giving complexity, MSE, and equations
309
 
310
  ## Feature ideas
311
 
 
312
  - [ ] Sympy printing
313
  - [ ] Better cleanup of zombie processes after <ctl-c>
314
  - [ ] Hierarchical model, so can re-use functional forms. Output of one equation goes into second equation?
 
309
 
310
  ## Feature ideas
311
 
312
+ - [ ] Cross-validation
313
  - [ ] Sympy printing
314
  - [ ] Better cleanup of zombie processes after <ctl-c>
315
  - [ ] Hierarchical model, so can re-use functional forms. Output of one equation goes into second equation?
julia/sr.jl CHANGED
@@ -1,6 +1,6 @@
1
  import Optim
2
  import Printf: @printf
3
- import Random: shuffle!
4
 
5
  const maxdegree = 2
6
  const actualMaxsize = maxsize + maxdegree
@@ -9,11 +9,13 @@ const actualMaxsize = maxsize + maxdegree
9
  # Sum of square error between two arrays
10
  function SSE(x::Array{Float32}, y::Array{Float32})::Float32
11
  diff = (x - y)
12
- if weighted
13
- return sum(diff .* diff .* weights)
14
- else
15
- return sum(diff .* diff)
16
- end
 
 
17
  end
18
 
19
  # Mean of square error between two arrays
@@ -21,15 +23,21 @@ function MSE(x::Array{Float32}, y::Array{Float32})::Float32
21
  return SSE(x, y)/size(x)[1]
22
  end
23
 
 
 
 
 
 
24
  const len = size(X)[1]
25
 
26
  if weighted
27
- const avgy = sum(y .* weights)/len/sum(weights)
 
28
  else
29
  const avgy = sum(y)/len
 
30
  end
31
 
32
- const baselineSSE = SSE(y, convert(Array{Float32, 1}, ones(len) .* avgy))
33
 
34
  id = (x,) -> x
35
  const nuna = size(unaops)[1]
@@ -278,10 +286,70 @@ function evalTreeArray(tree::Node)::Array{Float32, 1}
278
  end
279
  end
280
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
281
  # Score an equation
282
  function scoreFunc(tree::Node)::Float32
283
  try
284
- return SSE(evalTreeArray(tree), y)/baselineSSE + countNodes(tree)*parsimony
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
285
  catch error
286
  if isa(error, DomainError) || isa(error, LoadError) || isa(error, TaskFailedException)
287
  return 1f9
@@ -542,7 +610,9 @@ end
542
  function iterate(member::PopMember, T::Float32)::PopMember
543
  prev = member.tree
544
  tree = copyNode(prev)
545
- beforeLoss = member.score
 
 
546
 
547
  mutationChoice = rand()
548
  weightAdjustmentMutateConstant = min(8, countConstants(tree))/8.0
@@ -573,7 +643,7 @@ function iterate(member::PopMember, T::Float32)::PopMember
573
  return PopMember(tree, beforeLoss)
574
  end
575
 
576
- afterLoss = scoreFunc(tree)
577
 
578
  if annealing
579
  delta = afterLoss - beforeLoss
@@ -877,10 +947,19 @@ function fullRun(niterations::Integer;
877
  for size=1:actualMaxsize
878
  if hallOfFame.exists[size]
879
  member = hallOfFame.members[size]
880
- curMSE = MSE(evalTreeArray(member.tree), y)
 
 
 
 
881
  numberSmallerAndBetter = 0
882
  for i=1:(size-1)
883
- if (hallOfFame.exists[size] && curMSE > MSE(evalTreeArray(hallOfFame.members[i].tree), y))
 
 
 
 
 
884
  numberSmallerAndBetter += 1
885
  end
886
  end
@@ -951,7 +1030,7 @@ function fullRun(niterations::Integer;
951
  @printf("Hall of Fame:\n")
952
  @printf("-----------------------------------------\n")
953
  @printf("%-10s %-8s %-8s %-8s\n", "Complexity", "MSE", "Score", "Equation")
954
- curMSE = baselineSSE / len
955
  @printf("%-10d %-8.3e %-8.3e %-.f\n", 0, curMSE, 0f0, avgy)
956
  lastMSE = curMSE
957
  lastComplexity = 0
@@ -959,10 +1038,19 @@ function fullRun(niterations::Integer;
959
  for size=1:actualMaxsize
960
  if hallOfFame.exists[size]
961
  member = hallOfFame.members[size]
962
- curMSE = MSE(evalTreeArray(member.tree), y)
 
 
 
 
963
  numberSmallerAndBetter = 0
964
  for i=1:(size-1)
965
- if (hallOfFame.exists[size] && curMSE > MSE(evalTreeArray(hallOfFame.members[i].tree), y))
 
 
 
 
 
966
  numberSmallerAndBetter += 1
967
  end
968
  end
 
1
  import Optim
2
  import Printf: @printf
3
+ import Random: shuffle!, randperm
4
 
5
  const maxdegree = 2
6
  const actualMaxsize = maxsize + maxdegree
 
9
  # Sum of square error between two arrays
10
  function SSE(x::Array{Float32}, y::Array{Float32})::Float32
11
  diff = (x - y)
12
+ return sum(diff .* diff)
13
+ end
14
+
15
+ # Sum of square error between two arrays, with weights
16
+ function SSE(x::Array{Float32}, y::Array{Float32}, w::Array{Float32})::Float32
17
+ diff = (x - y)
18
+ return sum(diff .* diff .* w)
19
  end
20
 
21
  # Mean of square error between two arrays
 
23
  return SSE(x, y)/size(x)[1]
24
  end
25
 
26
+ # Mean of square error between two arrays
27
+ function MSE(x::Array{Float32}, y::Array{Float32}, w::Array{Float32})::Float32
28
+ return SSE(x, y, w)/sum(w)
29
+ end
30
+
31
  const len = size(X)[1]
32
 
33
  if weighted
34
+ const avgy = sum(y .* weights)/sum(weights)
35
+ const baselineMSE = MSE(y, convert(Array{Float32, 1}, ones(len) .* avgy), weights)
36
  else
37
  const avgy = sum(y)/len
38
+ const baselineMSE = MSE(y, convert(Array{Float32, 1}, ones(len) .* avgy))
39
  end
40
 
 
41
 
42
  id = (x,) -> x
43
  const nuna = size(unaops)[1]
 
286
  end
287
  end
288
 
289
+
290
+ # Evaluate an equation over an array of datapoints
291
+ function evalTreeArray(tree::Node, cX::Array{Float32, 2})::Array{Float32, 1}
292
+ clen = size(cX)[1]
293
+ if tree.degree == 0
294
+ if tree.constant
295
+ return fill(tree.val, clen)
296
+ else
297
+ return copy(cX[:, tree.val])
298
+ end
299
+ elseif tree.degree == 1
300
+ cumulator = evalTreeArray(tree.l, cX)
301
+ op = unaops[tree.op]
302
+ @inbounds for i=1:clen
303
+ cumulator[i] = op(cumulator[i])
304
+ end
305
+ return cumulator
306
+ else
307
+ op = binops[tree.op]
308
+ cumulator = evalTreeArray(tree.l, cX)
309
+ array2 = evalTreeArray(tree.r, cX)
310
+ @inbounds for i=1:clen
311
+ cumulator[i] = op(cumulator[i], array2[i])
312
+ end
313
+ return cumulator
314
+ end
315
+ end
316
+
317
  # Score an equation
318
  function scoreFunc(tree::Node)::Float32
319
  try
320
+ prediction = evalTreeArray(tree)
321
+ if weighted
322
+ mse = MSE(prediction, y, weights)
323
+ else
324
+ mse = MSE(prediction, y)
325
+ end
326
+ return mse / baselineMSE + countNodes(tree)*parsimony
327
+ catch error
328
+ if isa(error, DomainError) || isa(error, LoadError) || isa(error, TaskFailedException)
329
+ return 1f9
330
+ else
331
+ throw(error)
332
+ end
333
+ end
334
+ end
335
+
336
+ # Score an equation with a small batch
337
+ function scoreFuncBatch(tree::Node)::Float32
338
+ try
339
+ # batchSize
340
+ batch_idx = randperm(len)[1:batchSize]
341
+ batch_X = X[batch_idx, :]
342
+ batch_y = y[batch_idx]
343
+ prediction = evalTreeArray(tree, batch_X)
344
+ size_adjustment = 1
345
+ if weighted
346
+ batch_w = weights[batch_idx]
347
+ mse = MSE(prediction, batch_y, batch_w)
348
+ size_adjustment = 1f0 * len / batchSize
349
+ else
350
+ mse = MSE(prediction, batch_y)
351
+ end
352
+ return size_adjustment * mse / baselineMSE + countNodes(tree)*parsimony
353
  catch error
354
  if isa(error, DomainError) || isa(error, LoadError) || isa(error, TaskFailedException)
355
  return 1f9
 
610
  function iterate(member::PopMember, T::Float32)::PopMember
611
  prev = member.tree
612
  tree = copyNode(prev)
613
+ #TODO - reconsider this
614
+ # beforeLoss = member.score
615
+ beforeLoss = scoreFuncBatch(member.tree)
616
 
617
  mutationChoice = rand()
618
  weightAdjustmentMutateConstant = min(8, countConstants(tree))/8.0
 
643
  return PopMember(tree, beforeLoss)
644
  end
645
 
646
+ afterLoss = scoreFuncBatch(tree)
647
 
648
  if annealing
649
  delta = afterLoss - beforeLoss
 
947
  for size=1:actualMaxsize
948
  if hallOfFame.exists[size]
949
  member = hallOfFame.members[size]
950
+ if weighted
951
+ curMSE = MSE(evalTreeArray(member.tree), y, weights)
952
+ else
953
+ curMSE = MSE(evalTreeArray(member.tree), y)
954
+ end
955
  numberSmallerAndBetter = 0
956
  for i=1:(size-1)
957
+ if weighted
958
+ hofMSE = MSE(evalTreeArray(hallOfFame.members[i].tree), y, weights)
959
+ else
960
+ hofMSE = MSE(evalTreeArray(hallOfFame.members[i].tree), y)
961
+ end
962
+ if (hallOfFame.exists[size] && curMSE > hofMSE)
963
  numberSmallerAndBetter += 1
964
  end
965
  end
 
1030
  @printf("Hall of Fame:\n")
1031
  @printf("-----------------------------------------\n")
1032
  @printf("%-10s %-8s %-8s %-8s\n", "Complexity", "MSE", "Score", "Equation")
1033
+ curMSE = baselineMSE
1034
  @printf("%-10d %-8.3e %-8.3e %-.f\n", 0, curMSE, 0f0, avgy)
1035
  lastMSE = curMSE
1036
  lastComplexity = 0
 
1038
  for size=1:actualMaxsize
1039
  if hallOfFame.exists[size]
1040
  member = hallOfFame.members[size]
1041
+ if weighted
1042
+ curMSE = MSE(evalTreeArray(member.tree), y, weights)
1043
+ else
1044
+ curMSE = MSE(evalTreeArray(member.tree), y)
1045
+ end
1046
  numberSmallerAndBetter = 0
1047
  for i=1:(size-1)
1048
+ if weighted
1049
+ hofMSE = MSE(evalTreeArray(hallOfFame.members[i].tree), y, weights)
1050
+ else
1051
+ hofMSE = MSE(evalTreeArray(hallOfFame.members[i].tree), y)
1052
+ end
1053
+ if (hallOfFame.exists[size] && curMSE > hofMSE)
1054
  numberSmallerAndBetter += 1
1055
  end
1056
  end
pysr/sr.py CHANGED
@@ -227,6 +227,7 @@ const nrestarts = {nrestarts:d}
227
  const perturbationFactor = {perturbationFactor:f}f0
228
  const annealing = {"true" if annealing else "false"}
229
  const weighted = {"true" if weights is not None else "false"}
 
230
  const useVarMap = {"false" if len(variable_names) == 0 else "true"}
231
  const mutationWeights = [
232
  {weightMutateConstant:f},
 
227
  const perturbationFactor = {perturbationFactor:f}f0
228
  const annealing = {"true" if annealing else "false"}
229
  const weighted = {"true" if weights is not None else "false"}
230
+ const batchSize = {min([50, len(X)]):d}
231
  const useVarMap = {"false" if len(variable_names) == 0 else "true"}
232
  const mutationWeights = [
233
  {weightMutateConstant:f},