hma / visualize.py
LeroyWaa's picture
draft
246c106
raw
history blame
13.9 kB
#!/usr/bin/env python3
"""
Script to decode tokenized video into images/video.
Example usage: See https://github.com/1x-technologies/1xgpt?tab=readme-ov-file#1x-genie-baseline
"""
import argparse
import math
import os
from PIL import Image, ImageDraw
import numpy as np
import torch
import torch.distributed.optim
import torch.utils.checkpoint
import torch.utils.data
import torchvision.transforms.v2.functional as transforms_f
from diffusers import AutoencoderKLTemporalDecoder
from einops import rearrange
from matplotlib import pyplot as plt
from cont_data import RawFeatureDataset
from data import RawTokenDataset
from datasets.utils import get_image_encoder
from magvit2.config import VQConfig
from magvit2.models.lfqgan import VQModel
from common.eval_utils import decode_tokens, decode_features
import wandb
wandb.login(key='4c1540ebf8cb9964703ac212a937c00848a79b67')
SVD_SCALE = 0.18215
def parse_args():
parser = argparse.ArgumentParser(description="Visualize tokenized video as GIF or comic.")
parser.add_argument(
"--stride",
type=int,
default=1,
help="Frame skip",
)
parser.add_argument(
"--token_dir",
type=str,
default="data/genie_generated",
help="Directory of tokens, in the format of `video.bin` and `metadata.json`. "
"Visualized gif and comic will be written here.",
)
parser.add_argument(
"--offset", type=int, default=0, help="Offset to start generating images from"
)
parser.add_argument(
"--fps", type=int, default=2, help="Frames per second"
)
parser.add_argument(
"--max_images", type=int, default=None, help="Maximum number of images to generate. None for all."
)
parser.add_argument(
"--example_ind", type=int, default=0,
help="The index in the dataset of the example to generate on."
)
parser.add_argument(
"--project_prefix", type=str, default="", help="Project suffix."
)
parser.add_argument(
"--disable_comic", action="store_true",
help="Comic generation assumes `token_dir` follows the same format as generate: e.g., "
"`prompt | predictions | gtruth` in `video.bin`, `window_size` in `metadata.json`."
"Therefore, comic should be disabled when visualizing videos without this format, such as the dataset."
)
parser.add_argument(
"--batch_size", type=int, default=4,
help="Batch size, current script only supports a single GPU."
)
parser.add_argument(
"--max_example", type=int, default=4,
help="Maximum number of examples."
)
parser.add_argument(
"--use_feature", action="store_true",
help="visualize the features rather than tokens"
)
args = parser.parse_args()
return args
def export_to_gif(frames: list, output_gif_path: str, fps: int):
"""
Export a list of frames to a GIF.
Args:
- frames (list): List of frames (as numpy arrays or PIL Image objects).
- output_gif_path (str): Path to save the output GIF.
- fps (int): Desired frames per second.
"""
# Convert numpy arrays to PIL Images if needed
pil_frames = [Image.fromarray(frame) if isinstance(
frame, np.ndarray) else frame for frame in frames]
duration_ms = 1000 / fps
pil_frames[0].save(output_gif_path.replace(".mp4", ".gif"),
format="GIF",
append_images=pil_frames[1:],
save_all=True,
duration=duration_ms,
loop=0)
# return the gif
return output_gif_path.replace(".mp4", ".gif")
def unnormalize_imgs(normalized_imgs):
"""
[-1, 1] -> [0, 255]
Important: clip to [0, 255]
"""
normalized_imgs = torch.clamp(normalized_imgs, -1, 1)
rescaled_output = ((normalized_imgs.detach().cpu() + 1) * 127.5)
clipped_output = torch.clamp(rescaled_output, 0, 255).to(dtype=torch.uint8)
return clipped_output
# rescaled_output = ((normalized_imgs.detach().cpu() + 1) * 127.5)
# clipped_output = torch.clamp(rescaled_output, 0, 255).to(dtype=torch.uint8)
# return clipped_output
def decode_latents_wrapper(
batch_size: int = 16,
encoder_type: str = "magvit",
encoder_name_or_path: str = "data/magvit2.ckpt",
max_images: int = None,
device: str = "cuda",
):
dtype = torch.bfloat16 # torch.bfloat16
model = get_image_encoder(encoder_type, encoder_name_or_path)
model = model.to(device=device, dtype=dtype)
@torch.no_grad()
def decode_latents(video_data: np.array):
"""
video_data: (b, h, w) for quantized data, or (b, c, h, w) for continuous data,
where b is `batch_size` and different from training/eval batch size.
"""
decoded_imgs = []
for shard_ind in range(math.ceil(len(video_data) / batch_size)):
shard_data = video_data[shard_ind * batch_size: (shard_ind + 1) * batch_size]
if isinstance(model, VQModel): # TODO: class agnostic wrapper
# expecting quantized
assert shard_data.ndim == 3, f"{shard_data.shape=} {shard_data.dtype=}"
torch_shard = torch.from_numpy(shard_data.astype(np.int64))
# if model.use_ema: # EMA does nothing in bugged VQModel
# with model.ema_scope():
quant = model.quantize.get_codebook_entry(rearrange(torch_shard, "b h w -> b (h w)"),
bhwc=torch_shard.shape + (model.quantize.codebook_dim,)).flip(1)
normalized_imgs = model.decode(quant.to(device=device, dtype=dtype))
elif isinstance(model, AutoencoderKLTemporalDecoder):
# expecting continuous
assert shard_data.ndim == 4, f"{shard_data.shape=} {shard_data.dtype=}"
torch_shard = torch.from_numpy(shard_data)
# manual clip
# if torch_shard.shape[0] == 16:
# print("prompt torch_shard", torch_shard[:4, 0].min(), torch_shard[:4, 0].max(), torch_shard[:4, 0].mean(), torch_shard[:4, 0].std())
# print("pred torch_shard", torch_shard[4:12, 0].min(), torch_shard[4:12, 0].max(), torch_shard[4:12, 0].mean(), torch_shard[4:12, 0].std())
# print("groundtruth torch_shard", torch_shard[12:, 0].min(), torch_shard[12:, 0].max(), torch_shard[12:, 0].mean(), torch_shard[12:, 0].std())
torch_shard = torch.clamp(torch_shard, -25, 25)
normalized_imgs = model.decode(torch_shard.to(device=device, dtype=dtype), num_frames=1).sample # sample to mean
# if torch_shard.shape[0] == 16:
# print("prompt normalized_imgs", normalized_imgs[:4, 0].min(), normalized_imgs[:4, 0].max(), normalized_imgs[:4, 0].mean(), normalized_imgs[:4, 0].std())
# print("pred normalized_imgs", normalized_imgs[4:12, 0].min(), normalized_imgs[4:12, 0].max(), normalized_imgs[4:12, 0].mean(), normalized_imgs[4:12, 0].std())
# print("groundtruth normalized_imgs", normalized_imgs[12:, 0].min(), normalized_imgs[12:, 0].max(), normalized_imgs[12:, 0].mean(), normalized_imgs[12:, 0].std())
else:
raise NotImplementedError(f"{model=}")
decoded_imgs.append(unnormalize_imgs(normalized_imgs))
if max_images and len(decoded_imgs) * batch_size >= max_images:
break
return [transforms_f.to_pil_image(img) for img in torch.cat(decoded_imgs)]
return decode_latents
def caption_image(pil_image: Image, caption: str):
"""
Add a bit of empty space at the top, and add the caption there
"""
border_size = 36
font_size = 24
# convert pil_image to PIL.Image.Image if it's not already
if not isinstance(pil_image, Image.Image):
pil_image = transforms_f.to_pil_image(pil_image)
width, height = pil_image.size
new_width = width
new_height = height + border_size
new_image = Image.new("RGB", (new_width, new_height), "white")
new_image.paste(pil_image, (0, border_size))
# Draw the caption
draw = ImageDraw.Draw(new_image)
# Center text (`align` keyword doesn't work)
_, _, text_w, text_h = draw.textbbox((0, 0), caption, font_size=font_size)
draw.text(((width - text_w) / 2, (border_size - text_h) / 2), caption, fill="black", font_size=font_size)
return new_image
@torch.no_grad()
def main():
args = parse_args()
name = args.token_dir.split('/')[-2]
name_split = name.find('nodes')
model = name[:name_split-7]
dataset = name[name_split+8:]
# Load tokens
if args.use_feature:
token_dataset = RawFeatureDataset(args.token_dir, 1, compute_stride_from_freq_table=False,
filter_interrupts=False, filter_overlaps=False)
video_tokens = token_dataset.data
print(f"Loaded {video_tokens.shape=}")
else:
token_dataset = RawTokenDataset(args.token_dir, 1, compute_stride_from_freq_table=False,
filter_interrupts=False, filter_overlaps=False)
video_tokens = token_dataset.data
print(f"Loaded {video_tokens.shape=}")
metadata = token_dataset.metadata
video_tokens = video_tokens.reshape(-1, metadata["window_size"] * 2 - metadata["num_prompt_frames"], *video_tokens.shape[1:])
decode_func = decode_latents_wrapper
print(metadata)
print(f"Reshape {video_tokens.shape=}")
wandb.init(project='video_eval_vis', settings=wandb.Settings(start_method="thread"), name=f"{args.project_prefix}vis_{model}", id=f"{args.project_prefix}vis_{model}", resume="allow")
for example_id in range(min(args.max_example, len(video_tokens))):
if args.use_feature:
if "encoder_type" not in metadata:
metadata["encoder_type"] = "temporalvae"
metadata["encoder_name_or_path"] = "stabilityai/stable-video-diffusion-img2vid"
decode_latents = decode_func(max_images=args.max_images, encoder_name_or_path=metadata["encoder_name_or_path"],
encoder_type=metadata["encoder_type"]) # args.offset::args.stride
this_video_token = torch.FloatTensor(video_tokens[example_id].copy())[None] / SVD_SCALE
this_video_token = rearrange(this_video_token, "b t c h w -> b t h w c")
video_frames = decode_features(this_video_token, decode_latents)
video_frames = rearrange(video_frames, "b t c h w -> b t h w c")
video_frames = video_frames.detach().cpu().numpy()[0].astype(np.uint8)
else:
decode_latents = decode_func(max_images=args.max_images)
this_video_token = torch.LongTensor(video_tokens[example_id])[None]
video_frames = decode_tokens(this_video_token, decode_latents)
video_frames = rearrange(video_frames, "b t c h w -> b t h w c")
video_frames = video_frames.detach().cpu().numpy()[0].astype(np.uint8)
output_gif_path = os.path.join(args.token_dir, f"example{args.offset}.gif")
# `generate` should populate `metadata.json` with these keys, while ground truth metadata does not have them
is_generated_data = all(key in metadata for key in ("num_prompt_frames", "window_size"))
if is_generated_data:
if video_tokens[example_id].shape[0] != metadata["window_size"] * 2 - metadata["num_prompt_frames"]:
raise ValueError(f"Unexpected {video_tokens.shape=} given {metadata['window_size']=}, {metadata['num_prompt_frames']=}")
captioned_frames = []
for i, frame in enumerate(video_frames):
if i < metadata["num_prompt_frames"]:
caption = "Prompt"
elif i < metadata["window_size"]:
caption = "Generated"
else:
caption = "Ground truth"
captioned_frames.append(caption_image(frame, caption))
else:
# Leave ground truth frames uncaptioned
captioned_frames = video_frames
gif_path = export_to_gif(captioned_frames, output_gif_path, args.fps)
print(f"Saved to {output_gif_path}")
if not args.disable_comic:
fig, axs = plt.subplots(nrows=2, ncols=metadata["window_size"], figsize=(3 * metadata["window_size"], 3 * 2))
for i, image in enumerate(video_frames):
if i < metadata["num_prompt_frames"]:
curr_axs = [axs[0, i], axs[1, i]]
title = "Prompt"
elif i < metadata["window_size"]:
curr_axs = [axs[0, i]]
title = "Prediction"
else:
curr_axs = [axs[1, i - metadata["window_size"] + metadata["num_prompt_frames"]]]
title = "Ground truth"
for ax in curr_axs:
ax.set_title(title)
ax.imshow(image)
ax.axis("off")
output_comic_path = os.path.join(args.token_dir, f"example{args.offset}.png")
plt.savefig(output_comic_path, bbox_inches="tight")
plt.close()
print(f"Saved to {output_comic_path}")
wandb.log({f"{dataset}/gif_{example_id}": wandb.Video(gif_path)})
# add wandb logging
# wandb.log({f"{dataset}/comic_{args.example_ind}": wandb.Image(output_comic_path)})
wandb.run.summary["model_checkpoint"] = metadata["model_checkpoint"]
wandb.run.summary["dataset"] = metadata["dataset"]
wandb.run.summary["trained_steps"] = metadata["trained_steps"]
wandb.finish()
if __name__ == "__main__":
main()