|
import torch |
|
from .utils import cuda_launch, cuda_kernel, cuda_int32 |
|
|
|
sepconv_vergrad = """ |
|
extern "C" __global__ void __launch_bounds__(512) sepconv_vergrad( |
|
const int n, |
|
const {{type}}* __restrict__ tenIn, |
|
const {{type}}* __restrict__ tenVer, |
|
const {{type}}* __restrict__ tenHor, |
|
const {{type}}* __restrict__ tenOutgrad, |
|
{{type}}* __restrict__ tenIngrad, |
|
{{type}}* __restrict__ tenVergrad, |
|
{{type}}* __restrict__ tenHorgrad |
|
) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) { |
|
const int intN = ( intIndex / SIZE_3(tenVergrad) / SIZE_2(tenVergrad) / SIZE_1(tenVergrad) ) % SIZE_0(tenVergrad); |
|
const int intC = ( intIndex / SIZE_3(tenVergrad) / SIZE_2(tenVergrad) ) % SIZE_1(tenVergrad); |
|
const int intY = ( intIndex / SIZE_3(tenVergrad) ) % SIZE_2(tenVergrad); |
|
const int intX = ( intIndex ) % SIZE_3(tenVergrad); |
|
|
|
{{type}} fltVergrad = 0.0; |
|
|
|
{{type}} fltKahanc = 0.0; |
|
{{type}} fltKahany = 0.0; |
|
{{type}} fltKahant = 0.0; |
|
|
|
for (int intI = 0; intI < SIZE_1(tenIn); intI += 1) { |
|
for (int intFx = 0; intFx < SIZE_1(tenHor); intFx += 1) { |
|
fltKahany = VALUE_4(tenHor, intN, intFx, intY, intX) * VALUE_4(tenIn, intN, intI, intY + intC, intX + intFx) * VALUE_4(tenOutgrad, intN, intI, intY, intX); |
|
fltKahany = fltKahany - fltKahanc; |
|
fltKahant = fltVergrad + fltKahany; |
|
fltKahanc = (fltKahant - fltVergrad) - fltKahany; |
|
fltVergrad = fltKahant; |
|
} |
|
} |
|
|
|
tenVergrad[intIndex] = fltVergrad; |
|
} } |
|
""" |
|
|
|
sepconv_ingrad = """ |
|
extern "C" __global__ void __launch_bounds__(512) sepconv_ingrad( |
|
const int n, |
|
const {{type}}* __restrict__ tenIn, |
|
const {{type}}* __restrict__ tenVer, |
|
const {{type}}* __restrict__ tenHor, |
|
const {{type}}* __restrict__ tenOutgrad, |
|
{{type}}* __restrict__ tenIngrad, |
|
{{type}}* __restrict__ tenVergrad, |
|
{{type}}* __restrict__ tenHorgrad |
|
) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) { |
|
const int intN = ( intIndex / SIZE_3(tenIngrad) / SIZE_2(tenIngrad) / SIZE_1(tenIngrad) ) % SIZE_0(tenIngrad); |
|
const int intC = ( intIndex / SIZE_3(tenIngrad) / SIZE_2(tenIngrad) ) % SIZE_1(tenIngrad); |
|
const int intY = ( intIndex / SIZE_3(tenIngrad) ) % SIZE_2(tenIngrad); |
|
const int intX = ( intIndex ) % SIZE_3(tenIngrad); |
|
|
|
{{type}} fltIngrad = 0.0; |
|
|
|
{{type}} fltKahanc = 0.0; |
|
{{type}} fltKahany = 0.0; |
|
{{type}} fltKahant = 0.0; |
|
|
|
for (int intFy = 0; intFy < SIZE_1(tenVer); intFy += 1) { |
|
int intKy = intY + intFy - (SIZE_1(tenVer) - 1); |
|
|
|
if (intKy < 0) { continue; } |
|
if (intKy >= SIZE_2(tenVer)) { continue; } |
|
|
|
for (int intFx = 0; intFx < SIZE_1(tenHor); intFx += 1) { |
|
int intKx = intX + intFx - (SIZE_1(tenHor) - 1); |
|
|
|
if (intKx < 0) { continue; } |
|
if (intKx >= SIZE_3(tenHor)) { continue; } |
|
|
|
fltKahany = VALUE_4(tenVer, intN, (SIZE_1(tenVer) - 1) - intFy, intKy, intKx) * VALUE_4(tenHor, intN, (SIZE_1(tenHor) - 1) - intFx, intKy, intKx) * VALUE_4(tenOutgrad, intN, intC, intKy, intKx); |
|
fltKahany = fltKahany - fltKahanc; |
|
fltKahant = fltIngrad + fltKahany; |
|
fltKahanc = (fltKahant - fltIngrad) - fltKahany; |
|
fltIngrad = fltKahant; |
|
} |
|
} |
|
|
|
tenIngrad[intIndex] = fltIngrad; |
|
} } |
|
""" |
|
|
|
sepconv_out = """ |
|
extern "C" __global__ void __launch_bounds__(512) sepconv_out( |
|
const int n, |
|
const {{type}}* __restrict__ tenIn, |
|
const {{type}}* __restrict__ tenVer, |
|
const {{type}}* __restrict__ tenHor, |
|
{{type}}* __restrict__ tenOut |
|
) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) { |
|
const int intN = ( intIndex / SIZE_3(tenOut) / SIZE_2(tenOut) / SIZE_1(tenOut) ) % SIZE_0(tenOut); |
|
const int intC = ( intIndex / SIZE_3(tenOut) / SIZE_2(tenOut) ) % SIZE_1(tenOut); |
|
const int intY = ( intIndex / SIZE_3(tenOut) ) % SIZE_2(tenOut); |
|
const int intX = ( intIndex ) % SIZE_3(tenOut); |
|
|
|
{{type}} fltOut = 0.0; |
|
|
|
{{type}} fltKahanc = 0.0; |
|
{{type}} fltKahany = 0.0; |
|
{{type}} fltKahant = 0.0; |
|
|
|
for (int intFy = 0; intFy < SIZE_1(tenVer); intFy += 1) { |
|
for (int intFx = 0; intFx < SIZE_1(tenHor); intFx += 1) { |
|
fltKahany = VALUE_4(tenIn, intN, intC, intY + intFy, intX + intFx) * VALUE_4(tenVer, intN, intFy, intY, intX) * VALUE_4(tenHor, intN, intFx, intY, intX); |
|
fltKahany = fltKahany - fltKahanc; |
|
fltKahant = fltOut + fltKahany; |
|
fltKahanc = (fltKahant - fltOut) - fltKahany; |
|
fltOut = fltKahant; |
|
} |
|
} |
|
|
|
tenOut[intIndex] = fltOut; |
|
} } |
|
""" |
|
|
|
sepconv_horgrad = """ |
|
extern "C" __global__ void __launch_bounds__(512) sepconv_horgrad( |
|
const int n, |
|
const {{type}}* __restrict__ tenIn, |
|
const {{type}}* __restrict__ tenVer, |
|
const {{type}}* __restrict__ tenHor, |
|
const {{type}}* __restrict__ tenOutgrad, |
|
{{type}}* __restrict__ tenIngrad, |
|
{{type}}* __restrict__ tenVergrad, |
|
{{type}}* __restrict__ tenHorgrad |
|
) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) { |
|
const int intN = ( intIndex / SIZE_3(tenHorgrad) / SIZE_2(tenHorgrad) / SIZE_1(tenHorgrad) ) % SIZE_0(tenHorgrad); |
|
const int intC = ( intIndex / SIZE_3(tenHorgrad) / SIZE_2(tenHorgrad) ) % SIZE_1(tenHorgrad); |
|
const int intY = ( intIndex / SIZE_3(tenHorgrad) ) % SIZE_2(tenHorgrad); |
|
const int intX = ( intIndex ) % SIZE_3(tenHorgrad); |
|
|
|
{{type}} fltHorgrad = 0.0; |
|
|
|
{{type}} fltKahanc = 0.0; |
|
{{type}} fltKahany = 0.0; |
|
{{type}} fltKahant = 0.0; |
|
|
|
for (int intI = 0; intI < SIZE_1(tenIn); intI += 1) { |
|
for (int intFy = 0; intFy < SIZE_1(tenVer); intFy += 1) { |
|
fltKahany = VALUE_4(tenVer, intN, intFy, intY, intX) * VALUE_4(tenIn, intN, intI, intY + intFy, intX + intC) * VALUE_4(tenOutgrad, intN, intI, intY, intX); |
|
fltKahany = fltKahany - fltKahanc; |
|
fltKahant = fltHorgrad + fltKahany; |
|
fltKahanc = (fltKahant - fltHorgrad) - fltKahany; |
|
fltHorgrad = fltKahant; |
|
} |
|
} |
|
|
|
tenHorgrad[intIndex] = fltHorgrad; |
|
} } |
|
""" |
|
|
|
class sepconv_func(torch.autograd.Function): |
|
@staticmethod |
|
@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32) |
|
def forward(self, tenIn, tenVer, tenHor): |
|
tenOut = tenIn.new_empty( |
|
[ |
|
tenIn.shape[0], |
|
tenIn.shape[1], |
|
tenVer.shape[2] and tenHor.shape[2], |
|
tenVer.shape[3] and tenHor.shape[3], |
|
] |
|
) |
|
|
|
if tenIn.is_cuda == True: |
|
cuda_launch( |
|
cuda_kernel( |
|
"sepconv_out", |
|
sepconv_out, |
|
{ |
|
"tenIn": tenIn, |
|
"tenVer": tenVer, |
|
"tenHor": tenHor, |
|
"tenOut": tenOut, |
|
}, |
|
) |
|
)( |
|
grid=tuple([int((tenOut.nelement() + 512 - 1) / 512), 1, 1]), |
|
block=tuple([512, 1, 1]), |
|
args=[ |
|
cuda_int32(tenOut.nelement()), |
|
tenIn.data_ptr(), |
|
tenVer.data_ptr(), |
|
tenHor.data_ptr(), |
|
tenOut.data_ptr(), |
|
], |
|
) |
|
|
|
elif tenIn.is_cuda != True: |
|
assert False |
|
|
|
|
|
|
|
self.save_for_backward(tenIn, tenVer, tenHor) |
|
|
|
return tenOut |
|
|
|
|
|
|
|
@staticmethod |
|
@torch.cuda.amp.custom_bwd |
|
def backward(self, tenOutgrad): |
|
tenIn, tenVer, tenHor = self.saved_tensors |
|
|
|
tenOutgrad = tenOutgrad.contiguous() |
|
assert tenOutgrad.is_cuda == True |
|
|
|
tenIngrad = ( |
|
tenIn.new_empty( |
|
[tenIn.shape[0], tenIn.shape[1], tenIn.shape[2], tenIn.shape[3]] |
|
) |
|
if self.needs_input_grad[0] == True |
|
else None |
|
) |
|
tenVergrad = ( |
|
tenVer.new_empty( |
|
[tenVer.shape[0], tenVer.shape[1], tenVer.shape[2], tenVer.shape[3]] |
|
) |
|
if self.needs_input_grad[1] == True |
|
else None |
|
) |
|
tenHorgrad = ( |
|
tenHor.new_empty( |
|
[tenHor.shape[0], tenHor.shape[1], tenHor.shape[2], tenHor.shape[3]] |
|
) |
|
if self.needs_input_grad[2] == True |
|
else None |
|
) |
|
|
|
if tenIngrad is not None: |
|
cuda_launch( |
|
cuda_kernel( |
|
"sepconv_ingrad", |
|
sepconv_ingrad, |
|
{ |
|
"tenIn": tenIn, |
|
"tenVer": tenVer, |
|
"tenHor": tenHor, |
|
"tenOutgrad": tenOutgrad, |
|
"tenIngrad": tenIngrad, |
|
"tenVergrad": tenVergrad, |
|
"tenHorgrad": tenHorgrad, |
|
}, |
|
) |
|
)( |
|
grid=tuple([int((tenIngrad.nelement() + 512 - 1) / 512), 1, 1]), |
|
block=tuple([512, 1, 1]), |
|
args=[ |
|
cuda_int32(tenIngrad.nelement()), |
|
tenIn.data_ptr(), |
|
tenVer.data_ptr(), |
|
tenHor.data_ptr(), |
|
tenOutgrad.data_ptr(), |
|
tenIngrad.data_ptr(), |
|
None, |
|
None, |
|
], |
|
) |
|
|
|
|
|
if tenVergrad is not None: |
|
cuda_launch( |
|
cuda_kernel( |
|
"sepconv_vergrad", |
|
sepconv_vergrad, |
|
{ |
|
"tenIn": tenIn, |
|
"tenVer": tenVer, |
|
"tenHor": tenHor, |
|
"tenOutgrad": tenOutgrad, |
|
"tenIngrad": tenIngrad, |
|
"tenVergrad": tenVergrad, |
|
"tenHorgrad": tenHorgrad, |
|
}, |
|
) |
|
)( |
|
grid=tuple([int((tenVergrad.nelement() + 512 - 1) / 512), 1, 1]), |
|
block=tuple([512, 1, 1]), |
|
args=[ |
|
cuda_int32(tenVergrad.nelement()), |
|
tenIn.data_ptr(), |
|
tenVer.data_ptr(), |
|
tenHor.data_ptr(), |
|
tenOutgrad.data_ptr(), |
|
None, |
|
tenVergrad.data_ptr(), |
|
None, |
|
], |
|
) |
|
|
|
|
|
if tenHorgrad is not None: |
|
cuda_launch( |
|
cuda_kernel( |
|
"sepconv_horgrad", |
|
sepconv_horgrad, |
|
{ |
|
"tenIn": tenIn, |
|
"tenVer": tenVer, |
|
"tenHor": tenHor, |
|
"tenOutgrad": tenOutgrad, |
|
"tenIngrad": tenIngrad, |
|
"tenVergrad": tenVergrad, |
|
"tenHorgrad": tenHorgrad, |
|
}, |
|
) |
|
)( |
|
grid=tuple([int((tenHorgrad.nelement() + 512 - 1) / 512), 1, 1]), |
|
block=tuple([512, 1, 1]), |
|
args=[ |
|
cuda_int32(tenHorgrad.nelement()), |
|
tenIn.data_ptr(), |
|
tenVer.data_ptr(), |
|
tenHor.data_ptr(), |
|
tenOutgrad.data_ptr(), |
|
None, |
|
None, |
|
tenHorgrad.data_ptr(), |
|
], |
|
) |
|
|
|
|
|
return tenIngrad, tenVergrad, tenHorgrad |
|
|
|
|
|
|
|
|
|
|
|
__all__ = ["sepconv_func"] |