File size: 13,859 Bytes
246c106
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
#!/usr/bin/env python3

"""
Script to decode tokenized video into images/video.
Example usage: See https://github.com/1x-technologies/1xgpt?tab=readme-ov-file#1x-genie-baseline
"""

import argparse
import math
import os
from PIL import Image, ImageDraw

import numpy as np
import torch
import torch.distributed.optim
import torch.utils.checkpoint
import torch.utils.data
import torchvision.transforms.v2.functional as transforms_f
from diffusers import AutoencoderKLTemporalDecoder
from einops import rearrange
from matplotlib import pyplot as plt
from cont_data import RawFeatureDataset

from data import RawTokenDataset
from datasets.utils import get_image_encoder
from magvit2.config import VQConfig
from magvit2.models.lfqgan import VQModel
from common.eval_utils import decode_tokens, decode_features
import wandb
wandb.login(key='4c1540ebf8cb9964703ac212a937c00848a79b67')

SVD_SCALE = 0.18215

def parse_args():
    parser = argparse.ArgumentParser(description="Visualize tokenized video as GIF or comic.")
    parser.add_argument(
        "--stride",
        type=int,
        default=1,
        help="Frame skip",
    )
    parser.add_argument(
        "--token_dir",
        type=str,
        default="data/genie_generated",
        help="Directory of tokens, in the format of `video.bin` and `metadata.json`. "
             "Visualized gif and comic will be written here.",
    )
    parser.add_argument(
        "--offset", type=int, default=0, help="Offset to start generating images from"
    )
    parser.add_argument(
        "--fps", type=int, default=2, help="Frames per second"
    )
    parser.add_argument(
        "--max_images", type=int, default=None, help="Maximum number of images to generate. None for all."
    )
    parser.add_argument(
        "--example_ind", type=int, default=0,
        help="The index in the dataset of the example to generate on."
    )
    parser.add_argument(
        "--project_prefix", type=str, default="", help="Project suffix."
    )
    parser.add_argument(
        "--disable_comic", action="store_true",
        help="Comic generation assumes `token_dir` follows the same format as generate: e.g., "
             "`prompt | predictions | gtruth` in `video.bin`, `window_size` in `metadata.json`."
             "Therefore, comic should be disabled when visualizing videos without this format, such as the dataset."
    )
    parser.add_argument(
        "--batch_size", type=int, default=4,
        help="Batch size, current script only supports a single GPU."
    )
    parser.add_argument(
        "--max_example", type=int, default=4,
        help="Maximum number of examples."
    )
    parser.add_argument(
        "--use_feature", action="store_true",
        help="visualize the features rather than tokens"
    )
    args = parser.parse_args()

    return args


def export_to_gif(frames: list, output_gif_path: str, fps: int):
    """
    Export a list of frames to a GIF.

    Args:
    - frames (list): List of frames (as numpy arrays or PIL Image objects).
    - output_gif_path (str): Path to save the output GIF.
    - fps (int): Desired frames per second.
    """
    # Convert numpy arrays to PIL Images if needed
    pil_frames = [Image.fromarray(frame) if isinstance(
        frame, np.ndarray) else frame for frame in frames]

    duration_ms = 1000 / fps
    pil_frames[0].save(output_gif_path.replace(".mp4", ".gif"),
                       format="GIF",
                       append_images=pil_frames[1:],
                       save_all=True,
                       duration=duration_ms,
                       loop=0)
    # return the gif
    return output_gif_path.replace(".mp4", ".gif")

def unnormalize_imgs(normalized_imgs):
    """
    [-1, 1] -> [0, 255]

    Important: clip to [0, 255]
    """
    normalized_imgs = torch.clamp(normalized_imgs, -1, 1)
    rescaled_output = ((normalized_imgs.detach().cpu() + 1) * 127.5)
    clipped_output = torch.clamp(rescaled_output, 0, 255).to(dtype=torch.uint8)
    return clipped_output
    # rescaled_output = ((normalized_imgs.detach().cpu() + 1) * 127.5)
    # clipped_output = torch.clamp(rescaled_output, 0, 255).to(dtype=torch.uint8)
    # return clipped_output

