File size: 8,613 Bytes
246c106
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
"""
Example usage: See https://github.com/1x-technologies/1xgpt?tab=readme-ov-file#1x-genie-baseline
"""

import argparse
import json
import os
import sys
from pathlib import Path

import torch
import numpy as np

sys.path.append(os.getcwd())
from data import RawTokenDataset
from genie.st_mask_git import STMaskGIT

from cont_data import RawFeatureDataset
from genie.st_mar import STMAR
from torch.utils.data import DataLoader
from einops import rearrange
import re
from transformers import default_data_collator

def parse_args():
    parser = argparse.ArgumentParser(description="Generates samples (as tokens) from GENIE model. "
                                                 "Optionally visualizes these tokens as GIFs or comics.")
    parser.add_argument(
        "--val_data_dir", type=str, default="data/1x_humanoid_magvit_traj10_val",
        help="A directory with video data, should have a `metadata.json` and `video.bin` We generate using the first frames of this dataset."
    )
    parser.add_argument(
        "--checkpoint_dir", type=str,
        help="Path to a HuggingFace-style checkpoint."
    )
    parser.add_argument(
        "--output_dir", type=str, default="data/genie_generated",
        help="Directory to save generated outputs."
    )
    parser.add_argument(
        "--num_prompt_frames", type=int, default=4, help="The number of context frames."
    )
    parser.add_argument(
        "--window_size", type=int, default=12,
        help="Will generate `window_size - num_prompt_frames` frames."
    )
    parser.add_argument(
        "--example_ind", type=int, default=0,
        help="The index in the dataset of the example to generate on."
    )
    parser.add_argument(
        "--teacher_force_time", action="store_true",
        help="If True, teacher-forces generation in time dimension."
    )
    parser.add_argument(
        "--maskgit_steps", type=int, default=2, help="Number of MaskGIT sampling steps."
    )
    parser.add_argument(
        "--temperature", type=float, default=0,
        help="Sampling temperature. If `temperature` <= 1e-8, will do greedy sampling."
    )
    parser.add_argument(
        "--add_action_input", action="store_true",
        help="If True, uses action in the video output."
    )
    parser.add_argument(
        "--batch_size", type=int, default=4,
        help="Batch size, current script only supports a single GPU."
    )
    parser.add_argument(
        "--max_example", type=int, default=16,
        help="Maximum number of examples."
    )
    parser.add_argument(
        "--use_feature", action="store_true",
        help="visualize the features rather than tokens"
    )
    return parser.parse_args()

def get_model_step(checkpoint_dir):
    if os.path.exists(f"{checkpoint_dir}/scheduler.bin"):
        sch = torch.load(f"{checkpoint_dir}/scheduler.bin")
        return sch['_step_count']
    return 0

def compute_stride_from_model(model, dataset):
    action_d = len(model.action_preprocessor[dataset].mean)
    action_d_horizon = model.config.d_actions[model.config.action_domains.index(dataset)]
    stride = action_d_horizon // action_d
    print("model stride:", stride)
    return stride

