hma / genie /generate.py
LeroyWaa's picture
draft
246c106
raw
history blame
8.61 kB
"""
Example usage: See https://github.com/1x-technologies/1xgpt?tab=readme-ov-file#1x-genie-baseline
"""
import argparse
import json
import os
import sys
from pathlib import Path
import torch
import numpy as np
sys.path.append(os.getcwd())
from data import RawTokenDataset
from genie.st_mask_git import STMaskGIT
from cont_data import RawFeatureDataset
from genie.st_mar import STMAR
from torch.utils.data import DataLoader
from einops import rearrange
import re
from transformers import default_data_collator
def parse_args():
parser = argparse.ArgumentParser(description="Generates samples (as tokens) from GENIE model. "
"Optionally visualizes these tokens as GIFs or comics.")
parser.add_argument(
"--val_data_dir", type=str, default="data/1x_humanoid_magvit_traj10_val",
help="A directory with video data, should have a `metadata.json` and `video.bin` We generate using the first frames of this dataset."
)
parser.add_argument(
"--checkpoint_dir", type=str,
help="Path to a HuggingFace-style checkpoint."
)
parser.add_argument(
"--output_dir", type=str, default="data/genie_generated",
help="Directory to save generated outputs."
)
parser.add_argument(
"--num_prompt_frames", type=int, default=4, help="The number of context frames."
)
parser.add_argument(
"--window_size", type=int, default=12,
help="Will generate `window_size - num_prompt_frames` frames."
)
parser.add_argument(
"--example_ind", type=int, default=0,
help="The index in the dataset of the example to generate on."
)
parser.add_argument(
"--teacher_force_time", action="store_true",
help="If True, teacher-forces generation in time dimension."
)
parser.add_argument(
"--maskgit_steps", type=int, default=2, help="Number of MaskGIT sampling steps."
)
parser.add_argument(
"--temperature", type=float, default=0,
help="Sampling temperature. If `temperature` <= 1e-8, will do greedy sampling."
)
parser.add_argument(
"--add_action_input", action="store_true",
help="If True, uses action in the video output."
)
parser.add_argument(
"--batch_size", type=int, default=4,
help="Batch size, current script only supports a single GPU."
)
parser.add_argument(
"--max_example", type=int, default=16,
help="Maximum number of examples."
)
parser.add_argument(
"--use_feature", action="store_true",
help="visualize the features rather than tokens"
)
return parser.parse_args()
def get_model_step(checkpoint_dir):
if os.path.exists(f"{checkpoint_dir}/scheduler.bin"):
sch = torch.load(f"{checkpoint_dir}/scheduler.bin")
return sch['_step_count']
return 0
def compute_stride_from_model(model, dataset):
action_d = len(model.action_preprocessor[dataset].mean)
action_d_horizon = model.config.d_actions[model.config.action_domains.index(dataset)]
stride = action_d_horizon // action_d
print("model stride:", stride)
return stride
@torch.no_grad()
def main():
args = parse_args()
assert args.num_prompt_frames <= args.window_size
if not os.path.exists(args.checkpoint_dir + "/config.json"):
# search and find the latest modified checkpoint folder
dirs = [os.path.join(args.checkpoint_dir, f.name) for f in os.scandir(args.checkpoint_dir) if f.is_dir()]
dirs.sort(key=os.path.getctime)
if len(dirs) == 0:
print(f"No checkpoint directories found in {args.checkpoint_dir}")
sys.exit(1)
args.checkpoint_dir = dirs[-1]
dataset = re.search(r"data/(.*?)_magvit", args.val_data_dir).group(1)
# HACK HERE TO DO MULTIPLE DATASETS FOR THE SAME DATASET
if "robomimic" in args.val_data_dir:
dataset = "robomimic"
# Load the model checkpoint
if not args.use_feature:
print(f"loading STMaskGIT")
model = STMaskGIT.from_pretrained(args.checkpoint_dir).to("cuda")
stride = compute_stride_from_model(model, dataset)
val_dataset = RawTokenDataset(args.val_data_dir, window_size=args.window_size,
compute_stride_from_freq_table=False,
stride=stride,
use_actions=model.config.use_actions)
else:
print(f"loading STMAR")
model = STMAR.from_pretrained(args.checkpoint_dir).to("cuda")
stride = compute_stride_from_model(model, dataset)
args.val_data_dir = args.val_data_dir.replace("magvit", "vae")
val_dataset = RawFeatureDataset(args.val_data_dir,
compute_stride_from_freq_table=False,
stride=stride, window_size=args.window_size,
use_actions=model.config.use_actions)
val_dataset.metadata["token_dtype"] = "float32"
latent_side_len = val_dataset.data.shape[-1] # assume square
dataloader = DataLoader(val_dataset, collate_fn=default_data_collator, batch_size=args.batch_size)
# Get single example
if args.max_example > len(val_dataset):
print(f"Example index {args.example_ind} is out of bounds for dataset of length {len(val_dataset)}")
sys.exit(1)
model.eval()
output_list = []
for batch_idx, batch in enumerate(dataloader):
samples = []
if args.use_feature:
example_THW = rearrange(batch["input_ids"].to("cuda"), "b (t h w) c -> b t h w c", t=args.window_size,
h=latent_side_len, w=latent_side_len)
else:
example_THW = rearrange(batch["input_ids"].to("cuda"), "b (t h w) -> b t h w", t=args.window_size,
h=latent_side_len, w=latent_side_len)
example_actions = None
domain = None
if model.config.use_actions:
example_actions = batch["action_ids"].to("cuda")
domain = [val_dataset.name.replace("_noquant", "")] * args.batch_size
prompt_THW = example_THW.clone()
prompt_THW[:, args.num_prompt_frames:] = model.mask_token if args.use_feature else model.mask_token_id
for timestep in range(args.num_prompt_frames, args.window_size):
# Teacher-forced, maskgit generation
if args.teacher_force_time:
prompt_THW = example_THW.clone()
# Masked prediction for this timestep only, after which we provide ground-truth
prompt_THW[:, timestep:] = model.mask_token if args.use_feature else model.mask_token_id
samples_HW, _, _ = model.maskgit_generate(
prompt_THW, out_t=timestep, temperature=args.temperature,
action_ids=example_actions, domain=domain
)
samples.append(samples_HW)
if not args.teacher_force_time:
# autoregressive
prompt_THW[:, timestep] = samples_HW
outputs = torch.stack(samples, dim=1)
# prepend prompt sequence
outputs = torch.cat([example_THW[:, :args.num_prompt_frames], outputs], dim=1)
# append ground-truth targets next to generated outputs for comic strip generation
# [<prompt frames><predicted frames><ground truth frames>]
outputs = torch.cat([outputs, example_THW[:, args.num_prompt_frames:]], dim=1)
output_list.append(outputs)
if batch_idx >= args.max_example // args.batch_size:
break
outputs = torch.cat(output_list, dim=0)
if args.use_feature:
# use chw
outputs = rearrange(outputs, "b t h w c -> b t c h w")
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
outputs.cpu().numpy().astype(np.dtype(val_dataset.metadata["token_dtype"])).tofile(output_dir / "video.bin")
print(f"Saved generated video to {output_dir / 'video.bin'} {outputs.shape}")
model_steps = get_model_step(args.checkpoint_dir)
with open(output_dir / "metadata.json", "w") as f:
json.dump(vars(args) | val_dataset.metadata | {
"num_images": outputs.shape[1],
"h": latent_side_len,
"w": latent_side_len,
"t": args.window_size,
"model_checkpoint": args.checkpoint_dir,
"dataset": val_dataset.name,
"trained_steps": model_steps,
}, f)
if __name__ == "__main__":
main()