MilesCranmer commited on
Commit
3a578d2
1 Parent(s): cb0e2e9

Parallelize all recursive functions

Browse files
Files changed (3) hide show
  1. README.md +3 -0
  2. julia/sr.jl +19 -10
  3. pysr/sr.py +1 -0
README.md CHANGED
@@ -307,3 +307,6 @@ pd.DataFrame, Results dataframe, giving complexity, MSE, and equations
307
  - [ ] Create flexible way of providing "simplification recipes." I.e., plus(plus(T, C), C) => plus(T, +(C, C)). The user could pass these.
308
  - [ ] Can we cache calculations, or does the compiler do that? E.g., I should only have to run exp(x0) once; after that it should be read from memory.
309
  - Maybe I could store the result of calculations in a tree (or an index to a massive array that does this). And only when something in the subtree updates, does the rest of the tree update!
 
 
 
 
307
  - [ ] Create flexible way of providing "simplification recipes." I.e., plus(plus(T, C), C) => plus(T, +(C, C)). The user could pass these.
308
  - [ ] Can we cache calculations, or does the compiler do that? E.g., I should only have to run exp(x0) once; after that it should be read from memory.
309
  - Maybe I could store the result of calculations in a tree (or an index to a massive array that does this). And only when something in the subtree updates, does the rest of the tree update!
310
+ - [ ] Try Memoize.jl instead of manually caching.
311
+ - [ ] Try threading over population. Do random sort, compute mutation for each, then replace 10% oldest.
312
+
julia/sr.jl CHANGED
@@ -72,8 +72,9 @@ function copyNode(tree::Node)::Node
72
  return Node(tree.val)
73
  elseif tree.degree == 1
74
  return Node(tree.op, copyNode(tree.l))
75
- else
76
- return Node(tree.op, copyNode(tree.l), copyNode(tree.r))
 
77
  end
78
  end
79
 
@@ -88,7 +89,8 @@ function evalTree(tree::Node, x::Array{Float32, 1}=Float32[])::Float32
88
  elseif tree.degree == 1
89
  return tree.op(evalTree(tree.l, x))
90
  else
91
- return tree.op(evalTree(tree.l, x), evalTree(tree.r, x))
 
92
  end
93
  end
94
 
@@ -99,7 +101,8 @@ function countNodes(tree::Node)::Integer
99
  elseif tree.degree == 1
100
  return 1 + countNodes(tree.l)
101
  else
102
- return 1 + countNodes(tree.l) + countNodes(tree.r)
 
103
  end
104
  end
105
 
@@ -114,7 +117,8 @@ function stringTree(tree::Node)::String
114
  elseif tree.degree == 1
115
  return "$(tree.op)($(stringTree(tree.l)))"
116
  else
117
- return "$(tree.op)($(stringTree(tree.l)), $(stringTree(tree.r)))"
 
118
  end
119
  end
120
 
@@ -155,7 +159,8 @@ function countUnaryOperators(tree::Node)::Integer
155
  elseif tree.degree == 1
156
  return 1 + countUnaryOperators(tree.l)
157
  else
158
- return 0 + countUnaryOperators(tree.l) + countUnaryOperators(tree.r)
 
159
  end
160
  end
161
 
@@ -166,7 +171,8 @@ function countBinaryOperators(tree::Node)::Integer
166
  elseif tree.degree == 1
167
  return 0 + countBinaryOperators(tree.l)
168
  else
169
- return 1 + countBinaryOperators(tree.l) + countBinaryOperators(tree.r)
 
170
  end
171
  end
172
 
@@ -200,7 +206,8 @@ function countConstants(tree::Node)::Integer
200
  elseif tree.degree == 1
201
  return 0 + countConstants(tree.l)
202
  else
203
- return 0 + countConstants(tree.l) + countConstants(tree.r)
 
204
  end
205
  end
206
 
@@ -247,7 +254,8 @@ function evalTreeArray(tree::Node)::Array{Float32, 1}
247
  elseif tree.degree == 1
248
  return tree.op.(evalTreeArray(tree.l))
