Spaces:
Running
on
Zero
Running
on
Zero
#!/usr/bin/env python3 | |
""" | |
Script to decode tokenized video into images/video. | |
Example usage: See https://github.com/1x-technologies/1xgpt?tab=readme-ov-file#1x-genie-baseline | |
""" | |
import argparse | |
import math | |
import os | |
from PIL import Image, ImageDraw | |
import numpy as np | |
import torch | |
import torch.distributed.optim | |
import torch.utils.checkpoint | |
import torch.utils.data | |
import torchvision.transforms.v2.functional as transforms_f | |
from diffusers import AutoencoderKLTemporalDecoder | |
from einops import rearrange | |
from matplotlib import pyplot as plt | |
from cont_data import RawFeatureDataset | |
from data import RawTokenDataset | |
from datasets.utils import get_image_encoder | |
from magvit2.config import VQConfig | |
from magvit2.models.lfqgan import VQModel | |
from common.eval_utils import decode_tokens, decode_features | |
import wandb | |
wandb.login(key='4c1540ebf8cb9964703ac212a937c00848a79b67') | |
SVD_SCALE = 0.18215 | |
def parse_args(): | |
parser = argparse.ArgumentParser(description="Visualize tokenized video as GIF or comic.") | |
parser.add_argument( | |
"--stride", | |
type=int, | |
default=1, | |
help="Frame skip", | |
) | |
parser.add_argument( | |
"--token_dir", | |
type=str, | |
default="data/genie_generated", | |
help="Directory of tokens, in the format of `video.bin` and `metadata.json`. " | |
"Visualized gif and comic will be written here.", | |
) | |
parser.add_argument( | |
"--offset", type=int, default=0, help="Offset to start generating images from" | |
) | |
parser.add_argument( | |
"--fps", type=int, default=2, help="Frames per second" | |
) | |
parser.add_argument( | |
"--max_images", type=int, default=None, help="Maximum number of images to generate. None for all." | |
) | |
parser.add_argument( | |
"--example_ind", type=int, default=0, | |
help="The index in the dataset of the example to generate on." | |
) | |
parser.add_argument( | |
"--project_prefix", type=str, default="", help="Project suffix." | |
) | |
parser.add_argument( | |
"--disable_comic", action="store_true", | |
help="Comic generation assumes `token_dir` follows the same format as generate: e.g., " | |
"`prompt | predictions | gtruth` in `video.bin`, `window_size` in `metadata.json`." | |
"Therefore, comic should be disabled when visualizing videos without this format, such as the dataset." | |
) | |
parser.add_argument( | |
"--batch_size", type=int, default=4, | |
help="Batch size, current script only supports a single GPU." | |
) | |
parser.add_argument( | |
"--max_example", type=int, default=4, | |
help="Maximum number of examples." | |
) | |
parser.add_argument( | |
"--use_feature", action="store_true", | |
help="visualize the features rather than tokens" | |
) | |
args = parser.parse_args() | |
return args | |
def export_to_gif(frames: list, output_gif_path: str, fps: int): | |
""" | |
Export a list of frames to a GIF. | |
Args: | |
- frames (list): List of frames (as numpy arrays or PIL Image objects). | |
- output_gif_path (str): Path to save the output GIF. | |
- fps (int): Desired frames per second. | |
""" | |
# Convert numpy arrays to PIL Images if needed | |
pil_frames = [Image.fromarray(frame) if isinstance( | |
frame, np.ndarray) else frame for frame in frames] | |
duration_ms = 1000 / fps | |
pil_frames[0].save(output_gif_path.replace(".mp4", ".gif"), | |
format="GIF", | |
append_images=pil_frames[1:], | |
save_all=True, | |
duration=duration_ms, | |
loop=0) | |
# return the gif | |
return output_gif_path.replace(".mp4", ".gif") | |
def unnormalize_imgs(normalized_imgs): | |
""" | |
[-1, 1] -> [0, 255] | |
Important: clip to [0, 255] | |
""" | |
normalized_imgs = torch.clamp(normalized_imgs, -1, 1) | |
rescaled_output = ((normalized_imgs.detach().cpu() + 1) * 127.5) | |
clipped_output = torch.clamp(rescaled_output, 0, 255).to(dtype=torch.uint8) | |
return clipped_output | |
# rescaled_output = ((normalized_imgs.detach().cpu() + 1) * 127.5) | |
# clipped_output = torch.clamp(rescaled_output, 0, 255).to(dtype=torch.uint8) | |
# return clipped_output | |
def decode_latents_wrapper( | |
batch_size: int = 16, | |
encoder_type: str = "magvit", | |
encoder_name_or_path: str = "data/magvit2.ckpt", | |
max_images: int = None, | |
device: str = "cuda", | |
): | |
dtype = torch.bfloat16 # torch.bfloat16 | |
model = get_image_encoder(encoder_type, encoder_name_or_path) | |
model = model.to(device=device, dtype=dtype) | |
def decode_latents(video_data: np.array): | |
""" | |
video_data: (b, h, w) for quantized data, or (b, c, h, w) for continuous data, | |
where b is `batch_size` and different from training/eval batch size. | |
""" | |
decoded_imgs = [] | |
for shard_ind in range(math.ceil(len(video_data) / batch_size)): | |
shard_data = video_data[shard_ind * batch_size: (shard_ind + 1) * batch_size] | |
if isinstance(model, VQModel): # TODO: class agnostic wrapper | |
# expecting quantized | |
assert shard_data.ndim == 3, f"{shard_data.shape=} {shard_data.dtype=}" | |
torch_shard = torch.from_numpy(shard_data.astype(np.int64)) | |
# if model.use_ema: # EMA does nothing in bugged VQModel | |
# with model.ema_scope(): | |
quant = model.quantize.get_codebook_entry(rearrange(torch_shard, "b h w -> b (h w)"), | |
bhwc=torch_shard.shape + (model.quantize.codebook_dim,)).flip(1) | |
normalized_imgs = model.decode(quant.to(device=device, dtype=dtype)) | |
elif isinstance(model, AutoencoderKLTemporalDecoder): | |
# expecting continuous | |
assert shard_data.ndim == 4, f"{shard_data.shape=} {shard_data.dtype=}" | |
torch_shard = torch.from_numpy(shard_data) | |
# manual clip | |
# if torch_shard.shape[0] == 16: | |
# print("prompt torch_shard", torch_shard[:4, 0].min(), torch_shard[:4, 0].max(), torch_shard[:4, 0].mean(), torch_shard[:4, 0].std()) | |
# print("pred torch_shard", torch_shard[4:12, 0].min(), torch_shard[4:12, 0].max(), torch_shard[4:12, 0].mean(), torch_shard[4:12, 0].std()) | |
# print("groundtruth torch_shard", torch_shard[12:, 0].min(), torch_shard[12:, 0].max(), torch_shard[12:, 0].mean(), torch_shard[12:, 0].std()) | |
torch_shard = torch.clamp(torch_shard, -25, 25) | |
normalized_imgs = model.decode(torch_shard.to(device=device, dtype=dtype), num_frames=1).sample # sample to mean | |
# if torch_shard.shape[0] == 16: | |
# print("prompt normalized_imgs", normalized_imgs[:4, 0].min(), normalized_imgs[:4, 0].max(), normalized_imgs[:4, 0].mean(), normalized_imgs[:4, 0].std()) | |
# print("pred normalized_imgs", normalized_imgs[4:12, 0].min(), normalized_imgs[4:12, 0].max(), normalized_imgs[4:12, 0].mean(), normalized_imgs[4:12, 0].std()) | |
# print("groundtruth normalized_imgs", normalized_imgs[12:, 0].min(), normalized_imgs[12:, 0].max(), normalized_imgs[12:, 0].mean(), normalized_imgs[12:, 0].std()) | |
else: | |
raise NotImplementedError(f"{model=}") | |
decoded_imgs.append(unnormalize_imgs(normalized_imgs)) | |
if max_images and len(decoded_imgs) * batch_size >= max_images: | |
break | |
return [transforms_f.to_pil_image(img) for img in torch.cat(decoded_imgs)] | |
return decode_latents | |
def caption_image(pil_image: Image, caption: str): | |
""" | |
Add a bit of empty space at the top, and add the caption there | |
""" | |
border_size = 36 | |
font_size = 24 | |
# convert pil_image to PIL.Image.Image if it's not already | |
if not isinstance(pil_image, Image.Image): | |
pil_image = transforms_f.to_pil_image(pil_image) | |
width, height = pil_image.size | |
new_width = width | |
new_height = height + border_size | |
new_image = Image.new("RGB", (new_width, new_height), "white") | |
new_image.paste(pil_image, (0, border_size)) | |
# Draw the caption | |
draw = ImageDraw.Draw(new_image) | |
# Center text (`align` keyword doesn't work) | |
_, _, text_w, text_h = draw.textbbox((0, 0), caption, font_size=font_size) | |
draw.text(((width - text_w) / 2, (border_size - text_h) / 2), caption, fill="black", font_size=font_size) | |
return new_image | |
def main(): | |
args = parse_args() | |
name = args.token_dir.split('/')[-2] | |
name_split = name.find('nodes') | |
model = name[:name_split-7] | |
dataset = name[name_split+8:] | |
# Load tokens | |
if args.use_feature: | |
token_dataset = RawFeatureDataset(args.token_dir, 1, compute_stride_from_freq_table=False, | |
filter_interrupts=False, filter_overlaps=False) | |
video_tokens = token_dataset.data | |
print(f"Loaded {video_tokens.shape=}") | |
else: | |
token_dataset = RawTokenDataset(args.token_dir, 1, compute_stride_from_freq_table=False, | |
filter_interrupts=False, filter_overlaps=False) | |
video_tokens = token_dataset.data | |
print(f"Loaded {video_tokens.shape=}") | |
metadata = token_dataset.metadata | |
video_tokens = video_tokens.reshape(-1, metadata["window_size"] * 2 - metadata["num_prompt_frames"], *video_tokens.shape[1:]) | |
decode_func = decode_latents_wrapper | |
print(metadata) | |
print(f"Reshape {video_tokens.shape=}") | |
wandb.init(project='video_eval_vis', settings=wandb.Settings(start_method="thread"), name=f"{args.project_prefix}vis_{model}", id=f"{args.project_prefix}vis_{model}", resume="allow") | |
for example_id in range(min(args.max_example, len(video_tokens))): | |
if args.use_feature: | |
if "encoder_type" not in metadata: | |
metadata["encoder_type"] = "temporalvae" | |
metadata["encoder_name_or_path"] = "stabilityai/stable-video-diffusion-img2vid" | |
decode_latents = decode_func(max_images=args.max_images, encoder_name_or_path=metadata["encoder_name_or_path"], | |
encoder_type=metadata["encoder_type"]) # args.offset::args.stride | |
this_video_token = torch.FloatTensor(video_tokens[example_id].copy())[None] / SVD_SCALE | |
this_video_token = rearrange(this_video_token, "b t c h w -> b t h w c") | |
video_frames = decode_features(this_video_token, decode_latents) | |
video_frames = rearrange(video_frames, "b t c h w -> b t h w c") | |
video_frames = video_frames.detach().cpu().numpy()[0].astype(np.uint8) | |
else: | |
decode_latents = decode_func(max_images=args.max_images) | |
this_video_token = torch.LongTensor(video_tokens[example_id])[None] | |
video_frames = decode_tokens(this_video_token, decode_latents) | |
video_frames = rearrange(video_frames, "b t c h w -> b t h w c") | |
video_frames = video_frames.detach().cpu().numpy()[0].astype(np.uint8) | |
output_gif_path = os.path.join(args.token_dir, f"example{args.offset}.gif") | |
# `generate` should populate `metadata.json` with these keys, while ground truth metadata does not have them | |
is_generated_data = all(key in metadata for key in ("num_prompt_frames", "window_size")) | |
if is_generated_data: | |
if video_tokens[example_id].shape[0] != metadata["window_size"] * 2 - metadata["num_prompt_frames"]: | |
raise ValueError(f"Unexpected {video_tokens.shape=} given {metadata['window_size']=}, {metadata['num_prompt_frames']=}") | |
captioned_frames = [] | |
for i, frame in enumerate(video_frames): | |
if i < metadata["num_prompt_frames"]: | |
caption = "Prompt" | |
elif i < metadata["window_size"]: | |
caption = "Generated" | |
else: | |
caption = "Ground truth" | |
captioned_frames.append(caption_image(frame, caption)) | |
else: | |
# Leave ground truth frames uncaptioned | |
captioned_frames = video_frames | |
gif_path = export_to_gif(captioned_frames, output_gif_path, args.fps) | |
print(f"Saved to {output_gif_path}") | |
if not args.disable_comic: | |
fig, axs = plt.subplots(nrows=2, ncols=metadata["window_size"], figsize=(3 * metadata["window_size"], 3 * 2)) | |
for i, image in enumerate(video_frames): | |
if i < metadata["num_prompt_frames"]: | |
curr_axs = [axs[0, i], axs[1, i]] | |
title = "Prompt" | |
elif i < metadata["window_size"]: | |
curr_axs = [axs[0, i]] | |
title = "Prediction" | |
else: | |
curr_axs = [axs[1, i - metadata["window_size"] + metadata["num_prompt_frames"]]] | |
title = "Ground truth" | |
for ax in curr_axs: | |
ax.set_title(title) | |
ax.imshow(image) | |
ax.axis("off") | |
output_comic_path = os.path.join(args.token_dir, f"example{args.offset}.png") | |
plt.savefig(output_comic_path, bbox_inches="tight") | |
plt.close() | |
print(f"Saved to {output_comic_path}") | |
wandb.log({f"{dataset}/gif_{example_id}": wandb.Video(gif_path)}) | |
# add wandb logging | |
# wandb.log({f"{dataset}/comic_{args.example_ind}": wandb.Image(output_comic_path)}) | |
wandb.run.summary["model_checkpoint"] = metadata["model_checkpoint"] | |
wandb.run.summary["dataset"] = metadata["dataset"] | |
wandb.run.summary["trained_steps"] = metadata["trained_steps"] | |
wandb.finish() | |
if __name__ == "__main__": | |
main() | |