Spaces:
Running
Running
import SpecialFunctions: gamma, lgamma, erf, erfc, beta | |
import Base.FastMath: sqrt_llvm_fast, neg_float_fast, | |
add_float_fast, sub_float_fast, mul_float_fast, div_float_fast, rem_float_fast, | |
eq_float_fast, ne_float_fast, lt_float_fast, le_float_fast, | |
sign_fast, abs_fast, log_fast, log2_fast, log10_fast, sqrt_fast, | |
pow_fast | |
# Implicitly defined: | |
#binary: mod | |
#unary: exp, abs, log1p, sin, cos, tan, sinh, cosh, tanh, asin, acos, atan, asinh, acosh, atanh, erf, erfc, gamma, relu, round, floor, ceil, round, sign. | |
# Use some fast operators from https://github.com/JuliaLang/julia/blob/81597635c4ad1e8c2e1c5753fda4ec0e7397543f/base/fastmath.jl | |
# Define allowed operators. Any julia operator can also be used. | |
plus(x::Float32, y::Float32)::Float32 = add_float_fast(x, y) #Do not change the name of this operator. | |
sub(x::Float32, y::Float32)::Float32 = sub_float_fast(x, y) #Do not change the name of this operator. | |
mult(x::Float32, y::Float32)::Float32 = mul_float_fast(x, y) #Do not change the name of this operator. | |
square(x::Float32)::Float32 = mul_float_fast(x, x) | |
cube(x::Float32)::Float32 = mul_float_fast(mul_float_fast(x, x), x) | |
pow(x::Float32, y::Float32)::Float32 = sign_fast(x)*pow_fast(abs(x), y) | |
div(x::Float32, y::Float32)::Float32 = div_float_fast(x, y) | |
logm(x::Float32)::Float32 = log_fast(abs_fast(x) + 1f-8) | |
logm2(x::Float32)::Float32 = log2_fast(abs_fast(x) + 1f-8) | |
logm10(x::Float32)::Float32 = log10_fast(abs_fast(x) + 1f-8) | |
sqrtm(x::Float32)::Float32 = sqrt_fast(abs_fast(x)) | |
neg(x::Float32)::Float32 = sub_float_fast(x) | |
function greater(x::Float32, y::Float32)::Float32 | |
if x > y | |
return 1f0 | |
end | |
return 0f0 | |
end | |
function relu(x::Float32)::Float32 | |
if x > 0f0 | |
return x | |
end | |
return 0f0 | |
end | |
function logical_or(x::Float32, y::Float32)::Float32 | |
if x > 0f0 || y > 0f0 | |
return 1f0 | |
end | |
return 0f0 | |
end | |
# (Just use multiplication normally) | |
function logical_and(x::Float32, y::Float32)::Float32 | |
if x > 0f0 && y > 0f0 | |
return 1f0 | |
end | |
return 0f0 | |
end | |