249
  else
250
- return tree.op.(evalTreeArray(tree.l), evalTreeArray(tree.r))
 
251
  end
252
  end
253
 
@@ -486,8 +494,9 @@ function simplifyTree(tree::Node)::Node
486
  return Node(tree.op(tree.l.val))
487
  end
488
  elseif tree.degree == 2
489
- tree.r = simplifyTree(tree.r)
490
  tree.l = simplifyTree(tree.l)
 
491
  constantsBelow = (
492
  tree.l.degree == 0 && tree.l.constant &&
493
  tree.r.degree == 0 && tree.r.constant
 
72
  return Node(tree.val)
73
  elseif tree.degree == 1
74
  return Node(tree.op, copyNode(tree.l))
75
+ else
76
+ right = Threads.@spawn copyNode(tree.r)
77
+ return Node(tree.op, copyNode(tree.l), fetch(right))
78
  end
79
  end
80
 
 
89
  elseif tree.degree == 1
90
  return tree.op(evalTree(tree.l, x))
91
  else
92
+ right = Threads.@spawn evalTree(tree.r, x)
93
+ return tree.op(evalTree(tree.l, x), fetch(right))
94
  end
95
  end
96
 
 
101
  elseif tree.degree == 1
102
  return 1 + countNodes(tree.l)
103
  else
104
+ right = Threads.@spawn countNodes(tree.r)
105
+ return 1 + countNodes(tree.l) + fetch(right)
106
  end
107
  end
108
 
 
117
  elseif tree.degree == 1
118
  return "$(tree.op)($(stringTree(tree.l)))"
119
  else
120
+ right = Threads.@spawn stringTree(tree.r)
121
+ return "$(tree.op)($(stringTree(tree.l)), $(fetch(right)))"
122
  end
123
  end
124
 
 
159
  elseif tree.degree == 1
160
  return 1 + countUnaryOperators(tree.l)
161
  else
162
+ right = Threads.@spawn countUnaryOperators(tree.r)
163
+ return 0 + countUnaryOperators(tree.l) + fetch(right)
164
  end
165
  end
166
 
 
171
  elseif tree.degree == 1
172
  return 0 + countBinaryOperators(tree.l)
173
  else
174
+ right = Threads.@spawn countBinaryOperators(tree.r)
175
+ return 1 + countBinaryOperators(tree.l) + fetch(right)
176
  end
177
  end
178
 
 
206
  elseif tree.degree == 1
207
  return 0 + countConstants(tree.l)
208
  else
209
+ right = Threads.@spawn countConstants(tree.r)
210
+ return 0 + countConstants(tree.l) + fetch(right)
211
  end
212
  end
213
 
 
254
  elseif tree.degree == 1
255
  return tree.op.(evalTreeArray(tree.l))
256
  else
257
+ right = Threads.@spawn evalTreeArray(tree.r)
258
+ return tree.op.(evalTreeArray(tree.l), fetch(right))
259
  end
260
  end
261
 
 
494
  return Node(tree.op(tree.l.val))
495
  end
496
  elseif tree.degree == 2
497
+ right = Threads.@spawn simplifyTree(tree.r)
498
  tree.l = simplifyTree(tree.l)
499
+ tree.r = fetch(right)
500
  constantsBelow = (
501
  tree.l.degree == 0 && tree.l.constant &&
502
  tree.r.degree == 0 && tree.r.constant
pysr/sr.py CHANGED
@@ -213,6 +213,7 @@ const weights = convert(Array{Float32, 1}, """f"{weight_str})"
213
  command = [
214
  f'julia -O{julia_optimization:d}',
215
  f'-p {procs}',
 
216
  f'/tmp/.runfile_{rand_string}.jl',
217
  ]
218
  if timeout is not None:
 
213
  command = [
214
  f'julia -O{julia_optimization:d}',
215
  f'-p {procs}',
216
+ f'--threads auto',
217
  f'/tmp/.runfile_{rand_string}.jl',
218
  ]
219
  if timeout is not None: