|
def scaling_fixed_get_vmae_optical_flow_crop_batched_smoothed(generator, |
|
mask_generator, |
|
img1, |
|
img2, |
|
neg_back_flow=True, |
|
num_scales=1, |
|
min_scale=400, |
|
N_mask_samples=100, |
|
mask_ratio=0.8, |
|
smoothing_factor=1): |
|
B = img1.shape[0] |
|
assert len(img1.shape) == 4 |
|
assert num_scales >= 1 |
|
|
|
|
|
h1 = img2.shape[-2] |
|
w1 = img2.shape[-1] |
|
assert min_scale < h1 and min_scale >= 360 |
|
|
|
if neg_back_flow is False: |
|
print('WARNING: Not calculating negative backward flow') |
|
|
|
alpha = (min_scale / img1.shape[-2]) ** (1 / (num_scales - 1)) if num_scales > 1 else 1 |
|
|
|
frame_size = 224 // generator.patch_size[-1] |
|
|
|
all_fwd_flows_e2d = [] |
|
|
|
s_hs = [] |
|
s_ws = [] |
|
|
|
for aidx in range(num_scales): |
|
print(aidx) |
|
|
|
|
|
|
|
img1_scaled = F.interpolate(img1.clone(), [int((alpha ** aidx) * h1), int((alpha ** aidx) * w1)], |
|
mode='bicubic', align_corners=True) |
|
img2_scaled = F.interpolate(img2.clone(), [int((alpha ** aidx) * h1), int((alpha ** aidx) * w1)], |
|
mode='bicubic', align_corners=True) |
|
|
|
h2 = img2_scaled.shape[-2] |
|
w2 = img2_scaled.shape[-1] |
|
|
|
s_h = h1 / h2 |
|
s_w = w1 / w2 |
|
|
|
s_hs.append(s_h) |
|
s_ws.append(s_w) |
|
|
|
|
|
if neg_back_flow is True: |
|
video = torch.cat([img2_scaled.unsqueeze(1), img1_scaled.unsqueeze(1)], 1) |
|
else: |
|
video = torch.cat([img1_scaled.unsqueeze(1), img2_scaled.unsqueeze(1)], 1) |
|
|
|
|
|
crops1, c_pos1 = get_minimal_224_crops_new_batched(video, 1) |
|
|
|
num_crops = len(crops1) |
|
|
|
crop_flows_enc = [] |
|
crop_flows_enc2dec = [] |
|
N_samples = N_mask_samples |
|
|
|
crop = torch.cat(crops1, 0).cuda() |
|
|
|
optical_flows_enc2dec = torch.zeros(B * num_crops, 2, frame_size, frame_size).cuda() |
|
mask_counts = torch.zeros(frame_size, frame_size).cuda() |
|
|
|
i = 0 |
|
while i < N_samples or (mask_counts == 0).any().item(): |
|
if i % 100 == 0: |
|
pass |
|
mask_generator.mask_ratio = mask_ratio |
|
|
|
|
|
mask = mask_generator()[None] |
|
mask_2f = ~mask[0, frame_size * frame_size:] |
|
mask_counts += mask_2f.reshape(frame_size, frame_size) |
|
|
|
with torch.cuda.amp.autocast(enabled=True): |
|
|
|
processed_x = generator._preprocess(crop) |
|
|
|
encoder_out = generator.predictor.encoder(processed_x.to(torch.float16), mask.repeat(B * num_crops, 1)) |
|
encoder_to_decoder = generator.predictor.encoder_to_decoder(encoder_out) |
|
|
|
optical_flow_e2d = [] |
|
|
|
for b in range(B * num_crops): |
|
batch_flow = compute_optical_flow(encoder_to_decoder[b].unsqueeze(0), mask, frame_size) |
|
optical_flow_e2d.append(average_crops(batch_flow, smoothing_factor).unsqueeze(0)) |
|
|
|
optical_flow_e2d = torch.cat(optical_flow_e2d, 0) |
|
optical_flows_enc2dec += optical_flow_e2d |
|
i += 1 |
|
|
|
optical_flows_enc2dec = optical_flows_enc2dec / mask_counts |
|
|
|
|
|
crop_flows_enc2dec = optical_flows_enc2dec.split(B, 0) |
|
|
|
T1 = [F.interpolate(_, [int(224), int(224)], mode='bicubic', align_corners=True).unsqueeze(1).cpu() for _ in |
|
crop_flows_enc2dec] |
|
|
|
optical_flows_enc2dec_joined = reconstruct_video_new_2_batched(T1, c_pos1, ( |
|
B, 1, 2, video.shape[-2], video.shape[-1])).squeeze(1) |
|
|
|
all_fwd_flows_e2d.append(optical_flows_enc2dec_joined) |
|
|
|
all_fwd_flows_e2d_new = [] |
|
|
|
for ridx, r in enumerate(all_fwd_flows_e2d): |
|
|
|
|
|
|
|
|
|
|
|
|
|
_sh = s_hs[ridx] |
|
_sw = s_ws[ridx] |
|
_sfy = generator.patch_size[-1] |
|
_sfx = generator.patch_size[-1] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
new_r = F.interpolate(r, [int(all_fwd_flows_e2d[0].shape[-2]), int(all_fwd_flows_e2d[0].shape[-1])], |
|
mode='bicubic', align_corners=True) |
|
|
|
|
|
scaled_new_r = torch.zeros_like(new_r) |
|
scaled_new_r[:, 0, :, :] = new_r[:, 0, :, :] * _sfx * _sw |
|
scaled_new_r[:, 1, :, :] = new_r[:, 1, :, :] * _sfy * _sh |
|
|
|
|
|
|
|
|
|
|
|
|
|
all_fwd_flows_e2d_new.append(scaled_new_r.unsqueeze(-1)) |
|
return_flow = torch.cat(all_fwd_flows_e2d_new, -1).mean(-1) |
|
|
|
if neg_back_flow is True: |
|
return_flow = -return_flow |
|
all_fwd_flows_e2d_new = [-_ for _ in all_fwd_flows_e2d_new] |
|
|
|
return return_flow, all_fwd_flows_e2d_new |