Spaces:
Running
Running
from typing import List, Tuple | |
import matplotlib.pyplot as plt | |
import numpy as np | |
import torch | |
from decord import VideoReader, cpu | |
from einops import rearrange | |
from PIL import Image | |
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD | |
from torchvision import transforms | |
from torchvision.transforms import ToPILImage | |
def get_frames( | |
path: str, transform: transforms.Compose, num_frames: int = 16 | |
) -> Tuple[torch.Tensor, List[int]]: | |
vr = VideoReader(path, ctx=cpu(0)) | |
tmp = np.arange(0, num_frames * 2, 2) + 60 | |
frame_id_list = tmp.tolist() | |
video_data = vr.get_batch(frame_id_list).asnumpy() | |
frames, _ = transform( | |
( | |
[ | |
Image.fromarray(video_data[vid, :, :, :]).convert("RGB") | |
for vid, _ in enumerate(frame_id_list) | |
], | |
None, | |
) | |
) | |
frames = frames.view((num_frames, 3) + frames.size()[-2:]).transpose(0, 1) | |
return frames, frame_id_list | |
def prepare_frames_masks( | |
frames: torch.Tensor, masks: torch.Tensor, device: "torch.device" | |
) -> Tuple[torch.Tensor, torch.Tensor]: | |
frames = frames.unsqueeze(0) | |
masks = masks.unsqueeze(0) | |
frames = frames.to(device, non_blocking=True) | |
masks = masks.to(device, non_blocking=True).flatten(1).to(torch.bool) | |
return frames, masks | |
def get_videomae_outputs( | |
frames: torch.Tensor, | |
masks: torch.Tensor, | |
outputs: torch.Tensor, | |
ids: List[int], | |
patch_size: Tuple[int, ...], | |
device: "torch.device", | |
): | |
visualisations = [] | |
mean = torch.as_tensor(IMAGENET_DEFAULT_MEAN).to(device)[None, :, None, None, None] | |
std = torch.as_tensor(IMAGENET_DEFAULT_STD).to(device)[None, :, None, None, None] | |
ori_img = frames * std + mean # in [0, 1] | |
original_images = [ | |
ToPILImage()(ori_img[0, :, vid, :, :].cpu()) for vid, _ in enumerate(ids) | |
] | |
img_squeeze = rearrange( | |
ori_img, | |
"b c (t p0) (h p1) (w p2) -> b (t h w) (p0 p1 p2) c", | |
p0=2, | |
p1=patch_size[0], | |
p2=patch_size[0], | |
) | |
img_norm = (img_squeeze - img_squeeze.mean(dim=-2, keepdim=True)) / ( | |
img_squeeze.var(dim=-2, unbiased=True, keepdim=True).sqrt() + 1e-6 | |
) | |
img_patch = rearrange(img_norm, "b n p c -> b n (p c)") | |
img_patch[masks] = outputs | |
# make mask | |
mask = torch.ones_like(img_patch) | |
mask[masks] = 0 | |
mask = rearrange(mask, "b n (p c) -> b n p c", c=3) | |
mask = rearrange( | |
mask, | |
"b (t h w) (p0 p1 p2) c -> b c (t p0) (h p1) (w p2) ", | |
p0=2, | |
p1=patch_size[0], | |
p2=patch_size[1], | |
h=14, | |
w=14, | |
) | |
# save reconstruction video | |
rec_img = rearrange(img_patch, "b n (p c) -> b n p c", c=3) | |
rec_img = rec_img * ( | |
img_squeeze.var(dim=-2, unbiased=True, keepdim=True).sqrt() + 1e-6 | |
) + img_squeeze.mean(dim=-2, keepdim=True) | |
rec_img = rearrange( | |
rec_img, | |
"b (t h w) (p0 p1 p2) c -> b c (t p0) (h p1) (w p2)", | |
p0=2, | |
p1=patch_size[0], | |
p2=patch_size[1], | |
h=14, | |
w=14, | |
) | |
reconstructed_images = [ | |
ToPILImage()(rec_img[0, :, vid, :, :].cpu().clamp(0, 0.996)) | |
for vid, _ in enumerate(ids) | |
] | |
# save masked video | |
img_mask = rec_img * mask | |
masked_images = [ | |
ToPILImage()(img_mask[0, :, vid, :, :].cpu()) for vid, _ in enumerate(ids) | |
] | |
assert len(original_images) == len(reconstructed_images) == len(masked_images) | |
for i in range(len(original_images)): | |
visualisations.append( | |
[original_images[i], masked_images[i], reconstructed_images[i]] | |
) | |
return visualisations | |
def create_plot(images): | |
num_cols = 3 | |
num_rows = 16 | |
column_names = ["Original Patch", "Masked Patch", "Reconstructed Patch"] | |
fig, axes = plt.subplots(num_rows, num_cols, figsize=(12, 48)) | |
for i in range(num_rows): | |
for j in range(num_cols): | |
axes[i, j].imshow(images[i][j]) | |
axes[i, j].axis("off") | |
if i == 0: | |
axes[i, j].set_title(column_names[j], fontsize=16) | |
plt.tight_layout() | |
return fig | |