import torch import utils.basic import torch.nn.functional as F def bilinear_sample2d(im, x, y, return_inbounds=False): # x and y are each B, N # output is B, C, N B, C, H, W = list(im.shape) N = list(x.shape)[1] x = x.float() y = y.float() H_f = torch.tensor(H, dtype=torch.float32) W_f = torch.tensor(W, dtype=torch.float32) # inbound_mask = (x>-0.5).float()*(y>-0.5).float()*(x -0.5).byte() & (x < float(W_f - 0.5)).byte() y_valid = (y > -0.5).byte() & (y < float(H_f - 0.5)).byte() inbounds = (x_valid & y_valid).float() inbounds = inbounds.reshape(B, N) # something seems wrong here for B>1; i'm getting an error here (or downstream if i put -1) return output, inbounds return output # B, C, N def paste_crop_on_canvas(crop, box2d_unnorm, H, W, fast=True, mask=None, canvas=None): # this is the inverse of crop_and_resize_box2d B, C, Y, X = list(crop.shape) B2, D = list(box2d_unnorm.shape) assert(B == B2) assert(D == 4) # here, we want to place the crop into a bigger image, # at the location specified by the box2d. if canvas is None: canvas = torch.zeros((B, C, H, W), device=crop.device) else: B2, C2, H2, W2 = canvas.shape assert(B==B2) assert(C==C2) assert(H==H2) assert(W==W2) # box2d_unnorm = utils.geom.unnormalize_box2d(box2d, H, W) if fast: ymin = box2d_unnorm[:, 0].long() xmin = box2d_unnorm[:, 1].long() ymax = box2d_unnorm[:, 2].long() xmax = box2d_unnorm[:, 3].long() w = (xmax - xmin).float() h = (ymax - ymin).float() grids = utils.basic.gridcloud2d(B, H, W) grids_flat = grids.reshape(B, -1, 2) # grids_flat[:, :, 0] = (grids_flat[:, :, 0] - xmin.float().unsqueeze(1)) / w.unsqueeze(1) * X # grids_flat[:, :, 1] = (grids_flat[:, :, 1] - ymin.float().unsqueeze(1)) / h.unsqueeze(1) * Y # for each pixel in the main image, # grids_flat tells us where to sample in the crop image # print('grids_flat', grids_flat.shape) # print('crop', crop.shape) grids_flat[:, :, 0] = (grids_flat[:, :, 0] - xmin.float().unsqueeze(1)) / w.unsqueeze(1) * 2.0 - 1.0 grids_flat[:, :, 1] = (grids_flat[:, :, 1] - ymin.float().unsqueeze(1)) / h.unsqueeze(1) * 2.0 - 1.0 grid = grids_flat.reshape(B,H,W,2) canvas = F.grid_sample(crop, grid, align_corners=False) # print('canvas', canvas.shape) # if mask is None: # crop_resamp, inb = bilinear_sample2d(crop, grids_flat[:, :, 0], grids_flat[:, :, 1], return_inbounds=True) # crop_resamp = crop_resamp.reshape(B, C, H, W) # inb = inb.reshape(B, 1, H, W) # canvas = canvas * (1 - inb) + crop_resamp * inb # else: # full_resamp = bilinear_sample2d(torch.cat([crop, mask], dim=1), grids_flat[:, :, 0], grids_flat[:, :, 1]) # full_resamp = full_resamp.reshape(B, C+1, H, W) # crop_resamp = full_resamp[:,:3] # mask_resamp = full_resamp[:,3:4] # canvas = canvas * (1 - mask_resamp) + crop_resamp * mask_resamp else: for b in range(B): ymin = box2d_unnorm[b, 0].long() xmin = box2d_unnorm[b, 1].long() ymax = box2d_unnorm[b, 2].long() xmax = box2d_unnorm[b, 3].long() crop_b = F.interpolate(crop[b:b + 1], (ymax - ymin, xmax - xmin)).squeeze(0) # print('canvas[b,:,...', canvas[b,:,ymin:ymax,xmin:xmax].shape) # print('crop_b', crop_b.shape) canvas[b, :, ymin:ymax, xmin:xmax] = crop_b return canvas