MilesCranmer commited on
Commit
c9cb1b6
1 Parent(s): 02e2655

Remove nested threading for tree traversal

Browse files
Files changed (2) hide show
  1. julia/sr.jl +9 -18
  2. pysr/sr.py +0 -1
julia/sr.jl CHANGED
@@ -73,8 +73,7 @@ function copyNode(tree::Node)::Node
73
  elseif tree.degree == 1
74
  return Node(tree.op, copyNode(tree.l))
75
  else
76
- right = Threads.@spawn copyNode(tree.r)
77
- return Node(tree.op, copyNode(tree.l), fetch(right))
78
  end
79
  end
80
 
@@ -89,8 +88,7 @@ function evalTree(tree::Node, x::Array{Float32, 1}=Float32[])::Float32
89
  elseif tree.degree == 1
90
  return unaops[tree.op](evalTree(tree.l, x))
91
  else
92
- right = Threads.@spawn evalTree(tree.r, x)
93
- return binops[tree.op](evalTree(tree.l, x), fetch(right))
94
  end
95
  end
96
 
@@ -101,8 +99,7 @@ function countNodes(tree::Node)::Integer
101
  elseif tree.degree == 1
102
  return 1 + countNodes(tree.l)
103
  else
104
- right = Threads.@spawn countNodes(tree.r)
105
- return 1 + countNodes(tree.l) + fetch(right)
106
  end
107
  end
108
 
@@ -117,8 +114,7 @@ function stringTree(tree::Node)::String
117
  elseif tree.degree == 1
118
  return "$(unaops[tree.op])($(stringTree(tree.l)))"
119
  else
120
- right = Threads.@spawn stringTree(tree.r)
121
- return "$(binops[tree.op])($(stringTree(tree.l)), $(fetch(right)))"
122
  end
123
  end
124
 
@@ -159,8 +155,7 @@ function countUnaryOperators(tree::Node)::Integer
159
  elseif tree.degree == 1
160
  return 1 + countUnaryOperators(tree.l)
161
  else
162
- right = Threads.@spawn countUnaryOperators(tree.r)
163
- return 0 + countUnaryOperators(tree.l) + fetch(right)
164
  end
165
  end
166
 
@@ -171,8 +166,7 @@ function countBinaryOperators(tree::Node)::Integer
171
  elseif tree.degree == 1
172
  return 0 + countBinaryOperators(tree.l)
173
  else
174
- right = Threads.@spawn countBinaryOperators(tree.r)
175
- return 1 + countBinaryOperators(tree.l) + fetch(right)
176
  end
177
  end
178
 
@@ -206,8 +200,7 @@ function countConstants(tree::Node)::Integer
206
  elseif tree.degree == 1
207
  return 0 + countConstants(tree.l)
208
  else
209
- right = Threads.@spawn countConstants(tree.r)
210
- return 0 + countConstants(tree.l) + fetch(right)
211
  end
212
  end
213
 
@@ -254,8 +247,7 @@ function evalTreeArray(tree::Node)::Array{Float32, 1}
254
  elseif tree.degree == 1
255
  return unaops[tree.op].(evalTreeArray(tree.l))
256
  else
257
- right = Threads.@spawn evalTreeArray(tree.r)
258
- return binops[tree.op].(evalTreeArray(tree.l), fetch(right))
259
  end
260
  end
261
 
@@ -494,9 +486,8 @@ function simplifyTree(tree::Node)::Node
494
  return Node(unaops[tree.op](tree.l.val))
495
  end
496
  elseif tree.degree == 2
497
- right = Threads.@spawn simplifyTree(tree.r)
498
  tree.l = simplifyTree(tree.l)
499
- tree.r = fetch(right)
500
  constantsBelow = (
501
  tree.l.degree == 0 && tree.l.constant &&
502
  tree.r.degree == 0 && tree.r.constant
 
73
  elseif tree.degree == 1
74
  return Node(tree.op, copyNode(tree.l))
75
  else
76
+ return Node(tree.op, copyNode(tree.l), copyNode(tree.r))
 
77
  end
78
  end
79
 
 
88
  elseif tree.degree == 1
89
  return unaops[tree.op](evalTree(tree.l, x))
90
  else
91
+ return binops[tree.op](evalTree(tree.l, x), evalTree(tree.r, x))
 
92
  end
93
  end
94
 
 
99
  elseif tree.degree == 1
100
  return 1 + countNodes(tree.l)
101
  else
102
+ return 1 + countNodes(tree.l) + countNodes(tree.r)
 
103
  end
104
  end
105
 
 
114
  elseif tree.degree == 1
115
  return "$(unaops[tree.op])($(stringTree(tree.l)))"
116
  else
117
+ return "$(binops[tree.op])($(stringTree(tree.l)), $(stringTree(tree.r)))"
 
118
  end
119
  end
120
 
 
155
  elseif tree.degree == 1
156
  return 1 + countUnaryOperators(tree.l)
157
  else
158
+ return 0 + countUnaryOperators(tree.l) + countUnaryOperators(tree.r)
 
159
  end
160
  end
161
 
 
166
  elseif tree.degree == 1
167
  return 0 + countBinaryOperators(tree.l)
168
  else
169
+ return 1 + countBinaryOperators(tree.l) + countBinaryOperators(tree.r)
 
170
  end
171
  end
172
 
 
200
  elseif tree.degree == 1
201
  return 0 + countConstants(tree.l)
202
  else
203
+ return 0 + countConstants(tree.l) + countConstants(tree.r)
 
204
  end
205
  end
206
 
 
247
  elseif tree.degree == 1
248
  return unaops[tree.op].(evalTreeArray(tree.l))
249
  else
250
+ return binops[tree.op].(evalTreeArray(tree.l), evalTreeArray(tree.r))
 
251
  end
252
  end
253
 
 
486
  return Node(unaops[tree.op](tree.l.val))
487
  end
488
  elseif tree.degree == 2
 
489
  tree.l = simplifyTree(tree.l)
490
+ tree.r = simplifyTree(tree.r)
491
  constantsBelow = (
492
  tree.l.degree == 0 && tree.l.constant &&
493
  tree.r.degree == 0 && tree.r.constant
pysr/sr.py CHANGED
@@ -254,7 +254,6 @@ const weights = convert(Array{Float32, 1}, """f"{weight_str})"
254
  command = [
255
  f'julia -O{julia_optimization:d}',
256
  f'-p {procs}',
257
- f'--threads 2', #Threading within each worker
258
  f'/tmp/.runfile_{rand_string}.jl',
259
  ]
260
  if timeout is not None:
 
254
  command = [
255
  f'julia -O{julia_optimization:d}',
256
  f'-p {procs}',
 
257
  f'/tmp/.runfile_{rand_string}.jl',
258
  ]
259
  if timeout is not None: