hma / sim /simulator.py
liruiw's picture
tiny fix
b4e01e8
raw
history blame
20 kB
import cv2
import torch
import numpy as np
import einops
import skimage
import time
from genie.st_mask_git import STMaskGIT
from genie.st_mar import STMAR
from datasets.utils import get_image_encoder
from data import DATA_FREQ_TABLE
from train_diffusion import SVD_SCALE
from typing import Optional, Tuple, Callable, Dict
class Simulator:
def set_initial_state(self, state):
"""
the initial state of the simulated scene
e.g.
1. in robomimic, it's the scene state vector
2. in genie, it's the initial frames to prompt the model
"""
raise NotImplementedError
@torch.inference_mode()
def step(self, action):
raise NotImplementedError
def reset(self):
raise NotImplementedError
def close(self):
raise NotImplementedError
@property
def dt(self):
raise NotImplementedError
class PhysicsSimulator(Simulator):
def __init__(self):
super().__init__()
# physics engine should be able to update dt
def set_dt(self, dt):
raise NotImplementedError
# physics engine should be able to get scene state
# e.g., robot joint positions, object positions, etc.
def get_raw_state(self, port: Optional[str] = None):
raise NotImplementedError
@property
def action_dimension(self):
raise NotImplementedError
class LearnedSimulator(Simulator):
def __init__(self):
super().__init__()
# data replayed respect physics, so we inherit from PhysicsSimulator
# it can be considered as a special case of PhysicsSimulator
class ReplaySimulator(PhysicsSimulator):
def __init__(self,
frames,
prompt_horizon: int = 0,
dt: Optional[float] = None
):
super().__init__()
self.frames = frames
self.frame_idx = prompt_horizon
assert self.frame_idx < len(self.frames)
self._dt = dt
self.prompt_horizon = prompt_horizon
def __len__(self):
return len(self.frames) - self.prompt_horizon
def step(self, action):
frame = self.frames[self.frame_idx]
assert self.frame_idx < len(self.frames)
self.frame_idx = self.frame_idx + 1
return {
'pred_next_frame': frame
}
def reset(self): # return current frame = last frame of prompt
self.frame_idx = self.prompt_horizon
return self.prompt()[-1]
def prompt(self):
return self.frames[:self.prompt_horizon]
@property
def dt(self):
return self._dt
class GenieSimulator(LearnedSimulator):
average_delta_psnr_over = 5
def __init__(self,
# image preprocessing
max_image_resolution: int = 1024,
resize_image: bool = True,
resize_image_resolution: int = 256,
# tokenizer setting
image_encoder_type: str = "temporalvae",
image_encoder_ckpt: str = "stabilityai/stable-video-diffusion-img2vid",
quantize: bool = False,
quantization_slice_size: int = 16,
# dynamics backbone setting
backbone_type: str = "stmar",
backbone_ckpt: str = "data/mar_ckpt/robomimic",
prompt_horizon: int = 11,
inference_iterations: Optional[int] = None,
sampling_temperature: float = 0.0,
action_stride: Optional[int] = None,
domain: str = "robomimic",
genie_frequency: int = 2,
# misc
measure_step_time: bool = False,
compute_psnr: bool = False,
compute_delta_psnr: bool = False, # act as a signal for controlability
gaussian_action_perturbation_scale: Optional[float] = None,
device: str = 'cuda',
physics_simulator: Optional[PhysicsSimulator] = None,
physics_simulator_teacher_force: Optional[int] = None,
post_processor: Optional[Callable] = None, # on the predicted image, e.g., add action
allow_external_prompt: bool = False
):
super().__init__()
assert quantize == (image_encoder_type == "magvit"), \
"Currently quantization if and only if magvit is the image encoder."
assert image_encoder_type in ["magvit", "temporalvae"], \
"Image encoder type must be either 'magvit' or 'temporalvae'."
assert not quantize or image_encoder_type == "magvit", \
"If quantize is enabled, image encoder type must be 'magvit'."
assert backbone_type in ["stmaskgit", "stmar"], \
"Backbone type must be either 'stmaskgit' or 'stmar'."
if physics_simulator is None:
assert physics_simulator_teacher_force is None, \
"Physics simulator teacher force is only available when physics simulator is provided."
assert compute_psnr is False, \
"PSNR computation is only available when physics simulator is provided."
assert compute_delta_psnr is False, \
"Delta PSNR computation is only available when physics simulator is provided."
if action_stride is None:
action_stride = DATA_FREQ_TABLE[domain] // genie_frequency
if compute_delta_psnr:
compute_psnr = True # to compute delta psnr, psnr must be computed
if inference_iterations is None:
if backbone_type == "stmaskgit":
inference_iterations = 2
elif backbone_type == "stmar":
inference_iterations = 2
# misc
self.device = torch.device(device)
self.measure_step_time = measure_step_time
self.compute_psnr = compute_psnr
self.compute_delta_psnr = compute_delta_psnr
self.allow_external_prompt = allow_external_prompt
# image preprocessing
self.max_image_resolution = max_image_resolution
self.resize_image = resize_image
self.resize_image_resolution = resize_image_resolution
# load image encoder
self.image_encoding_dtype = torch.bfloat16
self.quantize = quantize
self.quant_slice_size = quantization_slice_size
self.image_encoder_type = image_encoder_type
self.image_encoder = get_image_encoder(
image_encoder_type,
image_encoder_ckpt
).to(device=self.device, dtype=self.image_encoding_dtype).eval()
# load STMaskGIT model (STMAR is inherited from STMaskGIT)
self.prompt_horizon = prompt_horizon
self.domain = domain
self.genie_frequency = genie_frequency
self.inference_iterations = inference_iterations
self.sampling_temperature = sampling_temperature
self.action_stride = action_stride
self.gauss_act_perturb_scale = gaussian_action_perturbation_scale
self.backbone_type = backbone_type
if backbone_type == "stmaskgit":
self.backbone = STMaskGIT.from_pretrained(backbone_ckpt)
else:
self.backbone = STMAR.from_pretrained(backbone_ckpt)
self.backbone = self.backbone.to(device=self.device).eval()
self.post_processor = post_processor
# load physics simulator if available
# the phys sim to get ground truth image,
# assume the phys sim has aligned prompt frames
self.gt_phys_sim = physics_simulator
self.gt_teacher_force = physics_simulator_teacher_force
# history buffer, i.e., the input to the model
self.cached_actions = None # (prompt_horizon, action_stride, A)
self.cached_latent_frames = None # (prompt_horizon, ...)
self.init_prompt = None # (prompt_frames, prompt_actions)
self.step_count = 0
# report model size
print(
"================ Model Size Report ================\n"
f" encoder size: {sum(p.numel() for p in self.image_encoder.parameters()) / 1e6:.3f}M \n"
f" backbone size: {sum(p.numel() for p in self.backbone.parameters()) / 1e6:.3f}M\n"
"==================================================="
)
def set_initial_state(self, state: Tuple[np.ndarray, np.ndarray]):
if not self.allow_external_prompt and self.gt_phys_sim is not None:
raise NotImplementedError("Initial state is set by the physics simulator.")
self.init_prompt = state
@torch.inference_mode()
def step(self, action: np.ndarray) -> Dict:
# action: (action_stride, A) OR (A,)
# return: (H, W, 3)
assert self.cached_latent_frames is not None and self.cached_actions is not None, \
"Model is not prompted yet. Please call `set_initial_state` first."
if action.ndim == 1:
action = np.tile(action, (self.action_stride, 1))
# perturb action
if self.gauss_act_perturb_scale is not None:
action = np.random.normal(action, self.gauss_act_perturb_scale)
# encoding
input_latent_states = torch.cat([
self.cached_latent_frames,
torch.zeros_like(self.cached_latent_frames[[0]]),
]).unsqueeze(0).to(torch.float32)
input_latent_states = input_latent_states[:, :self.prompt_horizon + 1]
# dtype conversion and mask token
if self.backbone_type == "stmaskgit":
input_latent_states = input_latent_states.long()
input_latent_states[:, -1] = self.backbone.mask_token_id
elif self.backbone_type == "stmar":
input_latent_states[:, -1] = self.backbone.mask_token
# dynamics rollout
action = torch.from_numpy(action).to(device=self.device)
input_actions = torch.cat([ # (1, prompt_horizon + 1, action_stride * A)
self.cached_actions,
action.unsqueeze(0),
action.unsqueeze(0) # the last action is not used, but we need a_{t-1}, s_{t-1} to predict s_t
]).view(1, -1, action.shape[-1]).to(torch.float32) # + 1
input_actions = input_actions[:, :self.prompt_horizon + 1]
if self.measure_step_time:
start_time = time.time()
pred_next_latent_state = self.backbone.maskgit_generate(
input_latent_states,
out_t=input_latent_states.shape[1] - 1,
maskgit_steps=self.inference_iterations,
temperature=self.sampling_temperature,
action_ids=input_actions,
domain=[self.domain]
)[0].squeeze(0)
# decoding
pred_next_frame = self._decode_image(pred_next_latent_state)
# timing
if self.measure_step_time:
end_time = time.time()
step_result = {'pred_next_frame': pred_next_frame,}
if self.measure_step_time:
step_result['step_time'] = end_time - start_time
# physics simulation
if self.gt_phys_sim is not None:
for a in action.cpu().numpy():
gt_result = self.gt_phys_sim.step(a)
gt_next_frame = cv2.resize(gt_result['pred_next_frame'], pred_next_frame.shape[:2])
step_result['gt_next_frame'] = gt_next_frame
gt_result.pop('pred_next_frame')
step_result.update(gt_result)
# gt state observation
try:
raw_state = self.gt_phys_sim.get_raw_state()
step_result.update(raw_state)
except NotImplementedError:
pass
# compute PSNR against ground truth
if self.compute_psnr:
psnr = skimage.metrics.peak_signal_noise_ratio(
image_true=gt_next_frame / 255.,
image_test=pred_next_frame / 255.,
data_range=1.0
)
step_result['psnr'] = psnr
# controlability metric
if self.compute_delta_psnr:
delta_psnr = 0.0
for _ in range(self.average_delta_psnr_over):
# re-mask the input latent states for masked prediction
if self.backbone_type == "stmaskgit":
input_latent_states = input_latent_states.long()
input_latent_states[:, self.prompt_horizon] = self.backbone.mask_token_id
elif self.backbone_type == "stmar":
input_latent_states[:, self.prompt_horizon] = self.backbone.mask_token
# sample random action from N(0, 1)
random_input_actions = torch.randn_like(input_actions)
random_pred_next_latent_state = self.backbone.maskgit_generate(
input_latent_states,
out_t=self.prompt_horizon,
maskgit_steps=self.inference_iterations,
temperature=self.sampling_temperature,
action_ids=random_input_actions,
domain=[self.domain],
skip_normalization=True
)[0].squeeze(0)
random_pred_next_frame = self._decode_image(random_pred_next_latent_state)
this_delta_psnr = step_result['psnr'] - skimage.metrics.peak_signal_noise_ratio(
image_true=gt_next_frame / 255.,
image_test=random_pred_next_frame / 255.,
data_range=1.0
)
delta_psnr += this_delta_psnr / self.average_delta_psnr_over
step_result['delta_psnr'] = delta_psnr
if self.gt_teacher_force is not None and self.step_count % self.gt_teacher_force == 0:
pred_next_latent_state = self._encode_image(gt_next_frame)
# update history buffer
self.cached_latent_frames = torch.cat([
self.cached_latent_frames[1:], pred_next_latent_state.unsqueeze(0)
])
self.cached_actions = torch.cat([
self.cached_actions[1:], action.unsqueeze(0)
])
# post processing
if self.post_processor is not None:
pred_next_frame = self.post_processor(pred_next_frame, action)
self.step_count += 1
return step_result
@torch.inference_mode()
def _encode_image(self, image: np.ndarray) -> torch.Tensor:
# (H, W, 3)
image = torch.from_numpy(
self._normalize_image(image).transpose(2, 0, 1)
).to(device=self.device, dtype=self.image_encoding_dtype
).unsqueeze(0)
H, W = image.shape[-2:]
if self.quantize:
H //= self.quant_slice_size
W //= self.quant_slice_size
_, _, indices, _ = self.image_encoder.encode(image, flip=True)
indices = einops.rearrange(indices, "(h w) -> h w", h=H, w=W)
indices = indices.to(torch.int32)
return indices
else:
if self.image_encoder_type == "magvit":
latent = self.image_encoder.encode_without_quantize(image)
elif self.image_encoder_type == "temporalvae":
latent_dist = self.image_encoder.encode(image).latent_dist
latent = latent_dist.mean
latent *= SVD_SCALE
latent = einops.rearrange(latent, "b c h w -> b h w c")
else:
pass
latent = latent.squeeze(0).to(torch.float32)
return latent
@torch.inference_mode()
def _decode_image(self, latent: torch.Tensor) -> np.ndarray:
# latent can be either quantized indices or raw latent
# return (H, W, 3)
latent = latent.to(device=self.device).unsqueeze(0)
if self.quantize:
latent = self.image_encoder.quantize.get_codebook_entry(
einops.rearrange(latent, "b h w -> b (h w)"),
bhwc=(*latent.shape, self.image_encoder.quantize.codebook_dim)
).flip(1)
latent = latent.to(device=self.device, dtype=self.image_encoding_dtype)
if self.image_encoder_type == "magvit":
decoded_image = self.image_encoder.decode(latent)
elif self.image_encoder_type == "temporalvae":
latent = einops.rearrange(latent, "b h w c -> b c h w")
latent /= SVD_SCALE
# HACK: clip for less visual artifacts
latent = torch.clamp(latent, -25, 25)
decoded_image = self.image_encoder.decode(latent, num_frames=1).sample
decoded_image = decoded_image.squeeze(0).to(torch.float32).detach().cpu().numpy()
decoded_image = self._unnormalize_image(decoded_image).transpose(1, 2, 0)
return decoded_image
def _normalize_image(self, image: np.ndarray) -> np.ndarray:
# (H, W, 3) normalized to [-1, 1]
# if `resize`, resize the shorter side to `resized_res`
# and then do a center crop
image = np.asarray(image, dtype=np.float32)
image /= 255.
H, W = image.shape[:2]
# resize if asked
if self.resize_image:
resized_res = self.resize_image_resolution
if H < W:
Hnew, Wnew = resized_res, int(resized_res * W / H)
else:
Hnew, Wnew = int(resized_res * H / W), resized_res
image = cv2.resize(image, (Wnew, Hnew))
# center crop
H, W = image.shape[:2]
Hstart = (H - resized_res) // 2
Wstart = (W - resized_res) // 2
image = image[Hstart:Hstart + resized_res, Wstart:Wstart + resized_res]
# resize if resolution is too large
elif H > self.max_image_resolution or W > self.max_image_resolution:
if H < W:
Hnew, Wnew = int(self.max_image_resolution * H / W), self.max_image_resolution
else:
Hnew, Wnew = self.max_image_resolution, int(self.max_image_resolution * W / H)
image = cv2.resize(image, (Wnew, Hnew))
image = (image * 2 - 1.)
return image
def _unnormalize_image(self, image: np.ndarray) -> np.ndarray:
# (H, W, 3) from [-1, 1] to [0, 255]
# NOTE: clip happens here
image = (image + 1.) * 127.5
image = np.clip(image, 0, 255).astype(np.uint8)
return image
def reset(self) -> np.ndarray:
# if ground truth physics simulator is provided,
# return the the side-by-side concatenated image
# get the initial prompt from the physics simulator if not yet set
if not self.allow_external_prompt and self.gt_phys_sim is not None:
image_prompt = np.tile(
self.gt_phys_sim.reset(), (self.prompt_horizon, 1, 1, 1)
).astype(np.uint8)
action_prompt = np.zeros(
(self.prompt_horizon, self.action_stride, self.gt_phys_sim.action_dimension)
).astype(np.float32)
else:
assert self.init_prompt is not None, "Initial state is not set."
image_prompt, action_prompt = self.init_prompt
# standardize the image
image_prompt = [self._unnormalize_image(self._normalize_image(frame)) for frame in image_prompt]
current_image = image_prompt[-1]
action_prompt = torch.from_numpy(action_prompt).to(device=self.device)
self.cached_actions = action_prompt
# convert to latent
self.cached_latent_frames = torch.stack([
self._encode_image(frame) for frame in image_prompt
], axis=0)
if self.resize_image:
current_image = cv2.resize(current_image,
(self.resize_image_resolution, self.resize_image_resolution))
if self.gt_phys_sim is not None:
current_image = np.concatenate([current_image, current_image], axis=1)
self.step_count = 0
return current_image
def close(self):
if self.gt_phys_sim is not None:
try:
self.gt_phys_sim.close()
except NotImplementedError:
pass
@property
def dt(self):
return 1.0 / self.genie_frequency