from typing import List, Tuple import matplotlib.pyplot as plt import numpy as np import torch from decord import VideoReader, cpu from einops import rearrange from PIL import Image from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from torchvision import transforms from torchvision.transforms import ToPILImage def get_frames( path: str, transform: transforms.Compose, num_frames: int = 16 ) -> Tuple[torch.Tensor, List[int]]: vr = VideoReader(path, ctx=cpu(0)) tmp = np.arange(0, num_frames * 2, 2) + 60 frame_id_list = tmp.tolist() video_data = vr.get_batch(frame_id_list).asnumpy() frames, _ = transform( ( [ Image.fromarray(video_data[vid, :, :, :]).convert("RGB") for vid, _ in enumerate(frame_id_list) ], None, ) ) frames = frames.view((num_frames, 3) + frames.size()[-2:]).transpose(0, 1) return frames, frame_id_list def prepare_frames_masks( frames: torch.Tensor, masks: torch.Tensor, device: "torch.device" ) -> Tuple[torch.Tensor, torch.Tensor]: frames = frames.unsqueeze(0) masks = masks.unsqueeze(0) frames = frames.to(device, non_blocking=True) masks = masks.to(device, non_blocking=True).flatten(1).to(torch.bool) return frames, masks def get_videomae_outputs( frames: torch.Tensor, masks: torch.Tensor, outputs: torch.Tensor, ids: List[int], patch_size: Tuple[int, ...], device: "torch.device", ): visualisations = [] mean = torch.as_tensor(IMAGENET_DEFAULT_MEAN).to(device)[None, :, None, None, None] std = torch.as_tensor(IMAGENET_DEFAULT_STD).to(device)[None, :, None, None, None] ori_img = frames * std + mean # in [0, 1] original_images = [ ToPILImage()(ori_img[0, :, vid, :, :].cpu()) for vid, _ in enumerate(ids) ] img_squeeze = rearrange( ori_img, "b c (t p0) (h p1) (w p2) -> b (t h w) (p0 p1 p2) c", p0=2, p1=patch_size[0], p2=patch_size[0], ) img_norm = (img_squeeze - img_squeeze.mean(dim=-2, keepdim=True)) / ( img_squeeze.var(dim=-2, unbiased=True, keepdim=True).sqrt() + 1e-6 ) img_patch = rearrange(img_norm, "b n p c -> b n (p c)") img_patch[masks] = outputs # make mask mask = torch.ones_like(img_patch) mask[masks] = 0 mask = rearrange(mask, "b n (p c) -> b n p c", c=3) mask = rearrange( mask, "b (t h w) (p0 p1 p2) c -> b c (t p0) (h p1) (w p2) ", p0=2, p1=patch_size[0], p2=patch_size[1], h=14, w=14, ) # save reconstruction video rec_img = rearrange(img_patch, "b n (p c) -> b n p c", c=3) rec_img = rec_img * ( img_squeeze.var(dim=-2, unbiased=True, keepdim=True).sqrt() + 1e-6 ) + img_squeeze.mean(dim=-2, keepdim=True) rec_img = rearrange( rec_img, "b (t h w) (p0 p1 p2) c -> b c (t p0) (h p1) (w p2)", p0=2, p1=patch_size[0], p2=patch_size[1], h=14, w=14, ) reconstructed_images = [ ToPILImage()(rec_img[0, :, vid, :, :].cpu().clamp(0, 0.996)) for vid, _ in enumerate(ids) ] # save masked video img_mask = rec_img * mask masked_images = [ ToPILImage()(img_mask[0, :, vid, :, :].cpu()) for vid, _ in enumerate(ids) ] assert len(original_images) == len(reconstructed_images) == len(masked_images) for i in range(len(original_images)): visualisations.append( [original_images[i], masked_images[i], reconstructed_images[i]] ) return visualisations def create_plot(images): num_cols = 3 num_rows = 16 column_names = ["Original Patch", "Masked Patch", "Reconstructed Patch"] fig, axes = plt.subplots(num_rows, num_cols, figsize=(12, 48)) for i in range(num_rows): for j in range(num_cols): axes[i, j].imshow(images[i][j]) axes[i, j].axis("off") if i == 0: axes[i, j].set_title(column_names[j], fontsize=16) plt.tight_layout() return fig