jaxmetaverse's picture
Upload folder using huggingface_hub
82ea528 verified
import taichi as ti
import taichi.math as tm
from functools import reduce
@ti.kernel
def sepconv_out(tenIn: ti.types.ndarray(), tenVer: ti.types.ndarray(), tenHor: ti.types.ndarray(), tenOut: ti.types.ndarray()):
N, C, H, W = tenIn.shape
intIndex = 0
for i, ch, y, x in ti.ndrange(N, C, H, W):
fltOut, fltKahanc, fltKahany, fltKahant = 0.0, 0.0, 0.0, 0.0
for intFy, intFx in ti.ndrange(tenVer.shape[1], tenHor.shape[1]):
fltKahany = tenIn[i, ch, y + intFy, x + intFx] * tenVer[i, intFy, y, x] * tenHor[i, intFx, y, x]
fltKahany = fltKahany - fltKahanc
fltKahant = fltOut + fltKahany
fltKahanc = (fltKahant - fltOut) - fltKahany
fltOut = fltKahant
tenOut[intIndex] = fltOut
intIndex += 1
def worker_interface(op_name, tensors):
if op_name == "sepconv_out":
tenIn, tenVer, tenHor = tensors
real_tenOut_shape = [
tenIn.shape[0],
tenIn.shape[1],
tenVer.shape[2] and tenHor.shape[2],
tenVer.shape[3] and tenHor.shape[3],
]
tenOut = tenIn.new_zeros([
int(reduce(lambda a, b: a * b, real_tenOut_shape))
])
sepconv_out(tenIn, tenVer, tenHor, tenOut)
tenOut = tenOut.view(*real_tenOut_shape)
return (tenOut, )
raise NotImplementedError(op_name)
__all__ = ["worker_interface"]