File size: 5,406 Bytes
4f6b78d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
import torch
import utils.basic
import torch.nn.functional as F

def bilinear_sample2d(im, x, y, return_inbounds=False):
    # x and y are each B, N
    # output is B, C, N
    B, C, H, W = list(im.shape)
    N = list(x.shape)[1]

    x = x.float()
    y = y.float()
    H_f = torch.tensor(H, dtype=torch.float32)
    W_f = torch.tensor(W, dtype=torch.float32)
    
    # inbound_mask = (x>-0.5).float()*(y>-0.5).float()*(x<W_f+0.5).float()*(y<H_f+0.5).float()

    max_y = (H_f - 1).int()
    max_x = (W_f - 1).int()

    x0 = torch.floor(x).int()
    x1 = x0 + 1
    y0 = torch.floor(y).int()
    y1 = y0 + 1
    
    x0_clip = torch.clamp(x0, 0, max_x)
    x1_clip = torch.clamp(x1, 0, max_x)
    y0_clip = torch.clamp(y0, 0, max_y)
    y1_clip = torch.clamp(y1, 0, max_y)
    dim2 = W
    dim1 = W * H

    base = torch.arange(0, B, dtype=torch.int64, device=x.device)*dim1
    base = torch.reshape(base, [B, 1]).repeat([1, N])

    base_y0 = base + y0_clip * dim2
    base_y1 = base + y1_clip * dim2

    idx_y0_x0 = base_y0 + x0_clip
    idx_y0_x1 = base_y0 + x1_clip
    idx_y1_x0 = base_y1 + x0_clip
    idx_y1_x1 = base_y1 + x1_clip

    # use the indices to lookup pixels in the flat image
    # im is B x C x H x W
    # move C out to last dim
    im_flat = (im.permute(0, 2, 3, 1)).reshape(B*H*W, C)
    i_y0_x0 = im_flat[idx_y0_x0.long()]
    i_y0_x1 = im_flat[idx_y0_x1.long()]
    i_y1_x0 = im_flat[idx_y1_x0.long()]
    i_y1_x1 = im_flat[idx_y1_x1.long()]

    # Finally calculate interpolated values.
    x0_f = x0.float()
    x1_f = x1.float()
    y0_f = y0.float()
    y1_f = y1.float()

    w_y0_x0 = ((x1_f - x) * (y1_f - y)).unsqueeze(2)
    w_y0_x1 = ((x - x0_f) * (y1_f - y)).unsqueeze(2)
    w_y1_x0 = ((x1_f - x) * (y - y0_f)).unsqueeze(2)
    w_y1_x1 = ((x - x0_f) * (y - y0_f)).unsqueeze(2)

    output = w_y0_x0 * i_y0_x0 + w_y0_x1 * i_y0_x1 + \
             w_y1_x0 * i_y1_x0 + w_y1_x1 * i_y1_x1
    # output is B*N x C
    output = output.view(B, -1, C)
    output = output.permute(0, 2, 1)
    # output is B x C x N

    if return_inbounds:
        x_valid = (x > -0.5).byte() & (x < float(W_f - 0.5)).byte()
        y_valid = (y > -0.5).byte() & (y < float(H_f - 0.5)).byte()
        inbounds = (x_valid & y_valid).float()
        inbounds = inbounds.reshape(B, N) # something seems wrong here for B>1; i'm getting an error here (or downstream if i put -1)
        return output, inbounds

    return output # B, C, N

def paste_crop_on_canvas(crop, box2d_unnorm, H, W, fast=True, mask=None, canvas=None):
    # this is the inverse of crop_and_resize_box2d
    B, C, Y, X = list(crop.shape)
    B2, D = list(box2d_unnorm.shape)
    assert(B == B2)
    assert(D == 4)

    # here, we want to place the crop into a bigger image,
    # at the location specified by the box2d.

    if canvas is None:
        canvas = torch.zeros((B, C, H, W), device=crop.device)
    else:
        B2, C2, H2, W2 = canvas.shape
        assert(B==B2)
        assert(C==C2)
        assert(H==H2)
        assert(W==W2)

    # box2d_unnorm = utils.geom.unnormalize_box2d(box2d, H, W)

    if fast:
        ymin = box2d_unnorm[:, 0].long()
        xmin = box2d_unnorm[:, 1].long()
        ymax = box2d_unnorm[:, 2].long()
        xmax = box2d_unnorm[:, 3].long()
        w = (xmax - xmin).float()
        h = (ymax - ymin).float()

        grids = utils.basic.gridcloud2d(B, H, W)
        grids_flat = grids.reshape(B, -1, 2)
        # grids_flat[:, :, 0] = (grids_flat[:, :, 0] - xmin.float().unsqueeze(1)) / w.unsqueeze(1) * X
        # grids_flat[:, :, 1] = (grids_flat[:, :, 1] - ymin.float().unsqueeze(1)) / h.unsqueeze(1) * Y

        # for each pixel in the main image,
        # grids_flat tells us where to sample in the crop image

        # print('grids_flat', grids_flat.shape)
        # print('crop', crop.shape)

        grids_flat[:, :, 0] = (grids_flat[:, :, 0] - xmin.float().unsqueeze(1)) / w.unsqueeze(1) * 2.0 - 1.0
        grids_flat[:, :, 1] = (grids_flat[:, :, 1] - ymin.float().unsqueeze(1)) / h.unsqueeze(1) * 2.0 - 1.0
        
        grid = grids_flat.reshape(B,H,W,2)

        canvas = F.grid_sample(crop, grid, align_corners=False)
        # print('canvas', canvas.shape)
        
        # if mask is None:
        #     crop_resamp, inb = bilinear_sample2d(crop, grids_flat[:, :, 0], grids_flat[:, :, 1], return_inbounds=True)
        #     crop_resamp = crop_resamp.reshape(B, C, H, W)
        #     inb = inb.reshape(B, 1, H, W)
        #     canvas = canvas * (1 - inb) + crop_resamp * inb
        # else:
        #     full_resamp = bilinear_sample2d(torch.cat([crop, mask], dim=1), grids_flat[:, :, 0], grids_flat[:, :, 1])
        #     full_resamp = full_resamp.reshape(B, C+1, H, W)
        #     crop_resamp = full_resamp[:,:3]
        #     mask_resamp = full_resamp[:,3:4]
        #     canvas = canvas * (1 - mask_resamp) + crop_resamp * mask_resamp
    else:
        for b in range(B):
            ymin = box2d_unnorm[b, 0].long()
            xmin = box2d_unnorm[b, 1].long()
            ymax = box2d_unnorm[b, 2].long()
            xmax = box2d_unnorm[b, 3].long()

            crop_b = F.interpolate(crop[b:b + 1], (ymax - ymin, xmax - xmin)).squeeze(0)

            # print('canvas[b,:,...', canvas[b,:,ymin:ymax,xmin:xmax].shape)
            # print('crop_b', crop_b.shape)

            canvas[b, :, ymin:ymax, xmin:xmax] = crop_b
    return canvas