|
import torch |
|
import torch.nn as nn |
|
|
|
import cwm.model.model_pretrain as vmae_tranformers |
|
from . import flow_utils |
|
from . import losses as bblosses |
|
|
|
|
|
|
|
def l2_norm(x): |
|
return x.square().sum(-3, True).sqrt() |
|
|
|
|
|
|
|
|
|
|
|
def get_occ_masks(flow_fwd, flow_bck, occ_thresh=0.5): |
|
fwd_bck_cycle, _ = bblosses.backward_warp(img2=flow_bck, flow=flow_fwd) |
|
flow_diff_fwd = flow_fwd + fwd_bck_cycle |
|
|
|
bck_fwd_cycle, _ = bblosses.backward_warp(img2=flow_fwd, flow=flow_bck) |
|
flow_diff_bck = flow_bck + bck_fwd_cycle |
|
|
|
norm_fwd = l2_norm(flow_fwd) ** 2 + l2_norm(fwd_bck_cycle) ** 2 |
|
norm_bck = l2_norm(flow_bck) ** 2 + l2_norm(bck_fwd_cycle) ** 2 |
|
|
|
occ_thresh_fwd = occ_thresh * norm_fwd + 0.5 |
|
occ_thresh_bck = occ_thresh * norm_bck + 0.5 |
|
|
|
occ_mask_fwd = 1 - (l2_norm(flow_diff_fwd) ** 2 > occ_thresh_fwd).float() |
|
occ_mask_bck = 1 - (l2_norm(flow_diff_bck) ** 2 > occ_thresh_bck).float() |
|
|
|
return occ_mask_fwd, occ_mask_bck |
|
|
|
|
|
class ExtractFlow(nn.Module): |
|
|
|
def __init__(self): |
|
super().__init__() |
|
return |
|
|
|
def forward(self, img1, img2): |
|
''' |
|
img1: first frame |
|
img2: second frame |
|
returns: flow map (h, w, 2) |
|
''' |
|
|
|
from cwm.data.masking_generator import RotatedTableMaskingGenerator |
|
|
|
class CWM(ExtractFlow): |
|
def __init__(self, model_name, patch_size, weights_path): |
|
super().__init__() |
|
|
|
self.patch_size = patch_size |
|
model = getattr(vmae_tranformers, model_name) |
|
vmae_8x8_full = model().cuda().eval().requires_grad_(False) |
|
|
|
VMAE_LOAD_PATH = weights_path |
|
did_load = vmae_8x8_full.load_state_dict(torch.load(VMAE_LOAD_PATH)['model'], strict=False) |
|
print(did_load, VMAE_LOAD_PATH) |
|
|
|
self.predictor = vmae_8x8_full |
|
|
|
self.mask_generator = RotatedTableMaskingGenerator( |
|
input_size=(vmae_8x8_full.num_frames, 28, 28), |
|
mask_ratio=0.0, |
|
tube_length=1, |
|
batch_size=1, |
|
mask_type='rotated_table' |
|
) |
|
|
|
def forward(self, img1, img2): |
|
''' |
|
img1: [3, 1024, 1024] |
|
img1: [3, 1024, 1024] |
|
both images are imagenet normalized |
|
''' |
|
|
|
with torch.no_grad(): |
|
FF, _ = flow_utils.scaling_fixed_get_vmae_optical_flow_crop_batched_smoothed(self.predictor, |
|
self.mask_generator, img1[None], |
|
img2[None], |
|
num_scales=2, |
|
min_scale=224, |
|
N_mask_samples=1) |
|
|
|
BF, _ = flow_utils.scaling_fixed_get_vmae_optical_flow_crop_batched_smoothed(self.predictor, |
|
self.mask_generator, |
|
img2[None], |
|
img1[None], |
|
num_scales=2, |
|
min_scale=224, |
|
N_mask_samples=1) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
occ_mask = get_occ_masks(FF, BF)[0] |
|
|
|
FF = FF * occ_mask |
|
|
|
FF = FF[0] |
|
|
|
return FF |
|
|
|
|
|
class CWM_8x8(CWM): |
|
def __init__(self): |
|
super().__init__('vitb_8x8patch_3frames', 8, |
|
'/ccn2/u/honglinc/cwm_checkpoints/ablation_3frame_no_clumping_mr0.90_extra_data_ep400/checkpoint-399.pth') |
|
|