|
import torch |
|
import torch.nn.functional as F |
|
from torchvision import transforms |
|
|
|
|
|
def sampling_grid(height, width): |
|
H, W = height, width |
|
grid = torch.stack([ |
|
torch.arange(W).view(1, -1).repeat(H, 1), |
|
torch.arange(H).view(-1, 1).repeat(1, W) |
|
], -1) |
|
grid = grid.view(1, H, W, 2) |
|
return grid |
|
|
|
|
|
def normalize_sampling_grid(coords): |
|
assert len(coords.shape) == 4, coords.shape |
|
assert coords.size(-1) == 2, coords.shape |
|
H, W = coords.shape[-3:-1] |
|
xs, ys = coords.split([1, 1], -1) |
|
xs = 2 * xs / (W - 1) - 1 |
|
ys = 2 * ys / (H - 1) - 1 |
|
return torch.cat([xs, ys], -1) |
|
|
|
|
|
def backward_warp(img2, flow, do_mask=False): |
|
""" |
|
Grid sample from img2 using the flow from img1->img2 to get a prediction of img1. |
|
|
|
flow: [B,2,H',W'] in units of pixels at its current resolution. The two channels |
|
should be (x,y) where larger y values correspond to lower parts of the image. |
|
""" |
|
|
|
|
|
|
|
if list(img2.shape[-2:]) != list(flow.shape[-2:]): |
|
scale = [img2.size(-1) / flow.size(-1), |
|
img2.size(-2) / flow.size(-2)] |
|
scale = torch.tensor(scale).view(1, 2, 1, 1).to(flow.device) |
|
flow = scale * transforms.Resize(img2.shape[-2:])(flow) |
|
|
|
B, C, H, W = img2.shape |
|
|
|
|
|
grid = sampling_grid(H, W).to(flow.device) + flow.permute(0, 2, 3, 1) |
|
|
|
|
|
grid = normalize_sampling_grid(grid) |
|
|
|
|
|
img1_pred = F.grid_sample(img2, grid, align_corners=True) |
|
|
|
if do_mask: |
|
mask = (grid[..., 0] > -1) & (grid[..., 0] < 1) & \ |
|
(grid[..., 1] > -1) & (grid[..., 1] < 1) |
|
mask = mask[:, None].to(img2.dtype) |
|
return (img1_pred, mask) |
|
|
|
else: |
|
return (img1_pred, torch.ones_like(grid[..., 0][:, None]).float()) |
|
|