File size: 5,922 Bytes
6dfcb0f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 |
def scaling_fixed_get_vmae_optical_flow_crop_batched_smoothed(generator,
mask_generator,
img1,
img2,
neg_back_flow=True,
num_scales=1,
min_scale=400,
N_mask_samples=100,
mask_ratio=0.8,
smoothing_factor=1):
B = img1.shape[0]
assert len(img1.shape) == 4
assert num_scales >= 1
# For scaling
h1 = img2.shape[-2]
w1 = img2.shape[-1]
assert min_scale < h1 and min_scale >= 360 # Below 360p, the flows look terrible
if neg_back_flow is False:
print('WARNING: Not calculating negative backward flow')
alpha = (min_scale / img1.shape[-2]) ** (1 / (num_scales - 1)) if num_scales > 1 else 1
frame_size = 224 // generator.patch_size[-1]
all_fwd_flows_e2d = []
s_hs = []
s_ws = []
for aidx in range(num_scales):
print(aidx)
# print('aidx: ', aidx)
img1_scaled = F.interpolate(img1.clone(), [int((alpha ** aidx) * h1), int((alpha ** aidx) * w1)],
mode='bicubic', align_corners=True)
img2_scaled = F.interpolate(img2.clone(), [int((alpha ** aidx) * h1), int((alpha ** aidx) * w1)],
mode='bicubic', align_corners=True)
h2 = img2_scaled.shape[-2]
w2 = img2_scaled.shape[-1]
s_h = h1 / h2
s_w = w1 / w2
s_hs.append(s_h)
s_ws.append(s_w)
# Because technically the compute_optical_flow function returns neg back flow
if neg_back_flow is True:
video = torch.cat([img2_scaled.unsqueeze(1), img1_scaled.unsqueeze(1)], 1)
else:
video = torch.cat([img1_scaled.unsqueeze(1), img2_scaled.unsqueeze(1)], 1)
# Should work, even if the incoming video is already 224x224
crops1, c_pos1 = get_minimal_224_crops_new_batched(video, 1)
num_crops = len(crops1)
crop_flows_enc = []
crop_flows_enc2dec = []
N_samples = N_mask_samples
crop = torch.cat(crops1, 0).cuda()
optical_flows_enc2dec = torch.zeros(B * num_crops, 2, frame_size, frame_size).cuda()
mask_counts = torch.zeros(frame_size, frame_size).cuda()
i = 0
while i < N_samples or (mask_counts == 0).any().item():
if i % 100 == 0:
pass # print(i)
mask_generator.mask_ratio = mask_ratio
# This would be that every sample has the same mask. For now that's okay I think
mask = mask_generator()[None]
mask_2f = ~mask[0, frame_size * frame_size:]
mask_counts += mask_2f.reshape(frame_size, frame_size)
with torch.cuda.amp.autocast(enabled=True):
processed_x = generator._preprocess(crop)
encoder_out = generator.predictor.encoder(processed_x.to(torch.float16), mask.repeat(B * num_crops, 1))
encoder_to_decoder = generator.predictor.encoder_to_decoder(encoder_out)
optical_flow_e2d = []
# one per batch element for now
for b in range(B * num_crops):
batch_flow = compute_optical_flow(encoder_to_decoder[b].unsqueeze(0), mask, frame_size)
optical_flow_e2d.append(average_crops(batch_flow, smoothing_factor).unsqueeze(0))
optical_flow_e2d = torch.cat(optical_flow_e2d, 0)
optical_flows_enc2dec += optical_flow_e2d
i += 1
optical_flows_enc2dec = optical_flows_enc2dec / mask_counts
# split the crops back up
crop_flows_enc2dec = optical_flows_enc2dec.split(B, 0)
T1 = [F.interpolate(_, [int(224), int(224)], mode='bicubic', align_corners=True).unsqueeze(1).cpu() for _ in
crop_flows_enc2dec]
optical_flows_enc2dec_joined = reconstruct_video_new_2_batched(T1, c_pos1, (
B, 1, 2, video.shape[-2], video.shape[-1])).squeeze(1)
all_fwd_flows_e2d.append(optical_flows_enc2dec_joined)
all_fwd_flows_e2d_new = []
for ridx, r in enumerate(all_fwd_flows_e2d):
# print('ridx', ridx)
# print('sh', s_hs[ridx])
# print('sw', s_ws[ridx])
# print('scale_fac y', scale_ys[ridx])
# print('scale_fac x', scale_xs[ridx])
_sh = s_hs[ridx]
_sw = s_ws[ridx]
_sfy = generator.patch_size[-1]
_sfx = generator.patch_size[-1]
# plt.figure(figsize=(20, 20))
# plt.subplot(1,3,1)
# plt.imshow(f2rgb(-r).cpu().numpy()[0].transpose(1,2,0))
# plt.subplot(1,3,2)
new_r = F.interpolate(r, [int(all_fwd_flows_e2d[0].shape[-2]), int(all_fwd_flows_e2d[0].shape[-1])],
mode='bicubic', align_corners=True)
# plt.imshow(f2rgb(-new_r).cpu().numpy()[0].transpose(1,2,0))
scaled_new_r = torch.zeros_like(new_r)
scaled_new_r[:, 0, :, :] = new_r[:, 0, :, :] * _sfx * _sw
scaled_new_r[:, 1, :, :] = new_r[:, 1, :, :] * _sfy * _sh
# plt.subplot(1,3,3)
# plt.imshow(f2rgb(-scaled_new_r).cpu().numpy()[0].transpose(1,2,0))
# plt.show()
all_fwd_flows_e2d_new.append(scaled_new_r.unsqueeze(-1))
return_flow = torch.cat(all_fwd_flows_e2d_new, -1).mean(-1)
if neg_back_flow is True:
return_flow = -return_flow
all_fwd_flows_e2d_new = [-_ for _ in all_fwd_flows_e2d_new]
return return_flow, all_fwd_flows_e2d_new |