def decode_latents_wrapper(
    batch_size: int = 16,
    encoder_type: str = "magvit",
    encoder_name_or_path: str = "data/magvit2.ckpt",
    max_images: int = None,
    device: str = "cuda",
):
    dtype = torch.bfloat16 # torch.bfloat16
    model = get_image_encoder(encoder_type, encoder_name_or_path)
    model = model.to(device=device, dtype=dtype)

    @torch.no_grad()
    def decode_latents(video_data: np.array):
        """
        video_data: (b, h, w) for quantized data, or (b, c, h, w) for continuous data,
        where b is `batch_size` and different from training/eval batch size.
        """
        decoded_imgs = []

        for shard_ind in range(math.ceil(len(video_data) / batch_size)):
            shard_data = video_data[shard_ind * batch_size: (shard_ind + 1) * batch_size]
            if isinstance(model, VQModel):  # TODO: class agnostic wrapper
                # expecting quantized
                assert shard_data.ndim == 3, f"{shard_data.shape=} {shard_data.dtype=}"
                torch_shard = torch.from_numpy(shard_data.astype(np.int64))
                # if model.use_ema:  # EMA does nothing in bugged VQModel
                #     with model.ema_scope():
                quant = model.quantize.get_codebook_entry(rearrange(torch_shard, "b h w -> b (h w)"),
                                                          bhwc=torch_shard.shape + (model.quantize.codebook_dim,)).flip(1)
                normalized_imgs = model.decode(quant.to(device=device, dtype=dtype))
            elif isinstance(model, AutoencoderKLTemporalDecoder):
                # expecting continuous
                assert shard_data.ndim == 4, f"{shard_data.shape=} {shard_data.dtype=}"
                torch_shard = torch.from_numpy(shard_data)
                # manual clip
                # if torch_shard.shape[0] == 16:
                #     print("prompt torch_shard", torch_shard[:4, 0].min(), torch_shard[:4, 0].max(), torch_shard[:4, 0].mean(), torch_shard[:4, 0].std())
                #     print("pred torch_shard", torch_shard[4:12, 0].min(), torch_shard[4:12, 0].max(), torch_shard[4:12, 0].mean(), torch_shard[4:12, 0].std())
                #     print("groundtruth torch_shard", torch_shard[12:, 0].min(), torch_shard[12:, 0].max(), torch_shard[12:, 0].mean(), torch_shard[12:, 0].std())

                torch_shard = torch.clamp(torch_shard, -25, 25)
                normalized_imgs = model.decode(torch_shard.to(device=device, dtype=dtype), num_frames=1).sample # sample to mean
                # if torch_shard.shape[0] == 16:
                #     print("prompt normalized_imgs", normalized_imgs[:4, 0].min(), normalized_imgs[:4, 0].max(), normalized_imgs[:4, 0].mean(), normalized_imgs[:4, 0].std())
                #     print("pred normalized_imgs", normalized_imgs[4:12, 0].min(), normalized_imgs[4:12, 0].max(), normalized_imgs[4:12, 0].mean(), normalized_imgs[4:12, 0].std())
                #     print("groundtruth normalized_imgs", normalized_imgs[12:, 0].min(), normalized_imgs[12:, 0].max(), normalized_imgs[12:, 0].mean(), normalized_imgs[12:, 0].std())

            else:
                raise NotImplementedError(f"{model=}")

            decoded_imgs.append(unnormalize_imgs(normalized_imgs))
            if max_images and len(decoded_imgs) * batch_size >= max_images:
                break

        return [transforms_f.to_pil_image(img) for img in torch.cat(decoded_imgs)]

    return decode_latents


def caption_image(pil_image: Image, caption: str):
    """
    Add a bit of empty space at the top, and add the caption there
    """
    border_size = 36
    font_size = 24
    # convert pil_image to PIL.Image.Image if it's not already
    if not isinstance(pil_image, Image.Image):
        pil_image = transforms_f.to_pil_image(pil_image)

    width, height = pil_image.size
    new_width = width
    new_height = height + border_size

    new_image = Image.new("RGB", (new_width, new_height), "white")
    new_image.paste(pil_image, (0, border_size))

    # Draw the caption
    draw = ImageDraw.Draw(new_image)

    # Center text (`align` keyword doesn't work)
    _, _, text_w, text_h = draw.textbbox((0, 0), caption, font_size=font_size)
    draw.text(((width - text_w) / 2, (border_size - text_h) / 2), caption, fill="black", font_size=font_size)

    return new_image


