File size: 1,411 Bytes
82ea528 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 |
import taichi as ti
import taichi.math as tm
from functools import reduce
@ti.kernel
def sepconv_out(tenIn: ti.types.ndarray(), tenVer: ti.types.ndarray(), tenHor: ti.types.ndarray(), tenOut: ti.types.ndarray()):
N, C, H, W = tenIn.shape
intIndex = 0
for i, ch, y, x in ti.ndrange(N, C, H, W):
fltOut, fltKahanc, fltKahany, fltKahant = 0.0, 0.0, 0.0, 0.0
for intFy, intFx in ti.ndrange(tenVer.shape[1], tenHor.shape[1]):
fltKahany = tenIn[i, ch, y + intFy, x + intFx] * tenVer[i, intFy, y, x] * tenHor[i, intFx, y, x]
fltKahany = fltKahany - fltKahanc
fltKahant = fltOut + fltKahany
fltKahanc = (fltKahant - fltOut) - fltKahany
fltOut = fltKahant
tenOut[intIndex] = fltOut
intIndex += 1
def worker_interface(op_name, tensors):
if op_name == "sepconv_out":
tenIn, tenVer, tenHor = tensors
real_tenOut_shape = [
tenIn.shape[0],
tenIn.shape[1],
tenVer.shape[2] and tenHor.shape[2],
tenVer.shape[3] and tenHor.shape[3],
]
tenOut = tenIn.new_zeros([
int(reduce(lambda a, b: a * b, real_tenOut_shape))
])
sepconv_out(tenIn, tenVer, tenHor, tenOut)
tenOut = tenOut.view(*real_tenOut_shape)
return (tenOut, )
raise NotImplementedError(op_name)
__all__ = ["worker_interface"] |