videomae-vis / src /utils.py
SauravMaheshkar's picture
chore: refactor src
1d4cc3a unverified
raw
history blame
4.14 kB
from typing import List, Tuple
import matplotlib.pyplot as plt
import numpy as np
import torch
from decord import VideoReader, cpu
from einops import rearrange
from PIL import Image
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from torchvision import transforms
from torchvision.transforms import ToPILImage
def get_frames(
path: str, transform: transforms.Compose, num_frames: int = 16
) -> Tuple[torch.Tensor, List[int]]:
vr = VideoReader(path, ctx=cpu(0))
tmp = np.arange(0, num_frames * 2, 2) + 60
frame_id_list = tmp.tolist()
video_data = vr.get_batch(frame_id_list).asnumpy()
frames, _ = transform(
(
[
Image.fromarray(video_data[vid, :, :, :]).convert("RGB")
for vid, _ in enumerate(frame_id_list)
],
None,
)
)
frames = frames.view((num_frames, 3) + frames.size()[-2:]).transpose(0, 1)
return frames, frame_id_list
def prepare_frames_masks(
frames: torch.Tensor, masks: torch.Tensor, device: "torch.device"
) -> Tuple[torch.Tensor, torch.Tensor]:
frames = frames.unsqueeze(0)
masks = masks.unsqueeze(0)
frames = frames.to(device, non_blocking=True)
masks = masks.to(device, non_blocking=True).flatten(1).to(torch.bool)
return frames, masks
def get_videomae_outputs(
frames: torch.Tensor,
masks: torch.Tensor,
outputs: torch.Tensor,
ids: List[int],
patch_size: Tuple[int, ...],
device: "torch.device",
):
visualisations = []
mean = torch.as_tensor(IMAGENET_DEFAULT_MEAN).to(device)[None, :, None, None, None]
std = torch.as_tensor(IMAGENET_DEFAULT_STD).to(device)[None, :, None, None, None]
ori_img = frames * std + mean # in [0, 1]
original_images = [
ToPILImage()(ori_img[0, :, vid, :, :].cpu()) for vid, _ in enumerate(ids)
]
img_squeeze = rearrange(
ori_img,
"b c (t p0) (h p1) (w p2) -> b (t h w) (p0 p1 p2) c",
p0=2,
p1=patch_size[0],
p2=patch_size[0],
)
img_norm = (img_squeeze - img_squeeze.mean(dim=-2, keepdim=True)) / (
img_squeeze.var(dim=-2, unbiased=True, keepdim=True).sqrt() + 1e-6
)
img_patch = rearrange(img_norm, "b n p c -> b n (p c)")
img_patch[masks] = outputs
# make mask
mask = torch.ones_like(img_patch)
mask[masks] = 0
mask = rearrange(mask, "b n (p c) -> b n p c", c=3)
mask = rearrange(
mask,
"b (t h w) (p0 p1 p2) c -> b c (t p0) (h p1) (w p2) ",
p0=2,
p1=patch_size[0],
p2=patch_size[1],
h=14,
w=14,
)
# save reconstruction video
rec_img = rearrange(img_patch, "b n (p c) -> b n p c", c=3)
rec_img = rec_img * (
img_squeeze.var(dim=-2, unbiased=True, keepdim=True).sqrt() + 1e-6
) + img_squeeze.mean(dim=-2, keepdim=True)
rec_img = rearrange(
rec_img,
"b (t h w) (p0 p1 p2) c -> b c (t p0) (h p1) (w p2)",
p0=2,
p1=patch_size[0],
p2=patch_size[1],
h=14,
w=14,
)
reconstructed_images = [
ToPILImage()(rec_img[0, :, vid, :, :].cpu().clamp(0, 0.996))
for vid, _ in enumerate(ids)
]
# save masked video
img_mask = rec_img * mask
masked_images = [
ToPILImage()(img_mask[0, :, vid, :, :].cpu()) for vid, _ in enumerate(ids)
]
assert len(original_images) == len(reconstructed_images) == len(masked_images)
for i in range(len(original_images)):
visualisations.append(
[original_images[i], masked_images[i], reconstructed_images[i]]
)
return visualisations
def create_plot(images):
num_cols = 3
num_rows = 16
column_names = ["Original Patch", "Masked Patch", "Reconstructed Patch"]
fig, axes = plt.subplots(num_rows, num_cols, figsize=(12, 48))
for i in range(num_rows):
for j in range(num_cols):
axes[i, j].imshow(images[i][j])
axes[i, j].axis("off")
if i == 0:
axes[i, j].set_title(column_names[j], fontsize=16)
plt.tight_layout()
return fig