MilesCranmer commited on
Commit
cf95c4d
1 Parent(s): 7b2d6ac

Treat dataset as global constant

Browse files
Files changed (1) hide show
  1. eureqa.jl +12 -17
eureqa.jl CHANGED
@@ -196,20 +196,18 @@ function mutateConstant(
196
  end
197
 
198
  # Evaluate an equation over an array of datapoints
199
- function evalTreeArray(
200
- tree::Node,
201
- x::Array{Float32, 2})::Array{Float32, 1}
202
- len = size(x)[1]
203
  if tree.degree == 0
204
  if tree.constant
205
  return ones(Float32, len) .* tree.val
206
  else
207
- return ones(Float32, len) .* x[:, tree.val]
208
  end
209
  elseif tree.degree == 1
210
- return tree.op.(evalTreeArray(tree.l, x))
211
  else
212
- return tree.op.(evalTreeArray(tree.l, x), evalTreeArray(tree.r, x))
213
  end
214
  end
215
 
@@ -225,12 +223,10 @@ end
225
 
226
  # Score an equation
227
  function scoreFunc(
228
- tree::Node,
229
- X::Array{Float32, 2},
230
- y::Array{Float32, 1};
231
  parsimony::Float32=0.1f0)::Float32
232
  try
233
- return MSE(evalTreeArray(tree, X), y) + countNodes(tree)*parsimony
234
  catch error
235
  if isa(error, DomainError)
236
  return 1f9
@@ -327,7 +323,6 @@ end
327
  # exp(-delta/T) defines probability of accepting a change
328
  function iterate(
329
  tree::Node, T::Float32,
330
- X::Array{Float32, 2}, y::Array{Float32, 1},
331
  alpha::Float32=1.0f0,
332
  mult::Float32=0.1f0;
333
  annealing::Bool=true
@@ -358,8 +353,8 @@ function iterate(
358
  end
359
 
360
  if annealing
361
- beforeLoss = scoreFunc(prev, X, y, parsimony=mult)
362
- afterLoss = scoreFunc(tree, X, y, parsimony=mult)
363
  delta = afterLoss - beforeLoss
364
  probChange = exp(-delta/(T*alpha))
365
 
@@ -387,7 +382,7 @@ mutable struct PopMember
387
  score::Float32
388
  birth::Int32
389
 
390
- PopMember(t::Node) = new(t, scoreFunc(t, X, y, parsimony=parsimony), round(Int32, 1e3*(time()-1.6e9)))
391
  PopMember(t::Node, score::Float32) = new(t, score, round(Int32, 1e3*(time()-1.6e9)))
392
 
393
  end
@@ -429,10 +424,10 @@ function iterateSample(
429
  annealing::Bool=true)::PopMember
430
  allstar = bestOfSample(pop)
431
  new = iterate(
432
- allstar.tree, T, X, y,
433
  alpha, parsimony, annealing=annealing)
434
  allstar.tree = new
435
- allstar.score = scoreFunc(new, X, y, parsimony=parsimony)
436
  allstar.birth = round(Int32, 1e3*(time()-1.6e9))
437
  return allstar
438
  end
 
196
  end
197
 
198
  # Evaluate an equation over an array of datapoints
199
+ function evalTreeArray(tree::Node)::Array{Float32, 1}
200
+ len = size(X)[1]
 
 
201
  if tree.degree == 0
202
  if tree.constant
203
  return ones(Float32, len) .* tree.val
204
  else
205
+ return ones(Float32, len) .* X[:, tree.val]
206
  end
207
  elseif tree.degree == 1
208
+ return tree.op.(evalTreeArray(tree.l))
209
  else
210
+ return tree.op.(evalTreeArray(tree.l), evalTreeArray(tree.r))
211
  end
212
  end
213
 
 
223
 
224
  # Score an equation
225
  function scoreFunc(
226
+ tree::Node;
 
 
227
  parsimony::Float32=0.1f0)::Float32
228
  try
229
+ return MSE(evalTreeArray(tree), y) + countNodes(tree)*parsimony
230
  catch error
231
  if isa(error, DomainError)
232
  return 1f9
 
323
  # exp(-delta/T) defines probability of accepting a change
324
  function iterate(
325
  tree::Node, T::Float32,
 
326
  alpha::Float32=1.0f0,
327
  mult::Float32=0.1f0;
328
  annealing::Bool=true
 
353
  end
354
 
355
  if annealing
356
+ beforeLoss = scoreFunc(prev, parsimony=mult)
357
+ afterLoss = scoreFunc(tree, parsimony=mult)
358
  delta = afterLoss - beforeLoss
359
  probChange = exp(-delta/(T*alpha))
360
 
 
382
  score::Float32
383
  birth::Int32
384
 
385
+ PopMember(t::Node) = new(t, scoreFunc(t, parsimony=parsimony), round(Int32, 1e3*(time()-1.6e9)))
386
  PopMember(t::Node, score::Float32) = new(t, score, round(Int32, 1e3*(time()-1.6e9)))
387
 
388
  end
 
424
  annealing::Bool=true)::PopMember
425
  allstar = bestOfSample(pop)
426
  new = iterate(
427
+ allstar.tree, T,
428
  alpha, parsimony, annealing=annealing)
429
  allstar.tree = new
430
+ allstar.score = scoreFunc(new, parsimony=parsimony)
431
  allstar.birth = round(Int32, 1e3*(time()-1.6e9))
432
  return allstar
433
  end