File size: 2,055 Bytes
2c6b73d
 
211c6a7
 
 
 
14f69cf
 
211c6a7
 
 
 
 
 
5fcecfa
211c6a7
d781e18
211c6a7
c835184
 
14f69cf
211c6a7
 
 
 
 
a02cc8a
78a18f0
2c6b73d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import SpecialFunctions: gamma, lgamma, erf, erfc, beta


import Base.FastMath: sqrt_llvm_fast, neg_float_fast,
    add_float_fast, sub_float_fast, mul_float_fast, div_float_fast, rem_float_fast,
    eq_float_fast, ne_float_fast, lt_float_fast, le_float_fast,
    sign_fast, abs_fast, log_fast, log2_fast, log10_fast, sqrt_fast,
    pow_fast

# Implicitly defined:
#binary: mod
#unary: exp, abs, log1p, sin, cos, tan, sinh, cosh, tanh, asin, acos, atan, asinh, acosh, atanh, erf, erfc, gamma, relu, round, floor, ceil, round, sign.

# Use some fast operators from https://github.com/JuliaLang/julia/blob/81597635c4ad1e8c2e1c5753fda4ec0e7397543f/base/fastmath.jl
# Define allowed operators. Any julia operator can also be used.
plus(x::Float32, y::Float32)::Float32 = add_float_fast(x, y) #Do not change the name of this operator.
sub(x::Float32, y::Float32)::Float32 = sub_float_fast(x, y) #Do not change the name of this operator.
mult(x::Float32, y::Float32)::Float32 = mul_float_fast(x, y) #Do not change the name of this operator.
square(x::Float32)::Float32 = mul_float_fast(x, x)
cube(x::Float32)::Float32 = mul_float_fast(mul_float_fast(x, x), x)
pow(x::Float32, y::Float32)::Float32 = sign_fast(x)*pow_fast(abs(x), y)
div(x::Float32, y::Float32)::Float32 = div_float_fast(x, y)
logm(x::Float32)::Float32 = log_fast(abs_fast(x) + 1f-8)
logm2(x::Float32)::Float32 = log2_fast(abs_fast(x) + 1f-8)
logm10(x::Float32)::Float32 = log10_fast(abs_fast(x) + 1f-8)
sqrtm(x::Float32)::Float32 = sqrt_fast(abs_fast(x))
neg(x::Float32)::Float32 = neg_float_fast(x)

function greater(x::Float32, y::Float32)::Float32
    if x > y
        return 1f0
    end
    return 0f0
end

function relu(x::Float32)::Float32
    if x > 0f0
        return x
    end
    return 0f0
end

function logical_or(x::Float32, y::Float32)::Float32
    if x > 0f0 || y > 0f0
        return 1f0
    end
    return 0f0
end

# (Just use multiplication normally)
function logical_and(x::Float32, y::Float32)::Float32
    if x > 0f0 && y > 0f0
        return 1f0
    end
    return 0f0
end