File size: 6,478 Bytes
6dfcb0f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 |
from einops import rearrange
import torch
import numpy as np
from torchvision import transforms
def unpatchify(labels, norm=True):
# Define the input tensor
B = labels.shape[0] # batch size
N_patches = int(np.sqrt(labels.shape[1])) # number of patches along each dimension
patch_size = int(np.sqrt(labels.shape[2] / 3)) # patch size along each dimension
channels = 3 # number of channels
rec_imgs = rearrange(labels, 'b n (p c) -> b n p c', c=3)
# Notice: To visualize the reconstruction video, we add the predict and the original mean and var of each patch.
rec_imgs = rearrange(rec_imgs,
'b (t h w) (p0 p1 p2) c -> b c (t p0) (h p1) (w p2)',
p0=1,
p1=patch_size,
p2=patch_size,
h=N_patches,
w=N_patches)
if norm:
MEAN = torch.from_numpy(np.array((0.485, 0.456, 0.406))[None, :, None, None, None]).cuda().half()
STD = torch.from_numpy(np.array((0.229, 0.224, 0.225))[None, :, None, None, None]).cuda().half()
rec_imgs = (rec_imgs - MEAN) / STD
return rec_imgs
def upsample_masks(masks, size, thresh=0.5):
shape = masks.shape
dtype = masks.dtype
h, w = shape[-2:]
H, W = size
if (H == h) and (W == w):
return masks
elif (H < h) and (W < w):
s = (h // H, w // W)
return masks[..., ::s[0], ::s[1]]
masks = masks.unsqueeze(-2).unsqueeze(-1)
masks = masks.repeat(*([1] * (len(shape) - 2)), 1, H // h, 1, W // w)
if ((H % h) == 0) and ((W % w) == 0):
masks = masks.view(*shape[:-2], H, W)
else:
_H = np.prod(masks.shape[-4:-2])
_W = np.prod(masks.shape[-2:])
masks = transforms.Resize(size)(masks.view(-1, 1, _H, _W)) > thresh
masks = masks.view(*shape[:2], H, W).to(masks.dtype)
return masks
def get_keypoints_batch(model, x,
n_samples,
n_rounds,
frac=0.25,
mask=None,
pool='avg',
):
"""x = image pair tensor
n_samples = number of potential candidates to look at on each round
(produces one new unmasked per round)
n_rounds = total number of unmasked patches
frac = how often to do random sampling vs error-based sampling
mask = initial mask
"""
# .half()
B = x.shape[0]
IMAGE_SIZE = [224, 224]
predictor = model
patch_size = predictor.patch_size[-1]
num_frames = predictor.num_frames
patch_num = IMAGE_SIZE[0] // patch_size
# this is setup for getting per-patch error
if pool == 'avg':
pool_op = torch.nn.AvgPool2d(patch_size, stride=patch_size)
elif pool == 'max':
pool_op = torch.nn.MaxPool2d(patch_size, stride=patch_size)
# initiazing rng
rng = np.random.RandomState(seed=0)
n_patches = patch_num * patch_num
# initializing mask at the fully masked state
mshape = num_frames * patch_num * patch_num
mshape_masked = (num_frames - 1) * patch_num * patch_num
if mask is None:
mask = torch.ones([B, mshape], dtype=torch.bool)
mask[:, :mshape_masked] = False
err_array = []
choices = []
# flows = []
for round_num in range(n_rounds):
# print(round_num)
# get the current prediction with current state of the mask
# .... produces out_flow b/c it's with head-motion condition
out = unpatchify(predictor(x, mask, forward_full=True))
# print(out.shape)
keypoint_recon = out.clone()
# flow = teacher.predict_flow(out)
# flows.append(flow)
# get the error map
err_mat = (out[:, :, 0] - x[:, :, -1]).abs().mean(1)
# pool it to patch-size
pooled_err = pool_op(err_mat[:, None])
# flatten the rror
flat_pooled_error = pooled_err.flatten(1, 3)
# set error to be zero where the mask is unmasked so it doesn't interfere
flat_pooled_error[mask[:, -n_patches:] == False] = 0
# sort patches by where the error is highest
err_sort = torch.argsort(flat_pooled_error, -1)
new_mask = mask.clone().detach()
errors = []
tries = []
err_choices = 0
# look at various candidates to reveal in the next round
for sample_num in range(n_samples):
# if sample_num % 10 == 0:
# print("%d/%d" % (sample_num, n_samples))
# either randomly sample
err_choices += 1
new_try = (num_frames - 1) * n_patches + err_sort[:, -1 * err_choices]
tries.append(new_try)
for k in range(B):
new_mask[k, new_try[k]] = False
reshaped_new_mask = upsample_masks(
new_mask.view(B, num_frames, IMAGE_SIZE[1] // patch_size, IMAGE_SIZE[1] // patch_size)[:, (num_frames - 1):],
IMAGE_SIZE)[:, 0]
# print(reshaped_new_mask.sum())
out = unpatchify(predictor(x, new_mask, forward_full=True))
abs_error = (out[:, :, 0] - x[:, :, -1]).abs().sum(1).cpu()
masked_abs_error = abs_error * reshaped_new_mask
error = masked_abs_error.flatten(1, 2).sum(-1)
errors.append(error)
# take the best one
for k in range(B):
new_mask[k, new_try[k]] = True
errors = torch.stack(errors, 1)
tries = torch.stack(tries, 1)
best_ind = torch.argmin(errors, dim=-1)
best = torch.tensor([tries[k, best_ind[k]] for k in range(B)])
choices.append(best)
err_array.append(errors)
# print(best)
for k in range(B):
mask[k, best[k]] = False
feat = predictor(x, mask, forward_full=True, return_features=True)
feat = feat#[:, :784*2]
choices = torch.stack(choices, 1)
#get x y coordinates of the keypoints
choices = choices % mshape_masked
choices_x = choices % (patch_num)
choices_y = choices // (patch_num)
choices = torch.stack([choices_x, choices_y], 2)
out = unpatchify(predictor(x, mask, forward_full=True), norm=False)
keypoint_recon = out[0, :, 0].permute(1, 2, 0).detach().cpu().numpy() * 255
return mask, choices, err_array, feat, keypoint_recon.astype('uint8') |