Spaces:
Running
on
Zero
Running
on
Zero
import numpy as np | |
import cv2 | |
import torch | |
import einops | |
from genie.st_mask_git import STMaskGIT | |
from genie.st_mar import STMAR | |
from datasets.utils import get_image_encoder | |
from diffusion_policy import diffusion_policy_factory | |
from data import DATA_FREQ_TABLE | |
from train_diffusion import SVD_SCALE | |
from typing import Optional, Tuple, List, Dict, Any | |
import os | |
class Policy: | |
def generate_action(self, obs): | |
raise NotImplementedError | |
def reset(self): | |
pass | |
class RandomPolicy(Policy): | |
def __init__(self): | |
super().__init__() | |
class TeleopPolicy(Policy): | |
def __init__(self): | |
super().__init__() | |
class LearnedPolicy(Policy): | |
def __init__(self): | |
super().__init__() | |
class ReplayPolicy(Policy): | |
def __init__(self, | |
actions: np.ndarray, # (T * S, A) | |
action_stride: int = 1, | |
prompt_horizon: int = 0, | |
): | |
super().__init__() | |
T = len(actions) // action_stride | |
self.actions = actions[:T * action_stride | |
].reshape(T, action_stride, actions.shape[-1]) | |
self.action_idx = prompt_horizon | |
self.prompt_horizon = prompt_horizon | |
self.action_stride = action_stride | |
assert self.action_idx < len(self.actions) | |
def __len__(self): | |
return len(self.actions) - self.prompt_horizon | |
def generate_action(self, obs): | |
assert self.action_idx < len(self.actions) | |
action = self.actions[self.action_idx] | |
self.action_idx = self.action_idx + 1 | |
return action | |
def reset(self): # return current action = last action of prompt | |
self.action_idx = self.prompt_horizon | |
return self.prompt()[-1] | |
def prompt(self): | |
return self.actions[:self.prompt_horizon] | |
class RandomJointPositionPolicy(RandomPolicy): | |
def __init__(self, action_bounds: Tuple[np.ndarray, np.ndarray]): | |
self.lb = action_bounds[0] | |
self.ub = action_bounds[1] | |
self.action_dim = action_bounds[0].shape[0] | |
def generate_action(self, obs): | |
return np.random.uniform(self.lb, self.ub) | |
class TeleopJointPositionPolicy(TeleopPolicy): | |
""" | |
Example usage: | |
teleop = TeleopJointPositionPolicy( | |
initial_position=[0, 0, 0, 0, 0, 0, 0], | |
increment=0.1, | |
keyboard_bindings=['q', 'w', 'e', 'r', 't', 'y', 'u'], | |
return_delta=False | |
) | |
while True: | |
print(teleop.generate_action(None)) | |
""" | |
def __init__(self, | |
initial_position: List[float], # initial position for each joint | |
increment: float, # increment for each joint | |
keyboard_bindings: List[str], # list of keyboard bindings for each joint | |
# shift + key for negative direction | |
return_delta: bool = False, # if True, return delta instead of absolute position | |
): | |
super().__init__() | |
self.increment = increment | |
self.pos_keys = keyboard_bindings | |
self.neg_keys = [self._shift_key(key) for key in keyboard_bindings] | |
self.action_dim = len(keyboard_bindings) | |
self.return_delta = return_delta | |
self.current_position = np.array(initial_position) | |
self.shift_pressed = False | |
self.delta_position = np.zeros(self.action_dim) | |
def generate_action(self, obs): | |
while (user_input := input('Waiting for input: ')) not in self.pos_keys + self.neg_keys: | |
print(f'Invalid input {user_input}') | |
is_pos = user_input in self.pos_keys | |
joint_idx = self.pos_keys.index(user_input) if is_pos else self.neg_keys.index(user_input) | |
self.delta_position[joint_idx] = self.increment * (1 if is_pos else -1) | |
curr_pos = self.current_position | |
delta_pos = self.delta_position | |
# update current position and reset delta | |
self.current_position += self.delta_position | |
self.delta_position = np.zeros(self.action_dim) | |
if self.return_delta: | |
return delta_pos | |
else: | |
return curr_pos | |
def _shift_key(self, key): | |
if key.isalpha(): | |
return key.upper() | |
return { | |
'1': '!', '2': '@', '3': '#', '4': '$', '5': '%', | |
'6': '^', '7': '&', '8': '*', '9': '(', '0': ')', | |
'-': '_', '=': '+', '[': '{', ']': '}', '\\': '|', | |
';': ':', "'": '"', ',': '<', '.': '>', '/': '?', | |
'`': '~' | |
}.get(key, key) | |
class RandomPlanarQuadDirectionalPolicy(RandomPolicy): | |
def __init__(self, increment: float = 0.5): | |
self.increment = increment | |
def generate_action(self, obs): | |
actions = [ | |
np.array([0, self.increment]), | |
np.array([0, -self.increment]), | |
np.array([self.increment, 0]), | |
np.array([-self.increment, 0]) | |
] | |
return actions[np.random.choice(4)] | |
class TeleopPlanarQuadDirectionalPolicy(TeleopPolicy): | |
# control with: w, a, s, d | |
def __init__(self, | |
increment: float = 0.5, # increment for each direction | |
): | |
super().__init__() | |
self.increment = increment | |
def generate_action(self, obs): | |
while (user_input := input('Waiting for input: ')) not in ['w', 'a', 's', 'd']: | |
print(f'Invalid input {user_input}') | |
# follow IRASIM's convention | |
if user_input == 'd': | |
return np.array([0, self.increment]) | |
elif user_input == 'a': | |
return np.array([0, -self.increment]) | |
elif user_input == 's': | |
return np.array([self.increment, 0]) | |
elif user_input == 'w': | |
return np.array([-self.increment, 0]) | |
class GeniePolicy(LearnedPolicy): | |
average_delta_psnr_over = 5 | |
def __init__(self, | |
# image preprocessing | |
max_image_resolution: int = 1024, | |
resize_image: bool = True, | |
resize_image_resolution: int = 256, | |
# tokenizer setting | |
image_encoder_type: str = "magvit", | |
image_encoder_ckpt: str = "data/magvit2.ckpt", | |
quantize: bool = True, | |
quantization_slice_size: int = 16, | |
# dynamics backbone setting | |
backbone_type: str = "stmaskgit", | |
backbone_ckpt: str = "data/genie_model/final_checkpt", | |
prompt_horizon: int = 4, | |
prediction_horizon: int = 4, | |
execution_horizon: int = 2, # half of the prediction context | |
inference_iterations: Optional[int] = None, | |
sampling_temperature: float = 0.0, | |
action_stride: Optional[int] = None, | |
domain: str = "robomimic", | |
genie_frequency: int = 2, | |
diffusion_steps=10, | |
# misc | |
is_full_dynamics: bool = False, | |
device: str = 'cuda', | |
use_raw_image=False, | |
): | |
super().__init__() | |
assert quantize == (image_encoder_type == "magvit"), \ | |
"Currently quantization if and only if magvit is the image encoder." | |
assert image_encoder_type in ["magvit", "temporalvae"], \ | |
"Image encoder type must be either 'magvit' or 'temporalvae'." | |
assert not quantize or image_encoder_type == "magvit", \ | |
"If quantize is enabled, image encoder type must be 'magvit'." | |
assert backbone_type in ["stmaskgit", "stmar"], \ | |
"Backbone type must be either 'stmaskgit' or 'stmar'." | |
if action_stride is None: | |
action_stride = DATA_FREQ_TABLE[domain] // genie_frequency | |
if inference_iterations is None: | |
if backbone_type == "stmaskgit": | |
inference_iterations = 2 | |
elif backbone_type == "stmar": | |
inference_iterations = 2 | |
# misc | |
self.use_raw_image = use_raw_image | |
self.device = torch.device(device) | |
self.is_full_dynamics = is_full_dynamics | |
self.prediction_horizon = prediction_horizon | |
self.execution_horizon = execution_horizon | |
self.open_loop_step = self.execution_horizon - 1 | |
# image preprocessing | |
self.max_image_resolution = max_image_resolution | |
self.resize_image = resize_image | |
self.resize_image_resolution = resize_image_resolution | |
# load image encoder | |
self.image_encoding_dtype = torch.bfloat16 | |
self.quantize = quantize | |
self.quant_slice_size = quantization_slice_size | |
self.image_encoder_type = image_encoder_type | |
self.image_encoder = get_image_encoder( | |
image_encoder_type, | |
image_encoder_ckpt | |
).to(device=self.device, dtype=self.image_encoding_dtype).eval() | |
# load STMaskGIT model (STMAR is inherited from STMaskGIT) | |
self.prompt_horizon = prompt_horizon | |
self.domain = domain | |
self.genie_frequency = genie_frequency | |
self.inference_iterations = inference_iterations | |
self.sampling_temperature = sampling_temperature | |
self.action_stride = action_stride | |
self.backbone_type = backbone_type | |
if not os.path.exists(backbone_ckpt + "/config.json"): | |
# search and find the latest modified checkpoint folder | |
dirs = [os.path.join(backbone_ckpt, f.name) for f in os.scandir(backbone_ckpt) if f.is_dir()] | |
dirs.sort(key=os.path.getctime) | |
backbone_ckpt = dirs[-1] | |
print("backbone_ckpt:", backbone_ckpt) | |
if backbone_type == "stmaskgit": | |
self.backbone = STMaskGIT.from_pretrained(backbone_ckpt) | |
else: | |
self.backbone = STMAR.from_pretrained(backbone_ckpt) | |
self.backbone.action_diff_losses[domain].gen_diffusion.num_timesteps = diffusion_steps | |
self.backbone.diffloss.gen_diffusion.num_timesteps = diffusion_steps | |
self.backbone = self.backbone.to(device=self.device).eval() | |
# history buffer, i.e., the input to the model | |
self.cached_actions = None # (prompt_horizon, action_stride, A) | |
self.cached_latent_frames = None # (prompt_horizon, ...) | |
self.init_prompt = None # (prompt_frames, prompt_actions) | |
# report model size | |
print( | |
"================ Model Size Report ================\n" | |
f" encoder size: {sum(p.numel() for p in self.image_encoder.parameters()) / 1e6:.3f}M \n" | |
f" backbone size: {sum(p.numel() for p in self.backbone.parameters()) / 1e6:.3f}M\n" | |
"===================================================" | |
) | |
def set_initial_state(self, state: Tuple[np.ndarray, np.ndarray]): | |
self.init_prompt = state | |
def generate_action(self, obs: Dict[str, Any]) -> np.ndarray: | |
# obs: {'image': np.ndarray (H, W, 3), ...} | |
# return: np.ndarray (stride, A) | |
assert self.cached_latent_frames is not None, "Model is not prompted yet." | |
this_image = obs['image'] | |
# encode | |
this_latent = self._encode_image(this_image) | |
# update cache for the current image. prompt_horizon+1 timesteps | |
self.cached_latent_frames = torch.cat([self.cached_latent_frames, this_latent.unsqueeze(0)]).to(torch.float32) | |
# new video tokens. s_t-prompt_horizon to s_t+1, s_t+1 to s_t+execution_horizon are masked tokens | |
mask_tokens = torch.zeros(self.execution_horizon - 1, *self.cached_latent_frames.shape[1:], | |
dtype=self.cached_latent_frames.dtype, | |
device=self.device) | |
input_latent_states = torch.cat([self.cached_latent_frames, mask_tokens]).unsqueeze(0).to(torch.float32) # add batch dimension | |
# new action tokens. a_t-prompt_horizon to a_t, a_t to a_t+execution_horizon are masked tokens | |
self.cached_actions = torch.cat([ | |
self.cached_actions, torch.zeros(self.execution_horizon, *self.cached_actions.shape[1:], | |
dtype=self.cached_actions.dtype, | |
device=self.device)]).to(torch.float32) | |
cached_actions = einops.rearrange(self.cached_actions, "h b c -> b h c") | |
action_mask = torch.zeros(cached_actions.shape[0], cached_actions.shape[1], 1, 1, | |
dtype=self.cached_actions.dtype, device=self.device) | |
action_mask[:, self.prompt_horizon:] = 1 | |
# dtype conversion and mask token | |
if self.backbone_type == "stmaskgit": | |
input_latent_states = input_latent_states.long() | |
input_latent_states[:, self.prompt_horizon + 1:] = self.backbone.mask_token_id | |
# we should experiment with the other way to do this as well | |
# cached_actions[:, self.prompt_horizon:] = self.backbone.action_mask_tokens | |
elif self.backbone_type == "stmar": | |
input_latent_states[:, self.prompt_horizon + 1:] = self.backbone.mask_token | |
# cached_actions[:, self.prompt_horizon:] = self.backbone.action_mask_tokens | |
if self.open_loop_step != self.execution_horizon - 1: | |
self.open_loop_step += 1 | |
else: | |
cached_actions = cached_actions[:, -input_latent_states.shape[1]:] | |
if self.execution_horizon == 1: | |
self.pred_action = self.backbone.maskgit_generate( | |
input_latent_states, | |
out_t=self.prompt_horizon, | |
maskgit_steps=self.inference_iterations, | |
temperature=self.sampling_temperature, | |
action_ids=cached_actions, # if self.is_full_dynamics else None | |
domain=[self.domain], | |
action_mask=action_mask | |
)[-1].squeeze(0) | |
else: | |
self.pred_action = self.backbone.maskgit_generate_horizon( | |
input_latent_states, | |
out_t_min=self.prompt_horizon, | |
out_t_max=self.prompt_horizon + self.execution_horizon, | |
maskgit_steps=self.inference_iterations, | |
temperature=self.sampling_temperature, | |
action_ids=cached_actions, # if self.is_full_dynamics else None | |
domain=[self.domain], | |
action_mask=action_mask | |
)[-1].squeeze(0) | |
self.open_loop_step = 0 | |
pred_action = self.pred_action[self.prompt_horizon+self.open_loop_step:self.prompt_horizon+self.open_loop_step+1] | |
self.cached_actions = torch.cat([self.cached_actions, pred_action.unsqueeze(0)]).to(torch.float32) | |
self.cached_actions = self.cached_actions[-self.prompt_horizon:] | |
self.cached_latent_frames = self.cached_latent_frames[-self.prompt_horizon:] | |
return pred_action.detach().cpu().numpy() | |
def _encode_image(self, image: np.ndarray) -> torch.Tensor: | |
# (H, W, 3) | |
if self.quantize: | |
image = torch.from_numpy( | |
self._normalize_image(image).transpose(2, 0, 1) | |
).to(device=self.device, dtype=self.image_encoding_dtype | |
).unsqueeze(0) | |
H, W = image.shape[-2:] | |
H //= self.quant_slice_size | |
W //= self.quant_slice_size | |
_, _, indices, _ = self.image_encoder.encode(image, flip=True) | |
indices = einops.rearrange(indices, "(h w) -> h w", h=H, w=W) | |
indices = indices.to(torch.int32) | |
return indices | |
elif self.use_raw_image: | |
image = torch.from_numpy(image).permute(2, 0, 1) | |
norm_image = torch.nn.functional.interpolate(image[None] / 255.0, (32, 32)) - 0.5 | |
norm_image = einops.rearrange(norm_image, "b c h w -> b h w c") | |
norm_image = norm_image.squeeze(0).to(torch.float32).to(self.device) | |
return norm_image | |
else: | |
image = torch.from_numpy( | |
self._normalize_image(image).transpose(2, 0, 1) | |
).to(device=self.device, dtype=self.image_encoding_dtype | |
).unsqueeze(0) | |
H, W = image.shape[-2:] | |
if self.image_encoder_type == "magvit": | |
latent = self.image_encoder.encode_without_quantize(image) | |
elif self.image_encoder_type == "temporalvae": | |
latent_dist = self.image_encoder.encode(image).latent_dist | |
latent = latent_dist.mean | |
latent *= SVD_SCALE | |
latent = einops.rearrange(latent, "b c h w -> b h w c") | |
else: | |
pass | |
latent = latent.squeeze(0).to(torch.float32) | |
return latent | |
def _normalize_image(self, image: np.ndarray) -> np.ndarray: | |
# (H, W, 3) normalized to [-1, 1] | |
# if `resize`, resize the shorter side to `resized_res` | |
# and then do a center crop | |
image = np.asarray(image, dtype=np.float32) | |
image /= 255. | |
H, W = image.shape[:2] | |
# resize if asked | |
if self.resize_image: | |
resized_res = self.resize_image_resolution | |
if H < W: | |
Hnew, Wnew = resized_res, int(resized_res * W / H) | |
else: | |
Hnew, Wnew = int(resized_res * H / W), resized_res | |
image = cv2.resize(image, (Wnew, Hnew)) | |
# center crop | |
H, W = image.shape[:2] | |
Hstart = (H - resized_res) // 2 | |
Wstart = (W - resized_res) // 2 | |
image = image[Hstart:Hstart + resized_res, Wstart:Wstart + resized_res] | |
# resize if resolution is too large | |
elif H > self.max_image_resolution or W > self.max_image_resolution: | |
if H < W: | |
Hnew, Wnew = int(self.max_image_resolution * H / W), self.max_image_resolution | |
else: | |
Hnew, Wnew = self.max_image_resolution, int(self.max_image_resolution * W / H) | |
image = cv2.resize(image, (Wnew, Hnew)) | |
image = (image * 2 - 1.) | |
return image | |
def reset(self) -> np.ndarray: | |
# if ground truth physics simulator is provided, | |
# return the the side-by-side concatenated image | |
assert self.init_prompt is not None, "Initial state is not set." | |
prompt_frames, prompt_actions = self.init_prompt | |
current_image = prompt_frames[-1] | |
prompt_actions = torch.from_numpy(prompt_actions | |
).to(device=self.device, dtype=torch.float32) | |
self.cached_actions = prompt_actions | |
# convert to latent | |
self.cached_latent_frames = torch.stack([ | |
self._encode_image(frame) for frame in prompt_frames | |
], axis=0) | |
if self.resize_image: | |
current_image = cv2.resize(current_image, | |
(self.resize_image_resolution, self.resize_image_resolution)) | |
return current_image | |
def close(self): | |
pass | |
def dt(self): | |
return 1.0 / self.genie_frequency | |
class DiffusionPolicy(LearnedPolicy): | |
def __init__(self, checkpoint: str): | |
super().__init__() | |
self.policy = diffusion_policy_factory(checkpoint) | |
def __getattr__(self, name): | |
try: | |
return self.__dict__[name] | |
except KeyError: | |
return getattr(self.policy, name) | |
def generate_action(self, obs): | |
return self.policy.predict_action(obs) | |
def reset(self): | |
pass | |
def close(self): | |
pass |