Spaces:
Running
on
Zero
Running
on
Zero
import cv2 | |
import torch | |
import numpy as np | |
import einops | |
import skimage | |
import time | |
from genie.st_mask_git import STMaskGIT | |
from genie.st_mar import STMAR | |
from datasets.utils import get_image_encoder | |
from data import DATA_FREQ_TABLE | |
from train_diffusion import SVD_SCALE | |
from typing import Optional, Tuple, Callable, Dict | |
class Simulator: | |
def set_initial_state(self, state): | |
""" | |
the initial state of the simulated scene | |
e.g. | |
1. in robomimic, it's the scene state vector | |
2. in genie, it's the initial frames to prompt the model | |
""" | |
raise NotImplementedError | |
def step(self, action): | |
raise NotImplementedError | |
def reset(self): | |
raise NotImplementedError | |
def close(self): | |
raise NotImplementedError | |
def dt(self): | |
raise NotImplementedError | |
class PhysicsSimulator(Simulator): | |
def __init__(self): | |
super().__init__() | |
# physics engine should be able to update dt | |
def set_dt(self, dt): | |
raise NotImplementedError | |
# physics engine should be able to get scene state | |
# e.g., robot joint positions, object positions, etc. | |
def get_raw_state(self, port: Optional[str] = None): | |
raise NotImplementedError | |
def action_dimension(self): | |
raise NotImplementedError | |
class LearnedSimulator(Simulator): | |
def __init__(self): | |
super().__init__() | |
# data replayed respect physics, so we inherit from PhysicsSimulator | |
# it can be considered as a special case of PhysicsSimulator | |
class ReplaySimulator(PhysicsSimulator): | |
def __init__(self, | |
frames, | |
prompt_horizon: int = 0, | |
dt: Optional[float] = None | |
): | |
super().__init__() | |
self.frames = frames | |
self.frame_idx = prompt_horizon | |
assert self.frame_idx < len(self.frames) | |
self._dt = dt | |
self.prompt_horizon = prompt_horizon | |
def __len__(self): | |
return len(self.frames) - self.prompt_horizon | |
def step(self, action): | |
frame = self.frames[self.frame_idx] | |
assert self.frame_idx < len(self.frames) | |
self.frame_idx = self.frame_idx + 1 | |
return { | |
'pred_next_frame': frame | |
} | |
def reset(self): # return current frame = last frame of prompt | |
self.frame_idx = self.prompt_horizon | |
return self.prompt()[-1] | |
def prompt(self): | |
return self.frames[:self.prompt_horizon] | |
def dt(self): | |
return self._dt | |
class GenieSimulator(LearnedSimulator): | |
average_delta_psnr_over = 5 | |
def __init__(self, | |
# image preprocessing | |
max_image_resolution: int = 1024, | |
resize_image: bool = True, | |
resize_image_resolution: int = 256, | |
# tokenizer setting | |
image_encoder_type: str = "temporalvae", | |
image_encoder_ckpt: str = "stabilityai/stable-video-diffusion-img2vid", | |
quantize: bool = False, | |
quantization_slice_size: int = 16, | |
# dynamics backbone setting | |
backbone_type: str = "stmar", | |
backbone_ckpt: str = "data/mar_ckpt/robomimic", | |
prompt_horizon: int = 11, | |
inference_iterations: Optional[int] = None, | |
sampling_temperature: float = 0.0, | |
action_stride: Optional[int] = None, | |
domain: str = "robomimic", | |
genie_frequency: int = 2, | |
# misc | |
measure_step_time: bool = False, | |
compute_psnr: bool = False, | |
compute_delta_psnr: bool = False, # act as a signal for controlability | |
gaussian_action_perturbation_scale: Optional[float] = None, | |
device: str = 'cuda', | |
physics_simulator: Optional[PhysicsSimulator] = None, | |
physics_simulator_teacher_force: Optional[int] = None, | |
post_processor: Optional[Callable] = None, # on the predicted image, e.g., add action | |
allow_external_prompt: bool = False | |
): | |
super().__init__() | |
assert quantize == (image_encoder_type == "magvit"), \ | |
"Currently quantization if and only if magvit is the image encoder." | |
assert image_encoder_type in ["magvit", "temporalvae"], \ | |
"Image encoder type must be either 'magvit' or 'temporalvae'." | |
assert not quantize or image_encoder_type == "magvit", \ | |
"If quantize is enabled, image encoder type must be 'magvit'." | |
assert backbone_type in ["stmaskgit", "stmar"], \ | |
"Backbone type must be either 'stmaskgit' or 'stmar'." | |
if physics_simulator is None: | |
assert physics_simulator_teacher_force is None, \ | |
"Physics simulator teacher force is only available when physics simulator is provided." | |
assert compute_psnr is False, \ | |
"PSNR computation is only available when physics simulator is provided." | |
assert compute_delta_psnr is False, \ | |
"Delta PSNR computation is only available when physics simulator is provided." | |
if action_stride is None: | |
action_stride = DATA_FREQ_TABLE[domain] // genie_frequency | |
if compute_delta_psnr: | |
compute_psnr = True # to compute delta psnr, psnr must be computed | |
if inference_iterations is None: | |
if backbone_type == "stmaskgit": | |
inference_iterations = 2 | |
elif backbone_type == "stmar": | |
inference_iterations = 2 | |
# misc | |
self.device = torch.device(device) | |
self.measure_step_time = measure_step_time | |
self.compute_psnr = compute_psnr | |
self.compute_delta_psnr = compute_delta_psnr | |
self.allow_external_prompt = allow_external_prompt | |
# image preprocessing | |
self.max_image_resolution = max_image_resolution | |
self.resize_image = resize_image | |
self.resize_image_resolution = resize_image_resolution | |
# load image encoder | |
self.image_encoding_dtype = torch.bfloat16 | |
self.quantize = quantize | |
self.quant_slice_size = quantization_slice_size | |
self.image_encoder_type = image_encoder_type | |
self.image_encoder = get_image_encoder( | |
image_encoder_type, | |
image_encoder_ckpt | |
).to(device=self.device, dtype=self.image_encoding_dtype).eval() | |
# load STMaskGIT model (STMAR is inherited from STMaskGIT) | |
self.prompt_horizon = prompt_horizon | |
self.domain = domain | |
self.genie_frequency = genie_frequency | |
self.inference_iterations = inference_iterations | |
self.sampling_temperature = sampling_temperature | |
self.action_stride = action_stride | |
self.gauss_act_perturb_scale = gaussian_action_perturbation_scale | |
self.backbone_type = backbone_type | |
if backbone_type == "stmaskgit": | |
self.backbone = STMaskGIT.from_pretrained(backbone_ckpt) | |
else: | |
self.backbone = STMAR.from_pretrained(backbone_ckpt) | |
self.backbone = self.backbone.to(device=self.device).eval() | |
self.post_processor = post_processor | |
# load physics simulator if available | |
# the phys sim to get ground truth image, | |
# assume the phys sim has aligned prompt frames | |
self.gt_phys_sim = physics_simulator | |
self.gt_teacher_force = physics_simulator_teacher_force | |
# history buffer, i.e., the input to the model | |
self.cached_actions = None # (prompt_horizon, action_stride, A) | |
self.cached_latent_frames = None # (prompt_horizon, ...) | |
self.init_prompt = None # (prompt_frames, prompt_actions) | |
self.step_count = 0 | |
# report model size | |
print( | |
"================ Model Size Report ================\n" | |
f" encoder size: {sum(p.numel() for p in self.image_encoder.parameters()) / 1e6:.3f}M \n" | |
f" backbone size: {sum(p.numel() for p in self.backbone.parameters()) / 1e6:.3f}M\n" | |
"===================================================" | |
) | |
def set_initial_state(self, state: Tuple[np.ndarray, np.ndarray]): | |
if not self.allow_external_prompt and self.gt_phys_sim is not None: | |
raise NotImplementedError("Initial state is set by the physics simulator.") | |
self.init_prompt = state | |
def step(self, action: np.ndarray) -> Dict: | |
# action: (action_stride, A) OR (A,) | |
# return: (H, W, 3) | |
assert self.cached_latent_frames is not None and self.cached_actions is not None, \ | |
"Model is not prompted yet. Please call `set_initial_state` first." | |
if action.ndim == 1: | |
action = np.tile(action, (self.action_stride, 1)) | |
# perturb action | |
if self.gauss_act_perturb_scale is not None: | |
action = np.random.normal(action, self.gauss_act_perturb_scale) | |
# encoding | |
input_latent_states = torch.cat([ | |
self.cached_latent_frames, | |
torch.zeros_like(self.cached_latent_frames[-1:]), | |
]).unsqueeze(0).to(torch.float32) | |
# dtype conversion and mask token | |
if self.backbone_type == "stmaskgit": | |
input_latent_states = input_latent_states.long() | |
input_latent_states[:, self.prompt_horizon] = self.backbone.mask_token_id | |
elif self.backbone_type == "stmar": | |
input_latent_states[:, self.prompt_horizon] = self.backbone.mask_token | |
# dynamics rollout | |
action = torch.from_numpy(action).to(device=self.device) | |
input_actions = torch.cat([ # (1, prompt_horizon + 1, action_stride * A) | |
self.cached_actions, | |
action.unsqueeze(0) | |
]).view(1, self.prompt_horizon + 1, -1).to(torch.float32) | |
if self.measure_step_time: | |
start_time = time.time() | |
pred_next_latent_state = self.backbone.maskgit_generate( | |
input_latent_states, | |
out_t=self.prompt_horizon, | |
maskgit_steps=self.inference_iterations, | |
temperature=self.sampling_temperature, | |
action_ids=input_actions, | |
domain=[self.domain] | |
)[0].squeeze(0) | |
# decoding | |
pred_next_frame = self._decode_image(pred_next_latent_state) | |
# timing | |
if self.measure_step_time: | |
end_time = time.time() | |
step_result = {'pred_next_frame': pred_next_frame,} | |
if self.measure_step_time: | |
step_result['step_time'] = end_time - start_time | |
# physics simulation | |
if self.gt_phys_sim is not None: | |
for a in action.cpu().numpy(): | |
gt_result = self.gt_phys_sim.step(a) | |
gt_next_frame = cv2.resize(gt_result['pred_next_frame'], pred_next_frame.shape[:2]) | |
step_result['gt_next_frame'] = gt_next_frame | |
gt_result.pop('pred_next_frame') | |
step_result.update(gt_result) | |
# gt state observation | |
try: | |
raw_state = self.gt_phys_sim.get_raw_state() | |
step_result.update(raw_state) | |
except NotImplementedError: | |
pass | |
# compute PSNR against ground truth | |
if self.compute_psnr: | |
psnr = skimage.metrics.peak_signal_noise_ratio( | |
image_true=gt_next_frame / 255., | |
image_test=pred_next_frame / 255., | |
data_range=1.0 | |
) | |
step_result['psnr'] = psnr | |
# controlability metric | |
if self.compute_delta_psnr: | |
delta_psnr = 0.0 | |
for _ in range(self.average_delta_psnr_over): | |
# re-mask the input latent states for masked prediction | |
if self.backbone_type == "stmaskgit": | |
input_latent_states = input_latent_states.long() | |
input_latent_states[:, self.prompt_horizon] = self.backbone.mask_token_id | |
elif self.backbone_type == "stmar": | |
input_latent_states[:, self.prompt_horizon] = self.backbone.mask_token | |
# sample random action from N(0, 1) | |
random_input_actions = torch.randn_like(input_actions) | |
random_pred_next_latent_state = self.backbone.maskgit_generate( | |
input_latent_states, | |
out_t=self.prompt_horizon, | |
maskgit_steps=self.inference_iterations, | |
temperature=self.sampling_temperature, | |
action_ids=random_input_actions, | |
domain=[self.domain], | |
skip_normalization=True | |
)[0].squeeze(0) | |
random_pred_next_frame = self._decode_image(random_pred_next_latent_state) | |
this_delta_psnr = step_result['psnr'] - skimage.metrics.peak_signal_noise_ratio( | |
image_true=gt_next_frame / 255., | |
image_test=random_pred_next_frame / 255., | |
data_range=1.0 | |
) | |
delta_psnr += this_delta_psnr / self.average_delta_psnr_over | |
step_result['delta_psnr'] = delta_psnr | |
if self.gt_teacher_force is not None and self.step_count % self.gt_teacher_force == 0: | |
pred_next_latent_state = self._encode_image(gt_next_frame) | |
# update history buffer | |
self.cached_latent_frames = torch.cat([ | |
self.cached_latent_frames[1:], pred_next_latent_state.unsqueeze(0) | |
]) | |
self.cached_actions = torch.cat([ | |
self.cached_actions[1:], action.unsqueeze(0) | |
]) | |
# post processing | |
if self.post_processor is not None: | |
pred_next_frame = self.post_processor(pred_next_frame, action) | |
self.step_count += 1 | |
return step_result | |
def _encode_image(self, image: np.ndarray) -> torch.Tensor: | |
# (H, W, 3) | |
image = torch.from_numpy( | |
self._normalize_image(image).transpose(2, 0, 1) | |
).to(device=self.device, dtype=self.image_encoding_dtype | |
).unsqueeze(0) | |
H, W = image.shape[-2:] | |
if self.quantize: | |
H //= self.quant_slice_size | |
W //= self.quant_slice_size | |
_, _, indices, _ = self.image_encoder.encode(image, flip=True) | |
indices = einops.rearrange(indices, "(h w) -> h w", h=H, w=W) | |
indices = indices.to(torch.int32) | |
return indices | |
else: | |
if self.image_encoder_type == "magvit": | |
latent = self.image_encoder.encode_without_quantize(image) | |
elif self.image_encoder_type == "temporalvae": | |
latent_dist = self.image_encoder.encode(image).latent_dist | |
latent = latent_dist.mean | |
latent *= SVD_SCALE | |
latent = einops.rearrange(latent, "b c h w -> b h w c") | |
else: | |
pass | |
latent = latent.squeeze(0).to(torch.float32) | |
return latent | |
def _decode_image(self, latent: torch.Tensor) -> np.ndarray: | |
# latent can be either quantized indices or raw latent | |
# return (H, W, 3) | |
latent = latent.to(device=self.device).unsqueeze(0) | |
if self.quantize: | |
latent = self.image_encoder.quantize.get_codebook_entry( | |
einops.rearrange(latent, "b h w -> b (h w)"), | |
bhwc=(*latent.shape, self.image_encoder.quantize.codebook_dim) | |
).flip(1) | |
latent = latent.to(device=self.device, dtype=self.image_encoding_dtype) | |
if self.image_encoder_type == "magvit": | |
decoded_image = self.image_encoder.decode(latent) | |
elif self.image_encoder_type == "temporalvae": | |
latent = einops.rearrange(latent, "b h w c -> b c h w") | |
latent /= SVD_SCALE | |
# HACK: clip for less visual artifacts | |
latent = torch.clamp(latent, -25, 25) | |
decoded_image = self.image_encoder.decode(latent, num_frames=1).sample | |
decoded_image = decoded_image.squeeze(0).to(torch.float32).detach().cpu().numpy() | |
decoded_image = self._unnormalize_image(decoded_image).transpose(1, 2, 0) | |
return decoded_image | |
def _normalize_image(self, image: np.ndarray) -> np.ndarray: | |
# (H, W, 3) normalized to [-1, 1] | |
# if `resize`, resize the shorter side to `resized_res` | |
# and then do a center crop | |
image = np.asarray(image, dtype=np.float32) | |
image /= 255. | |
H, W = image.shape[:2] | |
# resize if asked | |
if self.resize_image: | |
resized_res = self.resize_image_resolution | |
if H < W: | |
Hnew, Wnew = resized_res, int(resized_res * W / H) | |
else: | |
Hnew, Wnew = int(resized_res * H / W), resized_res | |
image = cv2.resize(image, (Wnew, Hnew)) | |
# center crop | |
H, W = image.shape[:2] | |
Hstart = (H - resized_res) // 2 | |
Wstart = (W - resized_res) // 2 | |
image = image[Hstart:Hstart + resized_res, Wstart:Wstart + resized_res] | |
# resize if resolution is too large | |
elif H > self.max_image_resolution or W > self.max_image_resolution: | |
if H < W: | |
Hnew, Wnew = int(self.max_image_resolution * H / W), self.max_image_resolution | |
else: | |
Hnew, Wnew = self.max_image_resolution, int(self.max_image_resolution * W / H) | |
image = cv2.resize(image, (Wnew, Hnew)) | |
image = (image * 2 - 1.) | |
return image | |
def _unnormalize_image(self, image: np.ndarray) -> np.ndarray: | |
# (H, W, 3) from [-1, 1] to [0, 255] | |
# NOTE: clip happens here | |
image = (image + 1.) * 127.5 | |
image = np.clip(image, 0, 255).astype(np.uint8) | |
return image | |
def reset(self) -> np.ndarray: | |
# if ground truth physics simulator is provided, | |
# return the the side-by-side concatenated image | |
# get the initial prompt from the physics simulator if not yet set | |
if not self.allow_external_prompt and self.gt_phys_sim is not None: | |
image_prompt = np.tile( | |
self.gt_phys_sim.reset(), (self.prompt_horizon, 1, 1, 1) | |
).astype(np.uint8) | |
action_prompt = np.zeros( | |
(self.prompt_horizon, self.action_stride, self.gt_phys_sim.action_dimension) | |
).astype(np.float32) | |
else: | |
assert self.init_prompt is not None, "Initial state is not set." | |
image_prompt, action_prompt = self.init_prompt | |
# standardize the image | |
image_prompt = [self._unnormalize_image(self._normalize_image(frame)) for frame in image_prompt] | |
current_image = image_prompt[-1] | |
action_prompt = torch.from_numpy(action_prompt).to(device=self.device) | |
self.cached_actions = action_prompt | |
# convert to latent | |
self.cached_latent_frames = torch.stack([ | |
self._encode_image(frame) for frame in image_prompt | |
], axis=0) | |
if self.resize_image: | |
current_image = cv2.resize(current_image, | |
(self.resize_image_resolution, self.resize_image_resolution)) | |
if self.gt_phys_sim is not None: | |
current_image = np.concatenate([current_image, current_image], axis=1) | |
self.step_count = 0 | |
return current_image | |
def close(self): | |
if self.gt_phys_sim is not None: | |
try: | |
self.gt_phys_sim.close() | |
except NotImplementedError: | |
pass | |
def dt(self): | |
return 1.0 / self.genie_frequency | |