jaxmetaverse's picture
Upload folder using huggingface_hub
82ea528 verified
import torch
from .utils import cuda_launch, cuda_kernel, cuda_int32
import cupy
import collections
softsplat_flowgrad = """
extern "C" __global__ void __launch_bounds__(512) softsplat_flowgrad(
const int n,
const {{type}}* __restrict__ tenIn,
const {{type}}* __restrict__ tenFlow,
const {{type}}* __restrict__ tenOutgrad,
{{type}}* __restrict__ tenIngrad,
{{type}}* __restrict__ tenFlowgrad
) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) {
const int intN = ( intIndex / SIZE_3(tenFlowgrad) / SIZE_2(tenFlowgrad) / SIZE_1(tenFlowgrad) ) % SIZE_0(tenFlowgrad);
const int intC = ( intIndex / SIZE_3(tenFlowgrad) / SIZE_2(tenFlowgrad) ) % SIZE_1(tenFlowgrad);
const int intY = ( intIndex / SIZE_3(tenFlowgrad) ) % SIZE_2(tenFlowgrad);
const int intX = ( intIndex ) % SIZE_3(tenFlowgrad);
assert(SIZE_1(tenFlow) == 2);
{{type}} fltFlowgrad = 0.0f;
{{type}} fltX = ({{type}}) (intX) + VALUE_4(tenFlow, intN, 0, intY, intX);
{{type}} fltY = ({{type}}) (intY) + VALUE_4(tenFlow, intN, 1, intY, intX);
if (isfinite(fltX) == false) { return; }
if (isfinite(fltY) == false) { return; }
int intNorthwestX = (int) (floor(fltX));
int intNorthwestY = (int) (floor(fltY));
int intNortheastX = intNorthwestX + 1;
int intNortheastY = intNorthwestY;
int intSouthwestX = intNorthwestX;
int intSouthwestY = intNorthwestY + 1;
int intSoutheastX = intNorthwestX + 1;
int intSoutheastY = intNorthwestY + 1;
{{type}} fltNorthwest = 0.0f;
{{type}} fltNortheast = 0.0f;
{{type}} fltSouthwest = 0.0f;
{{type}} fltSoutheast = 0.0f;
if (intC == 0) {
fltNorthwest = (({{type}}) (-1.0f)) * (({{type}}) (intSoutheastY) - fltY);
fltNortheast = (({{type}}) (+1.0f)) * (({{type}}) (intSouthwestY) - fltY);
fltSouthwest = (({{type}}) (-1.0f)) * (fltY - ({{type}}) (intNortheastY));
fltSoutheast = (({{type}}) (+1.0f)) * (fltY - ({{type}}) (intNorthwestY));
} else if (intC == 1) {
fltNorthwest = (({{type}}) (intSoutheastX) - fltX) * (({{type}}) (-1.0f));
fltNortheast = (fltX - ({{type}}) (intSouthwestX)) * (({{type}}) (-1.0f));
fltSouthwest = (({{type}}) (intNortheastX) - fltX) * (({{type}}) (+1.0f));
fltSoutheast = (fltX - ({{type}}) (intNorthwestX)) * (({{type}}) (+1.0f));
}
for (int intChannel = 0; intChannel < SIZE_1(tenOutgrad); intChannel += 1) {
{{type}} fltIn = VALUE_4(tenIn, intN, intChannel, intY, intX);
if ((intNorthwestX >= 0) && (intNorthwestX < SIZE_3(tenOutgrad)) && (intNorthwestY >= 0) && (intNorthwestY < SIZE_2(tenOutgrad))) {
fltFlowgrad += VALUE_4(tenOutgrad, intN, intChannel, intNorthwestY, intNorthwestX) * fltIn * fltNorthwest;
}
if ((intNortheastX >= 0) && (intNortheastX < SIZE_3(tenOutgrad)) && (intNortheastY >= 0) && (intNortheastY < SIZE_2(tenOutgrad))) {
fltFlowgrad += VALUE_4(tenOutgrad, intN, intChannel, intNortheastY, intNortheastX) * fltIn * fltNortheast;
}
if ((intSouthwestX >= 0) && (intSouthwestX < SIZE_3(tenOutgrad)) && (intSouthwestY >= 0) && (intSouthwestY < SIZE_2(tenOutgrad))) {
fltFlowgrad += VALUE_4(tenOutgrad, intN, intChannel, intSouthwestY, intSouthwestX) * fltIn * fltSouthwest;
}
if ((intSoutheastX >= 0) && (intSoutheastX < SIZE_3(tenOutgrad)) && (intSoutheastY >= 0) && (intSoutheastY < SIZE_2(tenOutgrad))) {
fltFlowgrad += VALUE_4(tenOutgrad, intN, intChannel, intSoutheastY, intSoutheastX) * fltIn * fltSoutheast;
}
}
tenFlowgrad[intIndex] = fltFlowgrad;
} }
"""
softsplat_ingrad = """
extern "C" __global__ void __launch_bounds__(512) softsplat_ingrad(
const int n,
const {{type}}* __restrict__ tenIn,
const {{type}}* __restrict__ tenFlow,
const {{type}}* __restrict__ tenOutgrad,
{{type}}* __restrict__ tenIngrad,
{{type}}* __restrict__ tenFlowgrad
) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) {
const int intN = ( intIndex / SIZE_3(tenIngrad) / SIZE_2(tenIngrad) / SIZE_1(tenIngrad) ) % SIZE_0(tenIngrad);
const int intC = ( intIndex / SIZE_3(tenIngrad) / SIZE_2(tenIngrad) ) % SIZE_1(tenIngrad);
const int intY = ( intIndex / SIZE_3(tenIngrad) ) % SIZE_2(tenIngrad);
const int intX = ( intIndex ) % SIZE_3(tenIngrad);
assert(SIZE_1(tenFlow) == 2);
{{type}} fltIngrad = 0.0f;
{{type}} fltX = ({{type}}) (intX) + VALUE_4(tenFlow, intN, 0, intY, intX);
{{type}} fltY = ({{type}}) (intY) + VALUE_4(tenFlow, intN, 1, intY, intX);
if (isfinite(fltX) == false) { return; }
if (isfinite(fltY) == false) { return; }
int intNorthwestX = (int) (floor(fltX));
int intNorthwestY = (int) (floor(fltY));
int intNortheastX = intNorthwestX + 1;
int intNortheastY = intNorthwestY;
int intSouthwestX = intNorthwestX;
int intSouthwestY = intNorthwestY + 1;
int intSoutheastX = intNorthwestX + 1;
int intSoutheastY = intNorthwestY + 1;
{{type}} fltNorthwest = (({{type}}) (intSoutheastX) - fltX) * (({{type}}) (intSoutheastY) - fltY);
{{type}} fltNortheast = (fltX - ({{type}}) (intSouthwestX)) * (({{type}}) (intSouthwestY) - fltY);
{{type}} fltSouthwest = (({{type}}) (intNortheastX) - fltX) * (fltY - ({{type}}) (intNortheastY));
{{type}} fltSoutheast = (fltX - ({{type}}) (intNorthwestX)) * (fltY - ({{type}}) (intNorthwestY));
if ((intNorthwestX >= 0) && (intNorthwestX < SIZE_3(tenOutgrad)) && (intNorthwestY >= 0) && (intNorthwestY < SIZE_2(tenOutgrad))) {
fltIngrad += VALUE_4(tenOutgrad, intN, intC, intNorthwestY, intNorthwestX) * fltNorthwest;
}
if ((intNortheastX >= 0) && (intNortheastX < SIZE_3(tenOutgrad)) && (intNortheastY >= 0) && (intNortheastY < SIZE_2(tenOutgrad))) {
fltIngrad += VALUE_4(tenOutgrad, intN, intC, intNortheastY, intNortheastX) * fltNortheast;
}
if ((intSouthwestX >= 0) && (intSouthwestX < SIZE_3(tenOutgrad)) && (intSouthwestY >= 0) && (intSouthwestY < SIZE_2(tenOutgrad))) {
fltIngrad += VALUE_4(tenOutgrad, intN, intC, intSouthwestY, intSouthwestX) * fltSouthwest;
}
if ((intSoutheastX >= 0) && (intSoutheastX < SIZE_3(tenOutgrad)) && (intSoutheastY >= 0) && (intSoutheastY < SIZE_2(tenOutgrad))) {
fltIngrad += VALUE_4(tenOutgrad, intN, intC, intSoutheastY, intSoutheastX) * fltSoutheast;
}
tenIngrad[intIndex] = fltIngrad;
} }
"""
softsplat_out = """
extern "C" __global__ void __launch_bounds__(512) softsplat_out(
const int n,
const {{type}}* __restrict__ tenIn,
const {{type}}* __restrict__ tenFlow,
{{type}}* __restrict__ tenOut
) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) {
const int intN = ( intIndex / SIZE_3(tenOut) / SIZE_2(tenOut) / SIZE_1(tenOut) ) % SIZE_0(tenOut);
const int intC = ( intIndex / SIZE_3(tenOut) / SIZE_2(tenOut) ) % SIZE_1(tenOut);
const int intY = ( intIndex / SIZE_3(tenOut) ) % SIZE_2(tenOut);
const int intX = ( intIndex ) % SIZE_3(tenOut);
assert(SIZE_1(tenFlow) == 2);
{{type}} fltX = ({{type}}) (intX) + VALUE_4(tenFlow, intN, 0, intY, intX);
{{type}} fltY = ({{type}}) (intY) + VALUE_4(tenFlow, intN, 1, intY, intX);
if (isfinite(fltX) == false) { return; }
if (isfinite(fltY) == false) { return; }
{{type}} fltIn = VALUE_4(tenIn, intN, intC, intY, intX);
int intNorthwestX = (int) (floor(fltX));
int intNorthwestY = (int) (floor(fltY));
int intNortheastX = intNorthwestX + 1;
int intNortheastY = intNorthwestY;
int intSouthwestX = intNorthwestX;
int intSouthwestY = intNorthwestY + 1;
int intSoutheastX = intNorthwestX + 1;
int intSoutheastY = intNorthwestY + 1;
{{type}} fltNorthwest = (({{type}}) (intSoutheastX) - fltX) * (({{type}}) (intSoutheastY) - fltY);
{{type}} fltNortheast = (fltX - ({{type}}) (intSouthwestX)) * (({{type}}) (intSouthwestY) - fltY);
{{type}} fltSouthwest = (({{type}}) (intNortheastX) - fltX) * (fltY - ({{type}}) (intNortheastY));
{{type}} fltSoutheast = (fltX - ({{type}}) (intNorthwestX)) * (fltY - ({{type}}) (intNorthwestY));
if ((intNorthwestX >= 0) && (intNorthwestX < SIZE_3(tenOut)) && (intNorthwestY >= 0) && (intNorthwestY < SIZE_2(tenOut))) {
atomicAdd(&tenOut[OFFSET_4(tenOut, intN, intC, intNorthwestY, intNorthwestX)], fltIn * fltNorthwest);
}
if ((intNortheastX >= 0) && (intNortheastX < SIZE_3(tenOut)) && (intNortheastY >= 0) && (intNortheastY < SIZE_2(tenOut))) {
atomicAdd(&tenOut[OFFSET_4(tenOut, intN, intC, intNortheastY, intNortheastX)], fltIn * fltNortheast);
}
if ((intSouthwestX >= 0) && (intSouthwestX < SIZE_3(tenOut)) && (intSouthwestY >= 0) && (intSouthwestY < SIZE_2(tenOut))) {
atomicAdd(&tenOut[OFFSET_4(tenOut, intN, intC, intSouthwestY, intSouthwestX)], fltIn * fltSouthwest);
}
if ((intSoutheastX >= 0) && (intSoutheastX < SIZE_3(tenOut)) && (intSoutheastY >= 0) && (intSoutheastY < SIZE_2(tenOut))) {
atomicAdd(&tenOut[OFFSET_4(tenOut, intN, intC, intSoutheastY, intSoutheastX)], fltIn * fltSoutheast);
}
} }
"""
# end
class softsplat_func(torch.autograd.Function):
@staticmethod
@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
def forward(self, tenIn, tenFlow):
tenOut = tenIn.new_zeros(
[tenIn.shape[0], tenIn.shape[1], tenIn.shape[2], tenIn.shape[3]]
)
if tenIn.is_cuda == True:
cuda_launch(
cuda_kernel(
"softsplat_out",
softsplat_out,
{"tenIn": tenIn, "tenFlow": tenFlow, "tenOut": tenOut},
)
)(
grid=tuple([int((tenOut.nelement() + 512 - 1) / 512), 1, 1]),
block=tuple([512, 1, 1]),
args=[
cuda_int32(tenOut.nelement()),
tenIn.data_ptr(),
tenFlow.data_ptr(),
tenOut.data_ptr(),
],
stream=collections.namedtuple("Stream", "ptr")(
torch.cuda.current_stream().cuda_stream
),
)
elif tenIn.is_cuda != True:
assert False
# end
self.save_for_backward(tenIn, tenFlow)
return tenOut
# end
@staticmethod
@torch.cuda.amp.custom_bwd
def backward(self, tenOutgrad):
tenIn, tenFlow = self.saved_tensors
tenOutgrad = tenOutgrad.contiguous()
assert tenOutgrad.is_cuda == True
tenIngrad = (
tenIn.new_zeros(
[tenIn.shape[0], tenIn.shape[1], tenIn.shape[2], tenIn.shape[3]]
)
if self.needs_input_grad[0] == True
else None
)
tenFlowgrad = (
tenFlow.new_zeros(
[tenFlow.shape[0], tenFlow.shape[1], tenFlow.shape[2], tenFlow.shape[3]]
)
if self.needs_input_grad[1] == True
else None
)
if tenIngrad is not None:
cuda_launch(
cuda_kernel(
"softsplat_ingrad",
softsplat_ingrad,
{
"tenIn": tenIn,
"tenFlow": tenFlow,
"tenOutgrad": tenOutgrad,
"tenIngrad": tenIngrad,
"tenFlowgrad": tenFlowgrad,
},
)
)(
grid=tuple([int((tenIngrad.nelement() + 512 - 1) / 512), 1, 1]),
block=tuple([512, 1, 1]),
args=[
cuda_int32(tenIngrad.nelement()),
tenIn.data_ptr(),
tenFlow.data_ptr(),
tenOutgrad.data_ptr(),
tenIngrad.data_ptr(),
None,
],
stream=collections.namedtuple("Stream", "ptr")(
torch.cuda.current_stream().cuda_stream
),
)
# end
if tenFlowgrad is not None:
cuda_launch(
cuda_kernel(
"softsplat_flowgrad",
softsplat_flowgrad,
{
"tenIn": tenIn,
"tenFlow": tenFlow,
"tenOutgrad": tenOutgrad,
"tenIngrad": tenIngrad,
"tenFlowgrad": tenFlowgrad,
},
)
)(
grid=tuple([int((tenFlowgrad.nelement() + 512 - 1) / 512), 1, 1]),
block=tuple([512, 1, 1]),
args=[
cuda_int32(tenFlowgrad.nelement()),
tenIn.data_ptr(),
tenFlow.data_ptr(),
tenOutgrad.data_ptr(),
None,
tenFlowgrad.data_ptr(),
],
stream=collections.namedtuple("Stream", "ptr")(
torch.cuda.current_stream().cuda_stream
),
)
# end
return tenIngrad, tenFlowgrad
# end
def FunctionSoftsplat(tenInput, tenFlow, tenMetric, strType):
assert tenMetric is None or tenMetric.shape[1] == 1
assert strType in ["summation", "average", "linear", "softmax"]
if strType == "average":
tenInput = torch.cat(
[
tenInput,
tenInput.new_ones(
tenInput.shape[0], 1, tenInput.shape[2], tenInput.shape[3]
),
],
1,
)
elif strType == "linear":
tenInput = torch.cat([tenInput * tenMetric, tenMetric], 1)
elif strType == "softmax":
tenInput = torch.cat([tenInput * tenMetric.exp(), tenMetric.exp()], 1)
# end
tenOutput = softsplat_func.apply(tenInput, tenFlow)
if strType != "summation":
tenNormalize = tenOutput[:, -1:, :, :]
tenNormalize[tenNormalize == 0.0] = 1.0
tenOutput = tenOutput[:, :-1, :, :] / tenNormalize
# end
return tenOutput
# end
class ModuleSoftsplat(torch.nn.Module):
def __init__(self, strType):
super().__init__()
self.strType = strType
# end
def forward(self, tenInput, tenFlow, tenMetric):
return FunctionSoftsplat(tenInput, tenFlow, tenMetric, self.strType)
# end
# end
def softsplat(
tenIn: torch.Tensor, tenFlow: torch.Tensor, tenMetric: torch.Tensor, strMode: str
):
assert strMode.split("-")[0] in ["sum", "avg", "linear", "soft"]
if strMode == "sum":
assert tenMetric is None
if strMode == "avg":
assert tenMetric is None
if strMode.split("-")[0] == "linear":
assert tenMetric is not None
if strMode.split("-")[0] == "soft":
assert tenMetric is not None
if strMode == "avg":
tenIn = torch.cat(
[
tenIn,
tenIn.new_ones([tenIn.shape[0], 1, tenIn.shape[2], tenIn.shape[3]]),
],
1,
)
elif strMode.split("-")[0] == "linear":
tenIn = torch.cat([tenIn * tenMetric, tenMetric], 1)
elif strMode.split("-")[0] == "soft":
tenIn = torch.cat([tenIn * tenMetric.exp(), tenMetric.exp()], 1)
# end
tenOut = softsplat_func.apply(tenIn, tenFlow)
if strMode.split("-")[0] in ["avg", "linear", "soft"]:
tenNormalize = tenOut[:, -1:, :, :]
if len(strMode.split("-")) == 1:
tenNormalize = tenNormalize + 0.0000001
elif strMode.split("-")[1] == "addeps":
tenNormalize = tenNormalize + 0.0000001
elif strMode.split("-")[1] == "zeroeps":
tenNormalize[tenNormalize == 0.0] = 1.0
elif strMode.split("-")[1] == "clipeps":
tenNormalize = tenNormalize.clip(0.0000001, None)
# end
tenOut = tenOut[:, :-1, :, :] / tenNormalize
# end
return tenOut
# end
__all__ = ["FunctionSoftsplat", "ModuleSoftsplat", "softsplat", "softsplat_func"]