|
from .utils import cuda_kernel, cuda_launch, cuda_int32 |
|
import torch, collections |
|
|
|
costvol_out = """ |
|
extern "C" __global__ void __launch_bounds__(512) costvol_out( |
|
const int n, |
|
const {{type}}* __restrict__ tenOne, |
|
const {{type}}* __restrict__ tenTwo, |
|
{{type}}* __restrict__ tenOut |
|
) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) { |
|
const int intN = ( intIndex / SIZE_3(tenOut) / SIZE_2(tenOut) ) % SIZE_0(tenOut); |
|
const int intC = -1; |
|
const int intY = ( intIndex / SIZE_3(tenOut) ) % SIZE_2(tenOut); |
|
const int intX = ( intIndex ) % SIZE_3(tenOut); |
|
|
|
{{type}} fltOne[{{intChans}}]; |
|
|
|
for (int intValue = 0; intValue < SIZE_1(tenOne); intValue += 1) { |
|
fltOne[intValue] = VALUE_4(tenOne, intN, intValue, intY, intX); |
|
} |
|
|
|
int intOffset = OFFSET_4(tenOut, intN, 0, intY, intX); |
|
|
|
for (int intOy = intY - 4; intOy <= intY + 4; intOy += 1) { |
|
for (int intOx = intX - 4; intOx <= intX + 4; intOx += 1) { |
|
{{type}} fltValue = 0.0f; |
|
|
|
if ((intOy >= 0) && (intOy < SIZE_2(tenOut)) && (intOx >= 0) && (intOx < SIZE_3(tenOut))) { |
|
for (int intValue = 0; intValue < SIZE_1(tenOne); intValue += 1) { |
|
fltValue += abs(fltOne[intValue] - VALUE_4(tenTwo, intN, intValue, intOy, intOx)); |
|
} |
|
} else { |
|
for (int intValue = 0; intValue < SIZE_1(tenOne); intValue += 1) { |
|
fltValue += abs(fltOne[intValue]); |
|
} |
|
} |
|
|
|
tenOut[intOffset] = fltValue / SIZE_1(tenOne); |
|
intOffset += SIZE_2(tenOut) * SIZE_3(tenOut); |
|
} |
|
} |
|
} } |
|
""" |
|
|
|
costvol_onegrad = """ |
|
extern "C" __global__ void __launch_bounds__(512) costvol_onegrad( |
|
const int n, |
|
const {{type}}* __restrict__ tenOne, |
|
const {{type}}* __restrict__ tenTwo, |
|
const {{type}}* __restrict__ tenOutgrad, |
|
{{type}}* __restrict__ tenOnegrad, |
|
{{type}}* __restrict__ tenTwograd |
|
) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) { |
|
const int intN = ( intIndex / SIZE_3(tenOnegrad) / SIZE_2(tenOnegrad) ) % SIZE_0(tenOnegrad); |
|
const int intC = -1; |
|
const int intY = ( intIndex / SIZE_3(tenOnegrad) ) % SIZE_2(tenOnegrad); |
|
const int intX = ( intIndex ) % SIZE_3(tenOnegrad); |
|
|
|
{{type}} fltOne[{{intChans}}]; |
|
|
|
for (int intValue = 0; intValue < SIZE_1(tenOne); intValue += 1) { |
|
fltOne[intValue] = VALUE_4(tenOne, intN, intValue, intY, intX); |
|
} |
|
|
|
int intOffset = OFFSET_4(tenOutgrad, intN, 0, intY, intX); |
|
|
|
for (int intOy = intY - 4; intOy <= intY + 4; intOy += 1) { |
|
for (int intOx = intX - 4; intOx <= intX + 4; intOx += 1) { |
|
if ((intOy >= 0) && (intOy < SIZE_2(tenOutgrad)) && (intOx >= 0) && (intOx < SIZE_3(tenOutgrad))) { |
|
for (int intValue = 0; intValue < SIZE_1(tenOne); intValue += 1) { |
|
if (fltOne[intValue] - VALUE_4(tenTwo, intN, intValue, intOy, intOx) >= 0.0f) { |
|
tenOnegrad[OFFSET_4(tenOnegrad, intN, intValue, intY, intX)] += +tenOutgrad[intOffset] / SIZE_1(tenOne); |
|
} else { |
|
tenOnegrad[OFFSET_4(tenOnegrad, intN, intValue, intY, intX)] += -tenOutgrad[intOffset] / SIZE_1(tenOne); |
|
} |
|
} |
|
} else { |
|
for (int intValue = 0; intValue < SIZE_1(tenOne); intValue += 1) { |
|
if (fltOne[intValue] >= 0.0f) { |
|
tenOnegrad[OFFSET_4(tenOnegrad, intN, intValue, intY, intX)] += +tenOutgrad[intOffset] / SIZE_1(tenOne); |
|
} else { |
|
tenOnegrad[OFFSET_4(tenOnegrad, intN, intValue, intY, intX)] += -tenOutgrad[intOffset] / SIZE_1(tenOne); |
|
} |
|
} |
|
} |
|
|
|
intOffset += SIZE_2(tenOutgrad) * SIZE_3(tenOutgrad); |
|
} |
|
} |
|
} } |
|
""" |
|
|
|
costvol_twograd = """ |
|
extern "C" __global__ void __launch_bounds__(512) costvol_twograd( |
|
const int n, |
|
const {{type}}* __restrict__ tenOne, |
|
const {{type}}* __restrict__ tenTwo, |
|
const {{type}}* __restrict__ tenOutgrad, |
|
{{type}}* __restrict__ tenOnegrad, |
|
{{type}}* __restrict__ tenTwograd |
|
) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) { |
|
const int intN = ( intIndex / SIZE_3(tenTwograd) / SIZE_2(tenTwograd) ) % SIZE_0(tenTwograd); |
|
const int intC = -1; |
|
const int intY = ( intIndex / SIZE_3(tenTwograd) ) % SIZE_2(tenTwograd); |
|
const int intX = ( intIndex ) % SIZE_3(tenTwograd); |
|
|
|
{{type}} fltOne[{{intChans}}]; |
|
|
|
for (int intValue = 0; intValue < SIZE_1(tenOne); intValue += 1) { |
|
fltOne[intValue] = VALUE_4(tenOne, intN, intValue, intY, intX); |
|
} |
|
|
|
int intOffset = OFFSET_4(tenOutgrad, intN, 0, intY, intX); |
|
|
|
for (int intOy = intY - 4; intOy <= intY + 4; intOy += 1) { |
|
for (int intOx = intX - 4; intOx <= intX + 4; intOx += 1) { |
|
if ((intOy >= 0) && (intOy < SIZE_2(tenOutgrad)) && (intOx >= 0) && (intOx < SIZE_3(tenOutgrad))) { |
|
for (int intValue = 0; intValue < SIZE_1(tenOne); intValue += 1) { |
|
if (fltOne[intValue] - VALUE_4(tenTwo, intN, intValue, intOy, intOx) >= 0.0f) { |
|
atomicAdd(&tenTwograd[OFFSET_4(tenTwograd, intN, intValue, intOy, intOx)], -tenOutgrad[intOffset] / SIZE_1(tenOne)); |
|
} else { |
|
atomicAdd(&tenTwograd[OFFSET_4(tenTwograd, intN, intValue, intOy, intOx)], +tenOutgrad[intOffset] / SIZE_1(tenOne)); |
|
} |
|
} |
|
} else { |
|
// ... |
|
} |
|
|
|
intOffset += SIZE_2(tenOutgrad) * SIZE_3(tenOutgrad); |
|
} |
|
} |
|
} } |
|
""" |
|
|
|
class costvol_func(torch.autograd.Function): |
|
@staticmethod |
|
@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32) |
|
def forward(self, tenOne, tenTwo): |
|
tenOut = tenOne.new_empty( |
|
[tenOne.shape[0], 81, tenOne.shape[2], tenOne.shape[3]] |
|
) |
|
|
|
cuda_launch( |
|
cuda_kernel( |
|
"costvol_out", |
|
costvol_out, |
|
{ |
|
"intChans": tenOne.shape[1], |
|
"tenOne": tenOne, |
|
"tenTwo": tenTwo, |
|
"tenOut": tenOut, |
|
}, |
|
) |
|
)( |
|
grid=tuple( |
|
[ |
|
int( |
|
( |
|
(tenOut.shape[0] * tenOut.shape[2] * tenOut.shape[3]) |
|
+ 512 |
|
- 1 |
|
) |
|
/ 512 |
|
), |
|
1, |
|
1, |
|
] |
|
), |
|
block=tuple([512, 1, 1]), |
|
args=[ |
|
cuda_int32(tenOut.shape[0] * tenOut.shape[2] * tenOut.shape[3]), |
|
tenOne.data_ptr(), |
|
tenTwo.data_ptr(), |
|
tenOut.data_ptr(), |
|
], |
|
stream=collections.namedtuple("Stream", "ptr")( |
|
torch.cuda.current_stream().cuda_stream |
|
), |
|
) |
|
|
|
self.save_for_backward(tenOne, tenTwo) |
|
|
|
return tenOut |
|
|
|
|
|
|
|
@staticmethod |
|
@torch.cuda.amp.custom_bwd |
|
def backward(self, tenOutgrad): |
|
tenOne, tenTwo = self.saved_tensors |
|
|
|
tenOutgrad = tenOutgrad.contiguous() |
|
assert tenOutgrad.is_cuda == True |
|
|
|
tenOnegrad = ( |
|
tenOne.new_zeros( |
|
[tenOne.shape[0], tenOne.shape[1], tenOne.shape[2], tenOne.shape[3]] |
|
) |
|
if self.needs_input_grad[0] == True |
|
else None |
|
) |
|
tenTwograd = ( |
|
tenTwo.new_zeros( |
|
[tenTwo.shape[0], tenTwo.shape[1], tenTwo.shape[2], tenTwo.shape[3]] |
|
) |
|
if self.needs_input_grad[1] == True |
|
else None |
|
) |
|
|
|
if tenOnegrad is not None: |
|
cuda_launch( |
|
cuda_kernel( |
|
"costvol_onegrad", |
|
costvol_onegrad, |
|
{ |
|
"intChans": tenOne.shape[1], |
|
"tenOne": tenOne, |
|
"tenTwo": tenTwo, |
|
"tenOutgrad": tenOutgrad, |
|
"tenOnegrad": tenOnegrad, |
|
"tenTwograd": tenTwograd, |
|
}, |
|
) |
|
)( |
|
grid=tuple( |
|
[ |
|
int( |
|
( |
|
( |
|
tenOnegrad.shape[0] |
|
* tenOnegrad.shape[2] |
|
* tenOnegrad.shape[3] |
|
) |
|
+ 512 |
|
- 1 |
|
) |
|
/ 512 |
|
), |
|
1, |
|
1, |
|
] |
|
), |
|
block=tuple([512, 1, 1]), |
|
args=[ |
|
cuda_int32( |
|
tenOnegrad.shape[0] * tenOnegrad.shape[2] * tenOnegrad.shape[3] |
|
), |
|
tenOne.data_ptr(), |
|
tenTwo.data_ptr(), |
|
tenOutgrad.data_ptr(), |
|
tenOnegrad.data_ptr(), |
|
tenTwograd.data_ptr(), |
|
], |
|
stream=collections.namedtuple("Stream", "ptr")( |
|
torch.cuda.current_stream().cuda_stream |
|
), |
|
) |
|
|
|
|
|
if tenTwograd is not None: |
|
cuda_launch( |
|
cuda_kernel( |
|
"costvol_twograd", |
|
costvol_twograd, |
|
{ |
|
"intChans": tenOne.shape[1], |
|
"tenOne": tenOne, |
|
"tenTwo": tenTwo, |
|
"tenOutgrad": tenOutgrad, |
|
"tenOnegrad": tenOnegrad, |
|
"tenTwograd": tenTwograd, |
|
}, |
|
) |
|
)( |
|
grid=tuple( |
|
[ |
|
int( |
|
( |
|
( |
|
tenTwograd.shape[0] |
|
* tenTwograd.shape[2] |
|
* tenTwograd.shape[3] |
|
) |
|
+ 512 |
|
- 1 |
|
) |
|
/ 512 |
|
), |
|
1, |
|
1, |
|
] |
|
), |
|
block=tuple([512, 1, 1]), |
|
args=[ |
|
cuda_int32( |
|
tenTwograd.shape[0] * tenTwograd.shape[2] * tenTwograd.shape[3] |
|
), |
|
tenOne.data_ptr(), |
|
tenTwo.data_ptr(), |
|
tenOutgrad.data_ptr(), |
|
tenOnegrad.data_ptr(), |
|
tenTwograd.data_ptr(), |
|
], |
|
stream=collections.namedtuple("Stream", "ptr")( |
|
torch.cuda.current_stream().cuda_stream |
|
), |
|
) |
|
|
|
|
|
return tenOnegrad, tenTwograd, None, None |
|
|
|
|
|
|
|
|
|
|
|
|
|
__all__ = ["costvol_func"] |