MilesCranmer commited on
Commit
1d23dc2
·
1 Parent(s): 32f7c64

Move entire operator loop to code generation

Browse files
Files changed (2) hide show
  1. julia/sr.jl +2 -8
  2. pysr/sr.py +9 -9
julia/sr.jl CHANGED
@@ -285,9 +285,7 @@ function evalTreeArray(tree::Node, cX::Array{Float32, 2})::Union{Array{Float32,
285
  return nothing
286
  end
287
  op_idx = tree.op
288
- @inbounds @simd for i=1:clen
289
- cumulator[i] = UNAOP(op_idx, cumulator[i])
290
- end
291
  @inbounds for i=1:clen
292
  if isinf(cumulator[i]) || isnan(cumulator[i])
293
  return nothing
@@ -303,12 +301,8 @@ function evalTreeArray(tree::Node, cX::Array{Float32, 2})::Union{Array{Float32,
303
  if array2 === nothing
304
  return nothing
305
  end
306
-
307
  op_idx = tree.op
308
-
309
- @inbounds @simd for i=1:clen
310
- cumulator[i] = BINOP(op_idx, cumulator[i], array2[i])
311
- end
312
  @inbounds for i=1:clen
313
  if isinf(cumulator[i]) || isnan(cumulator[i])
314
  return nothing
 
285
  return nothing
286
  end
287
  op_idx = tree.op
288
+ UNAOP!(op_idx, cumulator)
 
 
289
  @inbounds for i=1:clen
290
  if isinf(cumulator[i]) || isnan(cumulator[i])
291
  return nothing
 
301
  if array2 === nothing
302
  return nothing
303
  end
 
304
  op_idx = tree.op
305
+ BINOP!(op_idx, cumulator, array2)
 
 
 
306
  @inbounds for i=1:clen
307
  if isinf(cumulator[i]) || isnan(cumulator[i])
308
  return nothing
pysr/sr.py CHANGED
@@ -286,27 +286,27 @@ const limitPowComplexity = {"true" if limitPowComplexity else "false"}
286
 
287
  op_runner = ""
288
  if len(binary_operators) > 0:
289
- op_runner += f"""
290
- @inline function BINOP(i::Int, x::Float32, y::Float32)::Float32
291
  if i == 1
292
- return @fastmath {binary_operators[0]}(x, y)"""
293
  for i in range(1, len(binary_operators)):
294
  op_runner += f"""
295
  elseif i == {i+1}
296
- return @fastmath {binary_operators[i]}(x, y)"""
297
  op_runner += """
298
  end
299
  end"""
300
 
301
  if len(unary_operators) > 0:
302
- op_runner += f"""
303
- @inline function UNAOP(i::Int, x::Float32)::Float32
304
  if i == 1
305
- return @fastmath {unary_operators[0]}(x)"""
306
  for i in range(1, len(unary_operators)):
307
- op_runner += f"""
308
  elseif i == {i+1}
309
- return @fastmath {unary_operators[i]}(x)"""
310
  op_runner += """
311
  end
312
  end"""
 
286
 
287
  op_runner = ""
288
  if len(binary_operators) > 0:
289
+ op_runner += """
290
+ function BINOP!(i::Int, x::Array{Float32, 1}, y::Array{Float32, 1})::Array{Float32, 1}
291
  if i == 1
292
+ x .= @fastmath """f"{binary_operators[0]}"".(x, y)"
293
  for i in range(1, len(binary_operators)):
294
  op_runner += f"""
295
  elseif i == {i+1}
296
+ x .= @fastmath {binary_operators[i]}.(x, y)"""
297
  op_runner += """
298
  end
299
  end"""
300
 
301
  if len(unary_operators) > 0:
302
+ op_runner += """
303
+ function UNAOP!(i::Int, x::Array{Float32, 1})::Array{Float32, 1}
304
  if i == 1
305
+ x .= @fastmath """f"{unary_operators[0]}.(x)"
306
  for i in range(1, len(unary_operators)):
307
+ op_runner += """
308
  elseif i == {i+1}
309
+ x .= @fastmath """f"{unary_operators[i]}.(x)"
310
  op_runner += """
311
  end
312
  end"""