Spaces:
Sleeping
Sleeping
#include <torch/extension.h> | |
#include <THC/THCAtomics.cuh> | |
#include <vector> | |
#include <iostream> | |
using namespace torch::indexing; | |
#define THREADS 256 | |
#define BLOCKS(n) (n + THREADS - 1) / THREADS | |
__forceinline__ __device__ | |
bool within_bounds(int h, int w, int H, int W) { | |
return h >= 0 && h < H && w >= 0 && w < W; | |
} | |
template <typename scalar_t> | |
__global__ void patchify_forward_kernel(int R, | |
const torch::PackedTensorAccessor32<scalar_t,4,torch::RestrictPtrTraits> net, | |
const torch::PackedTensorAccessor32<float,3,torch::RestrictPtrTraits> coords, | |
torch::PackedTensorAccessor32<scalar_t,5,torch::RestrictPtrTraits> patches) | |
{ | |
// diameter | |
const int D = 2*R + 2; | |
const int B = coords.size(0); | |
const int M = coords.size(1); | |
const int C = net.size(1); | |
const int H = net.size(2); | |
const int W = net.size(3); | |
int n = blockIdx.x * blockDim.x + threadIdx.x; | |
if (n < B * M * D * D) { | |
const int ii = n % D; n /= D; | |
const int jj = n % D; n /= D; | |
const int m = n % M; n /= M; | |
const float x = coords[n][m][0]; | |
const float y = coords[n][m][1]; | |
const int i = static_cast<int>(floor(y)) + (ii - R); | |
const int j = static_cast<int>(floor(x)) + (jj - R); | |
if (within_bounds(i, j, H, W)) { | |
for (int k=0; k<C; k++) | |
patches[n][m][k][ii][jj] = net[n][k][i][j]; | |
} | |
} | |
} | |
template <typename scalar_t> | |
__global__ void patchify_backward_kernel(int R, | |
const torch::PackedTensorAccessor32<scalar_t,5,torch::RestrictPtrTraits> patch_gradient, | |
const torch::PackedTensorAccessor32<float,3,torch::RestrictPtrTraits> coords, | |
torch::PackedTensorAccessor32<scalar_t,4,torch::RestrictPtrTraits> gradient) | |
{ | |
// diameter | |
const int D = 2*R + 2; | |
const int B = coords.size(0); | |
const int M = coords.size(1); | |
const int C = gradient.size(1); | |
const int H = gradient.size(2); | |
const int W = gradient.size(3); | |
int n = blockIdx.x * blockDim.x + threadIdx.x; | |
if (n < B * M * D * D) { | |
const int ii = n % D; n /= D; | |
const int jj = n % D; n /= D; | |
const int m = n % M; n /= M; | |
const float x = coords[n][m][0]; | |
const float y = coords[n][m][1]; | |
const int i = static_cast<int>(floor(y)) + (ii - R); | |
const int j = static_cast<int>(floor(x)) + (jj - R); | |
if (within_bounds(i, j, H, W)) { | |
for (int k=0; k<C; k++) | |
atomicAdd(&gradient[n][k][i][j], patch_gradient[n][m][k][ii][jj]); | |
} | |
} | |
} | |
template <typename scalar_t> | |
__global__ void corr_forward_kernel(int R, | |
const torch::PackedTensorAccessor32<scalar_t,5,torch::RestrictPtrTraits> fmap1, | |
const torch::PackedTensorAccessor32<scalar_t,5,torch::RestrictPtrTraits> fmap2, | |
const torch::PackedTensorAccessor32<float,5,torch::RestrictPtrTraits> coords, | |
const torch::PackedTensorAccessor32<long,1,torch::RestrictPtrTraits> us, | |
const torch::PackedTensorAccessor32<long,1,torch::RestrictPtrTraits> vs, | |
torch::PackedTensorAccessor32<scalar_t,6,torch::RestrictPtrTraits> corr) | |
{ | |
// diameter | |
const int D = 2*R + 2; | |
const int B = coords.size(0); | |
const int M = coords.size(1); | |
const int H = coords.size(3); | |
const int W = coords.size(4); | |
const int C = fmap1.size(2); | |
const int H2 = fmap2.size(3); | |
const int W2 = fmap2.size(4); | |
int n = blockIdx.x * blockDim.x + threadIdx.x; | |
if (n < B * M * H * W * D * D) { | |
const int jj = n % D; n /= D; | |
const int ii = n % D; n /= D; | |
const int j0 = n % W; n /= W; | |
const int i0 = n % H; n /= H; | |
const int m = n % M; n /= M; | |
const int ix = us[m]; | |
const int jx = vs[m]; | |
const float x = coords[n][m][0][i0][j0]; | |
const float y = coords[n][m][1][i0][j0]; | |
const int i1 = static_cast<int>(floor(y)) + (ii - R); | |
const int j1 = static_cast<int>(floor(x)) + (jj - R); | |
scalar_t s = 0; | |
if (within_bounds(i1, j1, H2, W2)) { | |
#pragma unroll 8 | |
for (int i=0; i<C; i+=8) { | |
scalar_t f1[8]; for (int j=0; j<8; j++) f1[j] = fmap1[n][ix][i+j][i0][j0]; | |
scalar_t f2[8]; for (int j=0; j<8; j++) f2[j] = fmap2[n][jx][i+j][i1][j1]; | |
#pragma unroll | |
for (int j=0; j<8; j++) s += f1[j] * f2[j]; | |
} | |
} | |
corr[n][m][ii][jj][i0][j0] = s; | |
} | |
} | |
template <typename scalar_t> | |
__global__ void corr_backward_kernel(int R, | |
const torch::PackedTensorAccessor32<scalar_t,5,torch::RestrictPtrTraits> fmap1, | |
const torch::PackedTensorAccessor32<scalar_t,5,torch::RestrictPtrTraits> fmap2, | |
const torch::PackedTensorAccessor32<float,5,torch::RestrictPtrTraits> coords, | |
const torch::PackedTensorAccessor32<long,1,torch::RestrictPtrTraits> us, | |
const torch::PackedTensorAccessor32<long,1,torch::RestrictPtrTraits> vs, | |
const torch::PackedTensorAccessor32<float,6,torch::RestrictPtrTraits> corr_grad, | |
torch::PackedTensorAccessor32<scalar_t,5,torch::RestrictPtrTraits> fmap1_grad, | |
torch::PackedTensorAccessor32<scalar_t,5,torch::RestrictPtrTraits> fmap2_grad) | |
{ | |
// diameter | |
const int D = 2*R + 2; | |
const int B = coords.size(0); | |
const int M = coords.size(1); | |
const int H = coords.size(3); | |
const int W = coords.size(4); | |
const int C = fmap1.size(2); | |
const int H2 = fmap2.size(3); | |
const int W2 = fmap2.size(4); | |
int n = blockIdx.x * blockDim.x + threadIdx.x; | |
if (n < B * M * H * W * D * D) { | |
const int jj = n % D; n /= D; | |
const int ii = n % D; n /= D; | |
const int j0 = n % W; n /= W; | |
const int i0 = n % H; n /= H; | |
const int m = n % M; n /= M; | |
const int ix = us[m]; | |
const int jx = vs[m]; | |
const float x = coords[n][m][0][i0][j0]; | |
const float y = coords[n][m][1][i0][j0]; | |
const int i1 = static_cast<int>(floor(y)) + (ii - R); | |
const int j1 = static_cast<int>(floor(x)) + (jj - R); | |
const scalar_t g = (scalar_t) corr_grad[n][m][ii][jj][i0][j0]; | |
if (within_bounds(i1, j1, H2, W2)) { | |
#pragma unroll 32 | |
for (int i=0; i<C; i++) { | |
atomicAdd(&fmap1_grad[n][ix][i][i0][j0], g * fmap2[n][jx][i][i1][j1]); | |
atomicAdd(&fmap2_grad[n][jx][i][i1][j1], g * fmap1[n][ix][i][i0][j0]); | |
} | |
} | |
} | |
} | |
std::vector<torch::Tensor> corr_cuda_forward( | |
torch::Tensor fmap1, | |
torch::Tensor fmap2, | |
torch::Tensor coords, | |
torch::Tensor ii, | |
torch::Tensor jj, | |
int radius) | |
{ | |
const int B = coords.size(0); | |
const int M = coords.size(1); | |
const int H = coords.size(3); | |
const int W = coords.size(4); | |
const int D = 2 * radius + 2; | |
auto opts = fmap1.options(); | |
auto corr = torch::empty({B, M, D, D, H, W}, opts); | |
AT_DISPATCH_FLOATING_TYPES_AND_HALF(fmap1.type(), "corr_forward_kernel", ([&] { | |
corr_forward_kernel<scalar_t><<<BLOCKS(B * M * H * W * D * D), THREADS>>>(radius, | |
fmap1.packed_accessor32<scalar_t,5,torch::RestrictPtrTraits>(), | |
fmap2.packed_accessor32<scalar_t,5,torch::RestrictPtrTraits>(), | |
coords.packed_accessor32<float,5,torch::RestrictPtrTraits>(), | |
ii.packed_accessor32<long,1,torch::RestrictPtrTraits>(), | |
jj.packed_accessor32<long,1,torch::RestrictPtrTraits>(), | |
corr.packed_accessor32<scalar_t,6,torch::RestrictPtrTraits>()); | |
})); | |
torch::Tensor x = coords.index({Slice(), Slice(), 0, None, None}); | |
torch::Tensor y = coords.index({Slice(), Slice(), 1, None, None}); | |
torch::Tensor dx = x - x.floor(); dx = dx.to(fmap1.dtype()); | |
torch::Tensor dy = y - y.floor(); dy = dy.to(fmap2.dtype()); | |
torch::Tensor out; | |
out = (1 - dx) * (1 - dy) * corr.index({Slice(), Slice(), Slice(0, D-1), Slice(0, D-1)}); | |
out += (dx) * (1 - dy) * corr.index({Slice(), Slice(), Slice(0, D-1), Slice(1, D-0)}); | |
out += (1 - dx) * (dy) * corr.index({Slice(), Slice(), Slice(1, D-0), Slice(0, D-1)}); | |
out += (dx) * (dy) * corr.index({Slice(), Slice(), Slice(1, D-0), Slice(1, D-0)}); | |
return { out.permute({0,1,3,2,4,5}) }; | |
} | |
std::vector<torch::Tensor> corr_cuda_backward( | |
torch::Tensor fmap1, | |
torch::Tensor fmap2, | |
torch::Tensor coords, | |
torch::Tensor ii, | |
torch::Tensor jj, | |
torch::Tensor grad, | |
int radius) | |
{ | |
const int B = coords.size(0); | |
const int M = coords.size(1); | |
const int H = coords.size(3); | |
const int W = coords.size(4); | |
const int D = 2 * radius + 2; | |
grad = grad.permute({0,1,3,2,4,5}).contiguous(); | |
torch::Tensor x = coords.index({Slice(), Slice(), 0, None, None}); | |
torch::Tensor y = coords.index({Slice(), Slice(), 1, None, None}); | |
torch::Tensor dx = x - x.floor(); | |
torch::Tensor dy = y - y.floor(); | |
auto opts = torch::TensorOptions().dtype(torch::kFloat).device(torch::kCUDA); | |
torch::Tensor g1 = torch::zeros({B, M, D, D, H, W}, grad.options()); | |
torch::Tensor g2 = torch::zeros({B, M, D, D, H, W}, grad.options()); | |
torch::Tensor g3 = torch::zeros({B, M, D, D, H, W}, grad.options()); | |
torch::Tensor g4 = torch::zeros({B, M, D, D, H, W}, grad.options()); | |
g1.index_put_({Slice(), Slice(), Slice(0, D-1), Slice(0, D-1)}, (1 - dx) * (1 - dy) * grad); | |
g2.index_put_({Slice(), Slice(), Slice(0, D-1), Slice(1, D-0)}, (dx) * (1 - dy) * grad); | |
g3.index_put_({Slice(), Slice(), Slice(1, D-0), Slice(0, D-1)}, (1 - dx) * (dy) * grad); | |
g4.index_put_({Slice(), Slice(), Slice(1, D-0), Slice(1, D-0)}, (dx) * (dy) * grad); | |
torch::Tensor corr_grad = g1 + g2 + g3 + g4; | |
auto fmap1_grad = torch::zeros_like(fmap1); | |
auto fmap2_grad = torch::zeros_like(fmap2); | |
AT_DISPATCH_FLOATING_TYPES_AND_HALF(fmap1.type(), "corr_backward_kernel", ([&] { | |
corr_backward_kernel<scalar_t><<<BLOCKS(B * M * H * W * D * D), THREADS>>>(radius, | |
fmap1.packed_accessor32<scalar_t,5,torch::RestrictPtrTraits>(), | |
fmap2.packed_accessor32<scalar_t,5,torch::RestrictPtrTraits>(), | |
coords.packed_accessor32<float,5,torch::RestrictPtrTraits>(), | |
ii.packed_accessor32<long,1,torch::RestrictPtrTraits>(), | |
jj.packed_accessor32<long,1,torch::RestrictPtrTraits>(), | |
corr_grad.packed_accessor32<float,6,torch::RestrictPtrTraits>(), | |
fmap1_grad.packed_accessor32<scalar_t,5,torch::RestrictPtrTraits>(), | |
fmap2_grad.packed_accessor32<scalar_t,5,torch::RestrictPtrTraits>()); | |
})); | |
return {fmap1_grad, fmap2_grad}; | |
} | |
std::vector<torch::Tensor> patchify_cuda_forward( | |
torch::Tensor net, torch::Tensor coords, int radius) | |
{ | |
const int B = coords.size(0); | |
const int M = coords.size(1); | |
const int C = net.size(1); | |
const int D = 2 * radius + 2; | |
auto opts = net.options(); | |
auto patches = torch::zeros({B, M, C, D, D}, opts); | |
AT_DISPATCH_FLOATING_TYPES_AND_HALF(net.type(), "patchify_forward_kernel", ([&] { | |
patchify_forward_kernel<scalar_t><<<BLOCKS(B * M * D * D), THREADS>>>(radius, | |
net.packed_accessor32<scalar_t,4,torch::RestrictPtrTraits>(), | |
coords.packed_accessor32<float,3,torch::RestrictPtrTraits>(), | |
patches.packed_accessor32<scalar_t,5,torch::RestrictPtrTraits>()); | |
})); | |
return { patches }; | |
} | |
std::vector<torch::Tensor> patchify_cuda_backward( | |
torch::Tensor net, | |
torch::Tensor coords, | |
torch::Tensor gradient, | |
int radius) | |
{ | |
const int B = coords.size(0); | |
const int M = coords.size(1); | |
const int C = net.size(1); | |
const int H = net.size(2); | |
const int W = net.size(3); | |
const int D = 2 * radius + 2; | |
torch::Tensor net_gradient = torch::zeros_like(net); | |
AT_DISPATCH_FLOATING_TYPES_AND_HALF(net.type(), "patchify_backward_kernel", ([&] { | |
patchify_backward_kernel<scalar_t><<<BLOCKS(B * M * D * D), THREADS>>>(radius, | |
gradient.packed_accessor32<scalar_t,5,torch::RestrictPtrTraits>(), | |
coords.packed_accessor32<float,3,torch::RestrictPtrTraits>(), | |
net_gradient.packed_accessor32<scalar_t,4,torch::RestrictPtrTraits>()); | |
})); | |
return { net_gradient }; | |
} |