|
import numpy as np |
|
from physion_evaluator.feature_extract_interface import PhysionFeatureExtractor |
|
from physion_evaluator.utils import DataAugmentationForVideoMAE |
|
|
|
from torch.functional import F |
|
|
|
from cwm.eval.Flow.flow_utils import get_occ_masks |
|
|
|
from cwm.model.model_factory import model_factory |
|
import torch |
|
|
|
def load_predictor( |
|
model_func_, |
|
load_path_, |
|
**kwargs): |
|
predictor = model_func_(**kwargs).eval().requires_grad_(False) |
|
|
|
did_load = predictor.load_state_dict( |
|
torch.load(load_path_, map_location=torch.device("cpu"))['model']) |
|
predictor._predictor_load_path = load_path_ |
|
print(did_load, load_path_) |
|
return predictor |
|
|
|
|
|
class CWM(PhysionFeatureExtractor): |
|
def __init__(self, model_name, aggregate_embeddings=False): |
|
super().__init__() |
|
|
|
self.model = model_factory.load_model(model_name).cuda().half() |
|
|
|
self.num_frames = self.model.num_frames |
|
|
|
self.timestamps = np.arange(self.num_frames) |
|
|
|
ps = (224 // self.model.patch_size[1]) ** 2 |
|
|
|
self.bool_masked_pos = np.zeros([ps * self.num_frames]) |
|
self.bool_masked_pos[ps * (self.num_frames - 1):] = 1 |
|
|
|
self.ps = ps |
|
|
|
self.aggregate_embeddings = aggregate_embeddings |
|
|
|
def transform(self): |
|
|
|
return DataAugmentationForVideoMAE( |
|
imagenet_normalize=True, |
|
rescale_size=224, |
|
), 150, 4 |
|
|
|
def fwd(self, videos): |
|
bool_masked_pos = torch.tensor(self.bool_masked_pos).to(videos.device).unsqueeze(0).bool() |
|
bool_masked_pos = torch.cat([bool_masked_pos] * videos.shape[0]) |
|
x_encoded = self.model(videos.half(), bool_masked_pos, forward_full=True, |
|
return_features=True) |
|
return x_encoded |
|
|
|
def extract_features(self, videos, for_flow=False): |
|
''' |
|
videos: [B, T, C, H, W], T is usually 4 and videos are normalized with imagenet norm |
|
returns: [B, T, D] extracted features |
|
''' |
|
|
|
videos = videos.transpose(1, 2) |
|
|
|
all_features = [] |
|
|
|
|
|
videos = torch.cat([videos, videos[:, :, -1:]], dim=2) |
|
|
|
for x in range(0, 4, self.num_frames - 1): |
|
vid = videos[:, :, x:x + self.num_frames, :, :] |
|
all_features.append(self.fwd(vid)) |
|
if self.aggregate_embeddings: |
|
feats = all_features[-1].mean(dim=1, keepdim=True) |
|
all_features[-1] = feats |
|
|
|
|
|
|
|
|
|
x_encoded = torch.cat(all_features, dim=1) |
|
|
|
return x_encoded |
|
|
|
|
|
class CWM_Keypoints(PhysionFeatureExtractor): |
|
def __init__(self, model_name): |
|
super().__init__() |
|
|
|
self.model = model_factory.load_model(model_name).cuda().half() |
|
|
|
self.frames = [[0, 1, 2], [1, 2, 3]] |
|
|
|
self.num_frames = self.model.num_frames |
|
|
|
self.ps = (224 // self.model.patch_size[1]) ** 2 |
|
|
|
self.bool_masked_pos = np.zeros([self.ps * self.num_frames]) |
|
self.bool_masked_pos[self.ps * (self.num_frames - 1):] = 1 |
|
|
|
self.frame_gap = 150 |
|
|
|
self.num_frames_dataset = 4 |
|
|
|
self.res = 224 |
|
|
|
|
|
def transform(self): |
|
|
|
return DataAugmentationForVideoMAE( |
|
imagenet_normalize=True, |
|
rescale_size=self.res, |
|
), self.frame_gap, self.num_frames_dataset |
|
|
|
def fwd(self, videos): |
|
bool_masked_pos = torch.tensor(self.bool_masked_pos).to(videos.device).unsqueeze(0).bool() |
|
bool_masked_pos = torch.cat([bool_masked_pos] * videos.shape[0]) |
|
_, x_encoded = self.model(videos.half(), bool_masked_pos, forward_full=True, |
|
return_features=True) |
|
return x_encoded |
|
|
|
def extract_features(self, videos, segments=None): |
|
''' |
|
videos: [B, T, C, H, W], T is usually 4 and videos are normalized with imagenet norm |
|
returns: [B, T, D] extracted features |
|
''' |
|
|
|
videos = videos.transpose(1, 2) |
|
|
|
all_features = [] |
|
|
|
for x, arr in enumerate(self.frames): |
|
|
|
|
|
vid = videos[:, :, arr, :, :].half() |
|
frame0 = vid[:, :, 0] |
|
frame1 = vid[:, :, 1] |
|
frame2 = vid[:, :, 2] |
|
|
|
|
|
mask, choices, err_array, k_feat, keypoint_recon = self.model.get_keypoints(frame0, frame1, frame2, 10, 1) |
|
|
|
|
|
k_feat = k_feat.view(k_feat.shape[0], -1) |
|
|
|
all_features.append(k_feat) |
|
|
|
x_encoded = torch.cat(all_features, dim=1) |
|
|
|
return x_encoded |
|
|
|
|
|
class CWM_KeypointsFlow(PhysionFeatureExtractor): |
|
def __init__(self, model_name): |
|
super().__init__() |
|
|
|
self.model = model_factory.load_model(model_name).cuda().half() |
|
|
|
self.frames = [[0, 3, 6], [3, 6, 9], [6, 9, 9]] |
|
|
|
self.num_frames = self.model.num_frames |
|
|
|
self.timestamps = np.arange(self.num_frames) |
|
|
|
self.ps = (224 // self.model.patch_size[1]) ** 2 |
|
|
|
self.bool_masked_pos = np.zeros([self.ps * self.num_frames]) |
|
self.bool_masked_pos[self.ps * (self.num_frames - 1):] = 1 |
|
|
|
self.frame_gap = 50 |
|
|
|
self.num_frames_dataset = 9 |
|
|
|
self.res = 512 |
|
|
|
def transform(self): |
|
|
|
return DataAugmentationForVideoMAE( |
|
imagenet_normalize=True, |
|
rescale_size=self.res, |
|
), self.frame_gap, self.num_frames_dataset |
|
|
|
def fwd(self, videos): |
|
bool_masked_pos = torch.tensor(self.bool_masked_pos).to(videos.device).unsqueeze(0).bool() |
|
bool_masked_pos = torch.cat([bool_masked_pos] * videos.shape[0]) |
|
_, x_encoded = self.model(videos.half(), bool_masked_pos, forward_full=True, |
|
return_features=True) |
|
return x_encoded |
|
|
|
def get_forward_flow(self, videos): |
|
|
|
fid = 6 |
|
|
|
forward_flow = self.model.get_flow(videos[:, :, fid], videos[:, :, fid + 1], conditioning_img=videos[:, :, fid + 2], mode='cosine') |
|
|
|
backward_flow = self.model.get_flow(videos[:, :, fid + 1], videos[:, :, fid], conditioning_img=videos[:, :, fid - 1], mode='cosine') |
|
|
|
occlusion_mask = get_occ_masks(forward_flow, backward_flow)[0] |
|
|
|
forward_flow = forward_flow * occlusion_mask |
|
|
|
forward_flow = torch.stack([forward_flow, forward_flow, forward_flow], dim=1) |
|
|
|
forward_flow = forward_flow.to(videos.device) |
|
|
|
forward_flow = F.interpolate(forward_flow, size=(2, 224, 224), mode='nearest') |
|
|
|
return forward_flow |
|
|
|
def extract_features(self, videos, segments=None): |
|
''' |
|
videos: [B, T, C, H, W], T is usually 4 and videos are normalized with imagenet norm |
|
returns: [B, T, D] extracted features |
|
Note: |
|
For efficiency, the optical flow is computed and added for a single frame (300ms) as we found this to be sufficient |
|
for capturing temporal dynamics in our experiments. This approach can be extended to multiple frames if needed, |
|
depending on the complexity of the task. |
|
''' |
|
|
|
|
|
|
|
videos_downsampled = F.interpolate(videos.flatten(0, 1), size=(224, 224), mode='bilinear', align_corners=False) |
|
videos_downsampled = videos_downsampled.view(videos.shape[0], videos.shape[1], videos.shape[2], 224, 224) |
|
|
|
|
|
videos_ = F.interpolate(videos.flatten(0, 1), size=(1024, 1024), mode='bilinear', align_corners=False) |
|
videos = videos_.view(videos.shape[0], videos.shape[1], videos.shape[2], 1024, 1024) |
|
|
|
videos = videos.transpose(1, 2).half() |
|
videos_downsampled = videos_downsampled.transpose(1, 2).half() |
|
|
|
|
|
forward_flow = self.get_forward_flow(videos) |
|
|
|
|
|
assert not torch.isnan(forward_flow).any(), "Forward flow is nan" |
|
|
|
all_features = [] |
|
|
|
for x, arr in enumerate(self.frames): |
|
|
|
|
|
vid = videos_downsampled[:, :, arr, :, :] |
|
frame0 = vid[:, :, 0] |
|
frame1 = vid[:, :, 1] |
|
frame2 = vid[:, :, 2] |
|
|
|
|
|
mask, choices, err_array, k_feat, keypoint_recon = self.model.get_keypoints(frame0, frame1, frame2, 10, 1) |
|
|
|
|
|
if (x == 2): |
|
k_feat = k_feat[:, -10:, :] |
|
|
|
|
|
k_feat = k_feat.view(k_feat.shape[0], -1) |
|
|
|
choices_image_resolution = choices * self.model.patch_size[1] |
|
|
|
|
|
|
|
if x == 0: |
|
|
|
flow_keyp = forward_flow[:, 2] |
|
|
|
|
|
|
|
flow = torch.zeros(vid.shape[0], 8 * 8 * 2, 10).to(videos.device) |
|
|
|
|
|
shift = 8 |
|
|
|
|
|
for b in range(flow_keyp.size(0)): |
|
|
|
x_indices = choices_image_resolution[b, :, 0] |
|
y_indices = choices_image_resolution[b, :, 1] |
|
|
|
|
|
for ind in range(10): |
|
|
|
|
|
flow[b, :, ind] = flow_keyp[b, :, y_indices[ind]:y_indices[ind] + shift, |
|
x_indices[ind]:x_indices[ind] + shift].flatten() |
|
|
|
|
|
flow = flow.view(flow.shape[0], -1) |
|
|
|
|
|
k_feat = torch.cat([k_feat, flow], dim=1) |
|
|
|
all_features.append(k_feat) |
|
|
|
x_encoded = torch.cat(all_features, dim=1) |
|
|
|
return x_encoded |
|
|
|
|
|
class CWM_base_8x8_3frame(CWM): |
|
def __init__(self,): |
|
super().__init__('vitb_8x8patch_3frames') |
|
|
|
class CWM_base_8x8_3frame_mean_embed(CWM): |
|
def __init__(self,): |
|
super().__init__('vitb_8x8patch_3frames', aggregate_embeddings=True) |
|
|
|
|
|
class CWM_base_8x8_3frame_keypoints(CWM_Keypoints): |
|
def __init__(self,): |
|
super().__init__('vitb_8x8patch_3frames') |
|
|
|
|
|
|
|
class CWM_base_8x8_3frame_keypoints_flow(CWM_KeypointsFlow): |
|
def __init__(self,): |
|
super().__init__('vitb_8x8patch_3frames') |
|
|
|
|