Spaces:
Running
Running
File size: 4,632 Bytes
b7eedf7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 |
import torch
import torch.nn.functional as F
import droid_backends
class CorrSampler(torch.autograd.Function):
@staticmethod
def forward(ctx, volume, coords, radius):
ctx.save_for_backward(volume,coords)
ctx.radius = radius
corr, = droid_backends.corr_index_forward(volume, coords, radius)
return corr
@staticmethod
def backward(ctx, grad_output):
volume, coords = ctx.saved_tensors
grad_output = grad_output.contiguous()
grad_volume, = droid_backends.corr_index_backward(volume, coords, grad_output, ctx.radius)
return grad_volume, None, None
class CorrBlock:
def __init__(self, fmap1, fmap2, num_levels=4, radius=3):
self.num_levels = num_levels
self.radius = radius
self.corr_pyramid = []
# all pairs correlation
corr = CorrBlock.corr(fmap1, fmap2)
batch, num, h1, w1, h2, w2 = corr.shape
corr = corr.reshape(batch*num*h1*w1, 1, h2, w2)
for i in range(self.num_levels):
self.corr_pyramid.append(
corr.view(batch*num, h1, w1, h2//2**i, w2//2**i))
corr = F.avg_pool2d(corr, 2, stride=2)
def __call__(self, coords):
out_pyramid = []
batch, num, ht, wd, _ = coords.shape
coords = coords.permute(0,1,4,2,3)
coords = coords.contiguous().view(batch*num, 2, ht, wd)
for i in range(self.num_levels):
corr = CorrSampler.apply(self.corr_pyramid[i], coords/2**i, self.radius)
out_pyramid.append(corr.view(batch, num, -1, ht, wd))
return torch.cat(out_pyramid, dim=2)
def cat(self, other):
for i in range(self.num_levels):
self.corr_pyramid[i] = torch.cat([self.corr_pyramid[i], other.corr_pyramid[i]], 0)
return self
def __getitem__(self, index):
for i in range(self.num_levels):
self.corr_pyramid[i] = self.corr_pyramid[i][index]
return self
@staticmethod
def corr(fmap1, fmap2):
""" all-pairs correlation """
batch, num, dim, ht, wd = fmap1.shape
fmap1 = fmap1.reshape(batch*num, dim, ht*wd) / 4.0
fmap2 = fmap2.reshape(batch*num, dim, ht*wd) / 4.0
corr = torch.matmul(fmap1.transpose(1,2), fmap2)
return corr.view(batch, num, ht, wd, ht, wd)
class CorrLayer(torch.autograd.Function):
@staticmethod
def forward(ctx, fmap1, fmap2, coords, r):
ctx.r = r
ctx.save_for_backward(fmap1, fmap2, coords)
corr, = droid_backends.altcorr_forward(fmap1, fmap2, coords, ctx.r)
return corr
@staticmethod
def backward(ctx, grad_corr):
fmap1, fmap2, coords = ctx.saved_tensors
grad_corr = grad_corr.contiguous()
fmap1_grad, fmap2_grad, coords_grad = \
droid_backends.altcorr_backward(fmap1, fmap2, coords, grad_corr, ctx.r)
return fmap1_grad, fmap2_grad, coords_grad, None
class AltCorrBlock:
def __init__(self, fmaps, num_levels=4, radius=3):
self.num_levels = num_levels
self.radius = radius
B, N, C, H, W = fmaps.shape
fmaps = fmaps.view(B*N, C, H, W) / 4.0
self.pyramid = []
for i in range(self.num_levels):
sz = (B, N, H//2**i, W//2**i, C)
fmap_lvl = fmaps.permute(0, 2, 3, 1).contiguous()
self.pyramid.append(fmap_lvl.view(*sz))
fmaps = F.avg_pool2d(fmaps, 2, stride=2)
def corr_fn(self, coords, ii, jj):
B, N, H, W, S, _ = coords.shape
coords = coords.permute(0, 1, 4, 2, 3, 5)
corr_list = []
for i in range(self.num_levels):
r = self.radius
fmap1_i = self.pyramid[0][:, ii]
fmap2_i = self.pyramid[i][:, jj]
coords_i = (coords / 2**i).reshape(B*N, S, H, W, 2).contiguous()
fmap1_i = fmap1_i.reshape((B*N,) + fmap1_i.shape[2:])
fmap2_i = fmap2_i.reshape((B*N,) + fmap2_i.shape[2:])
corr = CorrLayer.apply(fmap1_i.float(), fmap2_i.float(), coords_i, self.radius)
corr = corr.view(B, N, S, -1, H, W).permute(0, 1, 3, 4, 5, 2)
corr_list.append(corr)
corr = torch.cat(corr_list, dim=2)
return corr
def __call__(self, coords, ii, jj):
squeeze_output = False
if len(coords.shape) == 5:
coords = coords.unsqueeze(dim=-2)
squeeze_output = True
corr = self.corr_fn(coords, ii, jj)
if squeeze_output:
corr = corr.squeeze(dim=-1)
return corr.contiguous()
|