@torch.no_grad()
def main():
    args = parse_args()
    assert args.num_prompt_frames <= args.window_size

    if not os.path.exists(args.checkpoint_dir + "/config.json"):
        # search and find the latest modified checkpoint folder
        dirs = [os.path.join(args.checkpoint_dir, f.name) for f in os.scandir(args.checkpoint_dir) if f.is_dir()]
        dirs.sort(key=os.path.getctime)
        if len(dirs) == 0:
            print(f"No checkpoint directories found in {args.checkpoint_dir}")
            sys.exit(1)
        args.checkpoint_dir = dirs[-1]

    dataset = re.search(r"data/(.*?)_magvit", args.val_data_dir).group(1)
    # HACK HERE TO DO MULTIPLE DATASETS FOR THE SAME DATASET
    if "robomimic" in args.val_data_dir:
        dataset = "robomimic"

    # Load the model checkpoint
    if not args.use_feature:
        print(f"loading STMaskGIT")
        model = STMaskGIT.from_pretrained(args.checkpoint_dir).to("cuda")
        stride = compute_stride_from_model(model, dataset)
        val_dataset = RawTokenDataset(args.val_data_dir, window_size=args.window_size,
                                        compute_stride_from_freq_table=False,
                                        stride=stride,
                                        use_actions=model.config.use_actions)
    else:
        print(f"loading STMAR")
        model = STMAR.from_pretrained(args.checkpoint_dir).to("cuda")
        stride = compute_stride_from_model(model, dataset)
        args.val_data_dir = args.val_data_dir.replace("magvit", "vae")

        val_dataset = RawFeatureDataset(args.val_data_dir,
                                      compute_stride_from_freq_table=False,
                                      stride=stride, window_size=args.window_size,
                                        use_actions=model.config.use_actions)
        val_dataset.metadata["token_dtype"] = "float32"

    latent_side_len = val_dataset.data.shape[-1] # assume square
    dataloader = DataLoader(val_dataset, collate_fn=default_data_collator, batch_size=args.batch_size)

    # Get single example
    if args.max_example > len(val_dataset):
        print(f"Example index {args.example_ind} is out of bounds for dataset of length {len(val_dataset)}")
        sys.exit(1)

    model.eval()
    output_list = []

    for batch_idx, batch in enumerate(dataloader):
        samples = []
        if args.use_feature:
            example_THW = rearrange(batch["input_ids"].to("cuda"), "b (t h w) c -> b t h w c", t=args.window_size,
                                h=latent_side_len, w=latent_side_len)
        else:
            example_THW = rearrange(batch["input_ids"].to("cuda"), "b (t h w) -> b t h w", t=args.window_size,
                                h=latent_side_len, w=latent_side_len)
        example_actions = None
        domain = None

        if model.config.use_actions:
            example_actions = batch["action_ids"].to("cuda")
            domain = [val_dataset.name.replace("_noquant", "")] * args.batch_size

        prompt_THW = example_THW.clone()
        prompt_THW[:, args.num_prompt_frames:] = model.mask_token if args.use_feature else model.mask_token_id

        for timestep in range(args.num_prompt_frames, args.window_size):
            # Teacher-forced, maskgit generation
            if args.teacher_force_time:
                prompt_THW = example_THW.clone()
                # Masked prediction for this timestep only, after which we provide ground-truth
                prompt_THW[:, timestep:] = model.mask_token if args.use_feature else model.mask_token_id

            samples_HW, _, _ = model.maskgit_generate(
                prompt_THW, out_t=timestep, temperature=args.temperature,
                action_ids=example_actions, domain=domain
            )

            samples.append(samples_HW)
            if not args.teacher_force_time:
                # autoregressive
                prompt_THW[:, timestep] = samples_HW


        outputs = torch.stack(samples, dim=1)
        # prepend prompt sequence
        outputs = torch.cat([example_THW[:, :args.num_prompt_frames], outputs], dim=1)

        # append ground-truth targets next to generated outputs for comic strip generation
        # [<prompt frames><predicted frames><ground truth frames>]
        outputs = torch.cat([outputs, example_THW[:, args.num_prompt_frames:]], dim=1)
        output_list.append(outputs)
        if batch_idx >= args.max_example // args.batch_size:
            break

    outputs = torch.cat(output_list, dim=0)
    if args.use_feature:
        # use chw
        outputs = rearrange(outputs, "b t h w c -> b t c h w")

    output_dir = Path(args.output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)
    outputs.cpu().numpy().astype(np.dtype(val_dataset.metadata["token_dtype"])).tofile(output_dir / "video.bin")
    print(f"Saved generated video to {output_dir / 'video.bin'} {outputs.shape}")
    model_steps = get_model_step(args.checkpoint_dir)

    with open(output_dir / "metadata.json", "w") as f:
        json.dump(vars(args) | val_dataset.metadata | {
            "num_images":  outputs.shape[1],
            "h": latent_side_len,
            "w": latent_side_len,
            "t": args.window_size,
            "model_checkpoint": args.checkpoint_dir,
            "dataset": val_dataset.name,
            "trained_steps": model_steps,
        }, f)


if __name__ == "__main__":
    main()