|
import taichi as ti |
|
import taichi.math as tm |
|
from functools import reduce |
|
|
|
@ti.kernel |
|
def sepconv_out(tenIn: ti.types.ndarray(), tenVer: ti.types.ndarray(), tenHor: ti.types.ndarray(), tenOut: ti.types.ndarray()): |
|
N, C, H, W = tenIn.shape |
|
intIndex = 0 |
|
for i, ch, y, x in ti.ndrange(N, C, H, W): |
|
fltOut, fltKahanc, fltKahany, fltKahant = 0.0, 0.0, 0.0, 0.0 |
|
for intFy, intFx in ti.ndrange(tenVer.shape[1], tenHor.shape[1]): |
|
fltKahany = tenIn[i, ch, y + intFy, x + intFx] * tenVer[i, intFy, y, x] * tenHor[i, intFx, y, x] |
|
fltKahany = fltKahany - fltKahanc |
|
fltKahant = fltOut + fltKahany |
|
fltKahanc = (fltKahant - fltOut) - fltKahany |
|
fltOut = fltKahant |
|
tenOut[intIndex] = fltOut |
|
intIndex += 1 |
|
|
|
|
|
def worker_interface(op_name, tensors): |
|
if op_name == "sepconv_out": |
|
tenIn, tenVer, tenHor = tensors |
|
real_tenOut_shape = [ |
|
tenIn.shape[0], |
|
tenIn.shape[1], |
|
tenVer.shape[2] and tenHor.shape[2], |
|
tenVer.shape[3] and tenHor.shape[3], |
|
] |
|
tenOut = tenIn.new_zeros([ |
|
int(reduce(lambda a, b: a * b, real_tenOut_shape)) |
|
]) |
|
sepconv_out(tenIn, tenVer, tenHor, tenOut) |
|
tenOut = tenOut.view(*real_tenOut_shape) |
|
return (tenOut, ) |
|
|
|
raise NotImplementedError(op_name) |
|
|
|
__all__ = ["worker_interface"] |