@torch.no_grad()
def main():
    args = parse_args()
    name = args.token_dir.split('/')[-2]
    name_split = name.find('nodes')
    model = name[:name_split-7]
    dataset = name[name_split+8:]

    # Load tokens
    if args.use_feature:
        token_dataset = RawFeatureDataset(args.token_dir, 1, compute_stride_from_freq_table=False,
                                          filter_interrupts=False, filter_overlaps=False)
        video_tokens = token_dataset.data

        print(f"Loaded {video_tokens.shape=}")
    else:
        token_dataset = RawTokenDataset(args.token_dir, 1, compute_stride_from_freq_table=False,
                                        filter_interrupts=False, filter_overlaps=False)
        video_tokens = token_dataset.data
        print(f"Loaded {video_tokens.shape=}")

    metadata = token_dataset.metadata
    video_tokens = video_tokens.reshape(-1, metadata["window_size"] * 2 - metadata["num_prompt_frames"], *video_tokens.shape[1:])
    decode_func = decode_latents_wrapper
    print(metadata)
    print(f"Reshape {video_tokens.shape=}")

    wandb.init(project='video_eval_vis', settings=wandb.Settings(start_method="thread"), name=f"{args.project_prefix}vis_{model}", id=f"{args.project_prefix}vis_{model}", resume="allow")
    for example_id in range(min(args.max_example, len(video_tokens))):
        if args.use_feature:
            if "encoder_type" not in metadata:
                metadata["encoder_type"] = "temporalvae"
                metadata["encoder_name_or_path"] = "stabilityai/stable-video-diffusion-img2vid"
            decode_latents = decode_func(max_images=args.max_images, encoder_name_or_path=metadata["encoder_name_or_path"],
                                       encoder_type=metadata["encoder_type"])  # args.offset::args.stride
            this_video_token = torch.FloatTensor(video_tokens[example_id].copy())[None] / SVD_SCALE
            this_video_token = rearrange(this_video_token, "b t c h w -> b t h w c")
            video_frames = decode_features(this_video_token, decode_latents)
            video_frames = rearrange(video_frames, "b t c h w -> b t h w c")
            video_frames = video_frames.detach().cpu().numpy()[0].astype(np.uint8)
        else:
            decode_latents = decode_func(max_images=args.max_images)
            this_video_token = torch.LongTensor(video_tokens[example_id])[None]
            video_frames = decode_tokens(this_video_token, decode_latents)
            video_frames = rearrange(video_frames, "b t c h w -> b t h w c")
            video_frames = video_frames.detach().cpu().numpy()[0].astype(np.uint8)

        output_gif_path = os.path.join(args.token_dir, f"example{args.offset}.gif")

        # `generate` should populate `metadata.json` with these keys, while ground truth metadata does not have them
        is_generated_data = all(key in metadata for key in ("num_prompt_frames", "window_size"))
        if is_generated_data:
            if video_tokens[example_id].shape[0] != metadata["window_size"] * 2 - metadata["num_prompt_frames"]:
                raise ValueError(f"Unexpected {video_tokens.shape=} given {metadata['window_size']=}, {metadata['num_prompt_frames']=}")

            captioned_frames = []
            for i, frame in enumerate(video_frames):
                if i < metadata["num_prompt_frames"]:
                    caption = "Prompt"
                elif i < metadata["window_size"]:
                    caption = "Generated"
                else:
                    caption = "Ground truth"

                captioned_frames.append(caption_image(frame, caption))
        else:
            # Leave ground truth frames uncaptioned
            captioned_frames = video_frames

        gif_path = export_to_gif(captioned_frames, output_gif_path, args.fps)
        print(f"Saved to {output_gif_path}")

        if not args.disable_comic:
            fig, axs = plt.subplots(nrows=2, ncols=metadata["window_size"], figsize=(3 * metadata["window_size"], 3 * 2))
            for i, image in enumerate(video_frames):
                if i < metadata["num_prompt_frames"]:
                    curr_axs = [axs[0, i], axs[1, i]]
                    title = "Prompt"

                elif i < metadata["window_size"]:
                    curr_axs = [axs[0, i]]
                    title = "Prediction"
                else:
                    curr_axs = [axs[1, i - metadata["window_size"] + metadata["num_prompt_frames"]]]
                    title = "Ground truth"

                for ax in curr_axs:
                    ax.set_title(title)
                    ax.imshow(image)
                    ax.axis("off")

            output_comic_path = os.path.join(args.token_dir, f"example{args.offset}.png")
            plt.savefig(output_comic_path, bbox_inches="tight")
            plt.close()
            print(f"Saved to {output_comic_path}")
        wandb.log({f"{dataset}/gif_{example_id}": wandb.Video(gif_path)})

    # add wandb logging
    # wandb.log({f"{dataset}/comic_{args.example_ind}": wandb.Image(output_comic_path)})
    wandb.run.summary["model_checkpoint"] = metadata["model_checkpoint"]
    wandb.run.summary["dataset"] = metadata["dataset"]
    wandb.run.summary["trained_steps"] = metadata["trained_steps"]

    wandb.finish()


if __name__ == "__main__":
    main()