File size: 4,316 Bytes
91fb4ef
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
"""
Adapted from:
https://github.com/genmoai/mochi/blob/main/demos/fine_tuner/encode_videos.py
https://github.com/genmoai/mochi/blob/main/demos/fine_tuner/embed_captions.py
"""

import click
import torch
import torchvision
from pathlib import Path
from diffusers import AutoencoderKLMochi, MochiPipeline
from transformers import T5EncoderModel, T5Tokenizer
from tqdm.auto import tqdm


def encode_videos(model: torch.nn.Module, vid_path: Path, shape: str):
    T, H, W = [int(s) for s in shape.split("x")]
    assert (T - 1) % 6 == 0, "Expected T to be 1 mod 6"
    video, _, metadata = torchvision.io.read_video(str(vid_path), output_format="THWC", pts_unit="secs")
    fps = metadata["video_fps"]
    video = video.permute(3, 0, 1, 2)
    og_shape = video.shape
    assert video.shape[2] == H, f"Expected {vid_path} to have height {H}, got {video.shape}"
    assert video.shape[3] == W, f"Expected {vid_path} to have width {W}, got {video.shape}"
    assert video.shape[1] >= T, f"Expected {vid_path} to have at least {T} frames, got {video.shape}"
    if video.shape[1] > T:
        video = video[:, :T]
        print(f"Trimmed video from {og_shape[1]} to first {T} frames")
    video = video.unsqueeze(0)
    video = video.float() / 127.5 - 1.0
    video = video.to(model.device)

    assert video.ndim == 5

    with torch.inference_mode():
        with torch.autocast("cuda", dtype=torch.bfloat16):
            ldist = model._encode(video)

        torch.save(dict(ldist=ldist), vid_path.with_suffix(".latent.pt"))


@click.command()
@click.argument("output_dir", type=click.Path(exists=True, file_okay=False, dir_okay=True, path_type=Path))
@click.option(
    "--model_id",
    type=str,
    help="Repo id. Should be genmo/mochi-1-preview",
    default="genmo/mochi-1-preview",
)
@click.option("--shape", default="163x480x848", help="Shape of the video to encode")
@click.option("--overwrite", "-ow", is_flag=True, help="Overwrite existing latents and caption embeddings.")
def batch_process(output_dir: Path, model_id: Path, shape: str, overwrite: bool) -> None:
    """Process all videos and captions in a directory using a single GPU."""
    # comment out when running on unsupported hardware
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True

    # Get all video paths
    video_paths = list(output_dir.glob("**/*.mp4"))
    if not video_paths:
        print(f"No MP4 files found in {output_dir}")
        return

    text_paths = list(output_dir.glob("**/*.txt"))
    if not text_paths:
        print(f"No text files found in {output_dir}")
        return

    # load the models
    vae = AutoencoderKLMochi.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32).to("cuda")
    text_encoder = T5EncoderModel.from_pretrained(model_id, subfolder="text_encoder")
    tokenizer = T5Tokenizer.from_pretrained(model_id, subfolder="tokenizer")
    pipeline = MochiPipeline.from_pretrained(
        model_id, text_encoder=text_encoder, tokenizer=tokenizer, transformer=None, vae=None
    ).to("cuda")

    for idx, video_path in tqdm(enumerate(sorted(video_paths))):
        print(f"Processing {video_path}")
        try:
            if video_path.with_suffix(".latent.pt").exists() and not overwrite:
                print(f"Skipping {video_path}")
                continue

            # encode videos.
            encode_videos(vae, vid_path=video_path, shape=shape)

            # embed captions.
            prompt_path = Path("/".join(str(video_path).split(".")[:-1]) + ".txt")
            embed_path = prompt_path.with_suffix(".embed.pt")

            if embed_path.exists() and not overwrite:
                print(f"Skipping {prompt_path} - embeddings already exist")
                continue

            with open(prompt_path) as f:
                text = f.read().strip()
            with torch.inference_mode():
                conditioning = pipeline.encode_prompt(prompt=[text])

            conditioning = {"prompt_embeds": conditioning[0], "prompt_attention_mask": conditioning[1]}
            torch.save(conditioning, embed_path)

        except Exception as e:
            import traceback

            traceback.print_exc()
            print(f"Error processing {video_path}: {str(e)}")


if __name__ == "__main__":
    batch_process()