MilesCranmer commited on
Commit
e2ae8ef
·
1 Parent(s): d781e18

Meta-program branching table rather than array of operators

Browse files
Files changed (2) hide show
  1. julia/sr.jl +5 -4
  2. pysr/sr.py +38 -3
julia/sr.jl CHANGED
@@ -281,9 +281,9 @@ function evalTreeArray(tree::Node, cX::Array{Float32, 2})::Union{Array{Float32,
281
  if cumulator == nothing
282
  return nothing
283
  end
284
- op = unaops[tree.op]
285
  @inbounds @simd for i=1:clen
286
- cumulator[i] = op(cumulator[i])
287
  end
288
  @inbounds for i=1:clen
289
  if isinf(cumulator[i]) || isnan(cumulator[i])
@@ -292,7 +292,6 @@ function evalTreeArray(tree::Node, cX::Array{Float32, 2})::Union{Array{Float32,
292
  end
293
  return cumulator
294
  else
295
- op = binops[tree.op]
296
  cumulator = evalTreeArray(tree.l, cX)
297
  if cumulator == nothing
298
  return nothing
@@ -302,8 +301,10 @@ function evalTreeArray(tree::Node, cX::Array{Float32, 2})::Union{Array{Float32,
302
  return nothing
303
  end
304
 
 
 
305
  @inbounds @simd for i=1:clen
306
- cumulator[i] = op(cumulator[i], array2[i])
307
  end
308
  @inbounds for i=1:clen
309
  if isinf(cumulator[i]) || isnan(cumulator[i])
 
281
  if cumulator == nothing
282
  return nothing
283
  end
284
+ op_idx = tree.op
285
  @inbounds @simd for i=1:clen
286
+ cumulator[i] = UNAOP(op_idx, cumulator[i])
287
  end
288
  @inbounds for i=1:clen
289
  if isinf(cumulator[i]) || isnan(cumulator[i])
 
292
  end
293
  return cumulator
294
  else
 
295
  cumulator = evalTreeArray(tree.l, cX)
296
  if cumulator == nothing
297
  return nothing
 
301
  return nothing
302
  end
303
 
304
+ op_idx = tree.op
305
+
306
  @inbounds @simd for i=1:clen
307
+ cumulator[i] = BINOP(op_idx, cumulator[i], array2[i])
308
  end
309
  @inbounds for i=1:clen
310
  if isinf(cumulator[i]) || isnan(cumulator[i])
pysr/sr.py CHANGED
@@ -242,8 +242,8 @@ def pysr(X=None, y=None, weights=None,
242
  op_list[i] = function_name
243
 
244
  def_hyperparams += f"""include("{pkg_directory}/operators.jl")
245
- const binops = @fastmath {'[' + ', '.join(binary_operators) + ']'}
246
- const unaops = @fastmath {'[' + ', '.join(unary_operators) + ']'}
247
  const ns=10;
248
  const parsimony = {parsimony:f}f0
249
  const alpha = {alpha:f}f0
@@ -275,7 +275,42 @@ const mutationWeights = [
275
  {weightDoNothing:f}
276
  ]
277
  const warmupMaxsize = {warmupMaxsize:d}
278
- """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
279
 
280
  if X.shape[1] == 1:
281
  X_str = 'transpose([' + str(X.tolist()).replace(']', '').replace(',', '').replace('[', '') + '])'
 
242
  op_list[i] = function_name
243
 
244
  def_hyperparams += f"""include("{pkg_directory}/operators.jl")
245
+ const binops = {'[' + ', '.join(binary_operators) + ']'}
246
+ const unaops = {'[' + ', '.join(unary_operators) + ']'}
247
  const ns=10;
248
  const parsimony = {parsimony:f}f0
249
  const alpha = {alpha:f}f0
 
275
  {weightDoNothing:f}
276
  ]
277
  const warmupMaxsize = {warmupMaxsize:d}
278
+ """
279
+
280
+ op_runner = ""
281
+ if len(binary_operators) > 0:
282
+ op_runner += f"""
283
+ @inline function BINOP(i::Int, x::Float32, y::Float32)::Float32
284
+ if i == 1
285
+ return @fastmath {binary_operators[0]}(x, y)
286
+ """
287
+ for i in range(1, len(binary_operators)):
288
+ op_runner += f"""
289
+ elseif i == {i+1}
290
+ return @fastmath {binary_operators[i]}(x, y)
291
+ """
292
+ op_runner += """
293
+ end
294
+ end
295
+ """
296
+
297
+ if len(unary_operators) > 0:
298
+ op_runner += f"""
299
+ @inline function UNAOP(i::Int, x::Float32)::Float32
300
+ if i == 1
301
+ return @fastmath {unary_operators[0]}(x)
302
+ """
303
+ for i in range(1, len(unary_operators)):
304
+ op_runner += f"""
305
+ elseif i == {i+1}
306
+ return @fastmath {unary_operators[i]}(x)
307
+ """
308
+ op_runner += """
309
+ end
310
+ end
311
+ """
312
+
313
+ def_hyperparams += op_runner
314
 
315
  if X.shape[1] == 1:
316
  X_str = 'transpose([' + str(X.tolist()).replace(']', '').replace(',', '').replace('[', '') + '])'