Spaces:
Running
on
Zero
Running
on
Zero
File size: 8,613 Bytes
246c106 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 |
"""
Example usage: See https://github.com/1x-technologies/1xgpt?tab=readme-ov-file#1x-genie-baseline
"""
import argparse
import json
import os
import sys
from pathlib import Path
import torch
import numpy as np
sys.path.append(os.getcwd())
from data import RawTokenDataset
from genie.st_mask_git import STMaskGIT
from cont_data import RawFeatureDataset
from genie.st_mar import STMAR
from torch.utils.data import DataLoader
from einops import rearrange
import re
from transformers import default_data_collator
def parse_args():
parser = argparse.ArgumentParser(description="Generates samples (as tokens) from GENIE model. "
"Optionally visualizes these tokens as GIFs or comics.")
parser.add_argument(
"--val_data_dir", type=str, default="data/1x_humanoid_magvit_traj10_val",
help="A directory with video data, should have a `metadata.json` and `video.bin` We generate using the first frames of this dataset."
)
parser.add_argument(
"--checkpoint_dir", type=str,
help="Path to a HuggingFace-style checkpoint."
)
parser.add_argument(
"--output_dir", type=str, default="data/genie_generated",
help="Directory to save generated outputs."
)
parser.add_argument(
"--num_prompt_frames", type=int, default=4, help="The number of context frames."
)
parser.add_argument(
"--window_size", type=int, default=12,
help="Will generate `window_size - num_prompt_frames` frames."
)
parser.add_argument(
"--example_ind", type=int, default=0,
help="The index in the dataset of the example to generate on."
)
parser.add_argument(
"--teacher_force_time", action="store_true",
help="If True, teacher-forces generation in time dimension."
)
parser.add_argument(
"--maskgit_steps", type=int, default=2, help="Number of MaskGIT sampling steps."
)
parser.add_argument(
"--temperature", type=float, default=0,
help="Sampling temperature. If `temperature` <= 1e-8, will do greedy sampling."
)
parser.add_argument(
"--add_action_input", action="store_true",
help="If True, uses action in the video output."
)
parser.add_argument(
"--batch_size", type=int, default=4,
help="Batch size, current script only supports a single GPU."
)
parser.add_argument(
"--max_example", type=int, default=16,
help="Maximum number of examples."
)
parser.add_argument(
"--use_feature", action="store_true",
help="visualize the features rather than tokens"
)
return parser.parse_args()
def get_model_step(checkpoint_dir):
if os.path.exists(f"{checkpoint_dir}/scheduler.bin"):
sch = torch.load(f"{checkpoint_dir}/scheduler.bin")
return sch['_step_count']
return 0
def compute_stride_from_model(model, dataset):
action_d = len(model.action_preprocessor[dataset].mean)
action_d_horizon = model.config.d_actions[model.config.action_domains.index(dataset)]
stride = action_d_horizon // action_d
print("model stride:", stride)
return stride
@torch.no_grad()
def main():
args = parse_args()
assert args.num_prompt_frames <= args.window_size
if not os.path.exists(args.checkpoint_dir + "/config.json"):
# search and find the latest modified checkpoint folder
dirs = [os.path.join(args.checkpoint_dir, f.name) for f in os.scandir(args.checkpoint_dir) if f.is_dir()]
dirs.sort(key=os.path.getctime)
if len(dirs) == 0:
print(f"No checkpoint directories found in {args.checkpoint_dir}")
sys.exit(1)
args.checkpoint_dir = dirs[-1]
dataset = re.search(r"data/(.*?)_magvit", args.val_data_dir).group(1)
# HACK HERE TO DO MULTIPLE DATASETS FOR THE SAME DATASET
if "robomimic" in args.val_data_dir:
dataset = "robomimic"
# Load the model checkpoint
if not args.use_feature:
print(f"loading STMaskGIT")
model = STMaskGIT.from_pretrained(args.checkpoint_dir).to("cuda")
stride = compute_stride_from_model(model, dataset)
val_dataset = RawTokenDataset(args.val_data_dir, window_size=args.window_size,
compute_stride_from_freq_table=False,
stride=stride,
use_actions=model.config.use_actions)
else:
print(f"loading STMAR")
model = STMAR.from_pretrained(args.checkpoint_dir).to("cuda")
stride = compute_stride_from_model(model, dataset)
args.val_data_dir = args.val_data_dir.replace("magvit", "vae")
val_dataset = RawFeatureDataset(args.val_data_dir,
compute_stride_from_freq_table=False,
stride=stride, window_size=args.window_size,
use_actions=model.config.use_actions)
val_dataset.metadata["token_dtype"] = "float32"
latent_side_len = val_dataset.data.shape[-1] # assume square
dataloader = DataLoader(val_dataset, collate_fn=default_data_collator, batch_size=args.batch_size)
# Get single example
if args.max_example > len(val_dataset):
print(f"Example index {args.example_ind} is out of bounds for dataset of length {len(val_dataset)}")
sys.exit(1)
model.eval()
output_list = []
for batch_idx, batch in enumerate(dataloader):
samples = []
if args.use_feature:
example_THW = rearrange(batch["input_ids"].to("cuda"), "b (t h w) c -> b t h w c", t=args.window_size,
h=latent_side_len, w=latent_side_len)
else:
example_THW = rearrange(batch["input_ids"].to("cuda"), "b (t h w) -> b t h w", t=args.window_size,
h=latent_side_len, w=latent_side_len)
example_actions = None
domain = None
if model.config.use_actions:
example_actions = batch["action_ids"].to("cuda")
domain = [val_dataset.name.replace("_noquant", "")] * args.batch_size
prompt_THW = example_THW.clone()
prompt_THW[:, args.num_prompt_frames:] = model.mask_token if args.use_feature else model.mask_token_id
for timestep in range(args.num_prompt_frames, args.window_size):
# Teacher-forced, maskgit generation
if args.teacher_force_time:
prompt_THW = example_THW.clone()
# Masked prediction for this timestep only, after which we provide ground-truth
prompt_THW[:, timestep:] = model.mask_token if args.use_feature else model.mask_token_id
samples_HW, _, _ = model.maskgit_generate(
prompt_THW, out_t=timestep, temperature=args.temperature,
action_ids=example_actions, domain=domain
)
samples.append(samples_HW)
if not args.teacher_force_time:
# autoregressive
prompt_THW[:, timestep] = samples_HW
outputs = torch.stack(samples, dim=1)
# prepend prompt sequence
outputs = torch.cat([example_THW[:, :args.num_prompt_frames], outputs], dim=1)
# append ground-truth targets next to generated outputs for comic strip generation
# [<prompt frames><predicted frames><ground truth frames>]
outputs = torch.cat([outputs, example_THW[:, args.num_prompt_frames:]], dim=1)
output_list.append(outputs)
if batch_idx >= args.max_example // args.batch_size:
break
outputs = torch.cat(output_list, dim=0)
if args.use_feature:
# use chw
outputs = rearrange(outputs, "b t h w c -> b t c h w")
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
outputs.cpu().numpy().astype(np.dtype(val_dataset.metadata["token_dtype"])).tofile(output_dir / "video.bin")
print(f"Saved generated video to {output_dir / 'video.bin'} {outputs.shape}")
model_steps = get_model_step(args.checkpoint_dir)
with open(output_dir / "metadata.json", "w") as f:
json.dump(vars(args) | val_dataset.metadata | {
"num_images": outputs.shape[1],
"h": latent_side_len,
"w": latent_side_len,
"t": args.window_size,
"model_checkpoint": args.checkpoint_dir,
"dataset": val_dataset.name,
"trained_steps": model_steps,
}, f)
if __name__ == "__main__":
main()
|