hma / sim /policy.py
LeroyWaa's picture
draft
246c106
import numpy as np
import cv2
import torch
import einops
from genie.st_mask_git import STMaskGIT
from genie.st_mar import STMAR
from datasets.utils import get_image_encoder
from diffusion_policy import diffusion_policy_factory
from data import DATA_FREQ_TABLE
from train_diffusion import SVD_SCALE
from typing import Optional, Tuple, List, Dict, Any
import os
class Policy:
def generate_action(self, obs):
raise NotImplementedError
def reset(self):
pass
class RandomPolicy(Policy):
def __init__(self):
super().__init__()
class TeleopPolicy(Policy):
def __init__(self):
super().__init__()
class LearnedPolicy(Policy):
def __init__(self):
super().__init__()
class ReplayPolicy(Policy):
def __init__(self,
actions: np.ndarray, # (T * S, A)
action_stride: int = 1,
prompt_horizon: int = 0,
):
super().__init__()
T = len(actions) // action_stride
self.actions = actions[:T * action_stride
].reshape(T, action_stride, actions.shape[-1])
self.action_idx = prompt_horizon
self.prompt_horizon = prompt_horizon
self.action_stride = action_stride
assert self.action_idx < len(self.actions)
def __len__(self):
return len(self.actions) - self.prompt_horizon
def generate_action(self, obs):
assert self.action_idx < len(self.actions)
action = self.actions[self.action_idx]
self.action_idx = self.action_idx + 1
return action
def reset(self): # return current action = last action of prompt
self.action_idx = self.prompt_horizon
return self.prompt()[-1]
def prompt(self):
return self.actions[:self.prompt_horizon]
class RandomJointPositionPolicy(RandomPolicy):
def __init__(self, action_bounds: Tuple[np.ndarray, np.ndarray]):
self.lb = action_bounds[0]
self.ub = action_bounds[1]
self.action_dim = action_bounds[0].shape[0]
def generate_action(self, obs):
return np.random.uniform(self.lb, self.ub)
class TeleopJointPositionPolicy(TeleopPolicy):
"""
Example usage:
teleop = TeleopJointPositionPolicy(
initial_position=[0, 0, 0, 0, 0, 0, 0],
increment=0.1,
keyboard_bindings=['q', 'w', 'e', 'r', 't', 'y', 'u'],
return_delta=False
)
while True:
print(teleop.generate_action(None))
"""
def __init__(self,
initial_position: List[float], # initial position for each joint
increment: float, # increment for each joint
keyboard_bindings: List[str], # list of keyboard bindings for each joint
# shift + key for negative direction
return_delta: bool = False, # if True, return delta instead of absolute position
):
super().__init__()
self.increment = increment
self.pos_keys = keyboard_bindings
self.neg_keys = [self._shift_key(key) for key in keyboard_bindings]
self.action_dim = len(keyboard_bindings)
self.return_delta = return_delta
self.current_position = np.array(initial_position)
self.shift_pressed = False
self.delta_position = np.zeros(self.action_dim)
def generate_action(self, obs):
while (user_input := input('Waiting for input: ')) not in self.pos_keys + self.neg_keys:
print(f'Invalid input {user_input}')
is_pos = user_input in self.pos_keys
joint_idx = self.pos_keys.index(user_input) if is_pos else self.neg_keys.index(user_input)
self.delta_position[joint_idx] = self.increment * (1 if is_pos else -1)
curr_pos = self.current_position
delta_pos = self.delta_position
# update current position and reset delta
self.current_position += self.delta_position
self.delta_position = np.zeros(self.action_dim)
if self.return_delta:
return delta_pos
else:
return curr_pos
def _shift_key(self, key):
if key.isalpha():
return key.upper()
return {
'1': '!', '2': '@', '3': '#', '4': '$', '5': '%',
'6': '^', '7': '&', '8': '*', '9': '(', '0': ')',
'-': '_', '=': '+', '[': '{', ']': '}', '\\': '|',
';': ':', "'": '"', ',': '<', '.': '>', '/': '?',
'`': '~'
}.get(key, key)
class RandomPlanarQuadDirectionalPolicy(RandomPolicy):
def __init__(self, increment: float = 0.5):
self.increment = increment
def generate_action(self, obs):
actions = [
np.array([0, self.increment]),
np.array([0, -self.increment]),
np.array([self.increment, 0]),
np.array([-self.increment, 0])
]
return actions[np.random.choice(4)]
class TeleopPlanarQuadDirectionalPolicy(TeleopPolicy):
# control with: w, a, s, d
def __init__(self,
increment: float = 0.5, # increment for each direction
):
super().__init__()
self.increment = increment
def generate_action(self, obs):
while (user_input := input('Waiting for input: ')) not in ['w', 'a', 's', 'd']:
print(f'Invalid input {user_input}')
# follow IRASIM's convention
if user_input == 'd':
return np.array([0, self.increment])
elif user_input == 'a':
return np.array([0, -self.increment])
elif user_input == 's':
return np.array([self.increment, 0])
elif user_input == 'w':
return np.array([-self.increment, 0])
class GeniePolicy(LearnedPolicy):
average_delta_psnr_over = 5
def __init__(self,
# image preprocessing
max_image_resolution: int = 1024,
resize_image: bool = True,
resize_image_resolution: int = 256,
# tokenizer setting
image_encoder_type: str = "magvit",
image_encoder_ckpt: str = "data/magvit2.ckpt",
quantize: bool = True,
quantization_slice_size: int = 16,
# dynamics backbone setting
backbone_type: str = "stmaskgit",
backbone_ckpt: str = "data/genie_model/final_checkpt",
prompt_horizon: int = 4,
prediction_horizon: int = 4,
execution_horizon: int = 2, # half of the prediction context
inference_iterations: Optional[int] = None,
sampling_temperature: float = 0.0,
action_stride: Optional[int] = None,
domain: str = "robomimic",
genie_frequency: int = 2,
diffusion_steps=10,
# misc
is_full_dynamics: bool = False,
device: str = 'cuda',
use_raw_image=False,
):
super().__init__()
assert quantize == (image_encoder_type == "magvit"), \
"Currently quantization if and only if magvit is the image encoder."
assert image_encoder_type in ["magvit", "temporalvae"], \
"Image encoder type must be either 'magvit' or 'temporalvae'."
assert not quantize or image_encoder_type == "magvit", \
"If quantize is enabled, image encoder type must be 'magvit'."
assert backbone_type in ["stmaskgit", "stmar"], \
"Backbone type must be either 'stmaskgit' or 'stmar'."
if action_stride is None:
action_stride = DATA_FREQ_TABLE[domain] // genie_frequency
if inference_iterations is None:
if backbone_type == "stmaskgit":
inference_iterations = 2
elif backbone_type == "stmar":
inference_iterations = 2
# misc
self.use_raw_image = use_raw_image
self.device = torch.device(device)
self.is_full_dynamics = is_full_dynamics
self.prediction_horizon = prediction_horizon
self.execution_horizon = execution_horizon
self.open_loop_step = self.execution_horizon - 1
# image preprocessing
self.max_image_resolution = max_image_resolution
self.resize_image = resize_image
self.resize_image_resolution = resize_image_resolution
# load image encoder
self.image_encoding_dtype = torch.bfloat16
self.quantize = quantize
self.quant_slice_size = quantization_slice_size
self.image_encoder_type = image_encoder_type
self.image_encoder = get_image_encoder(
image_encoder_type,
image_encoder_ckpt
).to(device=self.device, dtype=self.image_encoding_dtype).eval()
# load STMaskGIT model (STMAR is inherited from STMaskGIT)
self.prompt_horizon = prompt_horizon
self.domain = domain
self.genie_frequency = genie_frequency
self.inference_iterations = inference_iterations
self.sampling_temperature = sampling_temperature
self.action_stride = action_stride
self.backbone_type = backbone_type
if not os.path.exists(backbone_ckpt + "/config.json"):
# search and find the latest modified checkpoint folder
dirs = [os.path.join(backbone_ckpt, f.name) for f in os.scandir(backbone_ckpt) if f.is_dir()]
dirs.sort(key=os.path.getctime)
backbone_ckpt = dirs[-1]
print("backbone_ckpt:", backbone_ckpt)
if backbone_type == "stmaskgit":
self.backbone = STMaskGIT.from_pretrained(backbone_ckpt)
else:
self.backbone = STMAR.from_pretrained(backbone_ckpt)
self.backbone.action_diff_losses[domain].gen_diffusion.num_timesteps = diffusion_steps
self.backbone.diffloss.gen_diffusion.num_timesteps = diffusion_steps
self.backbone = self.backbone.to(device=self.device).eval()
# history buffer, i.e., the input to the model
self.cached_actions = None # (prompt_horizon, action_stride, A)
self.cached_latent_frames = None # (prompt_horizon, ...)
self.init_prompt = None # (prompt_frames, prompt_actions)
# report model size
print(
"================ Model Size Report ================\n"
f" encoder size: {sum(p.numel() for p in self.image_encoder.parameters()) / 1e6:.3f}M \n"
f" backbone size: {sum(p.numel() for p in self.backbone.parameters()) / 1e6:.3f}M\n"
"==================================================="
)
def set_initial_state(self, state: Tuple[np.ndarray, np.ndarray]):
self.init_prompt = state
@torch.inference_mode()
def generate_action(self, obs: Dict[str, Any]) -> np.ndarray:
# obs: {'image': np.ndarray (H, W, 3), ...}
# return: np.ndarray (stride, A)
assert self.cached_latent_frames is not None, "Model is not prompted yet."
this_image = obs['image']
# encode
this_latent = self._encode_image(this_image)
# update cache for the current image. prompt_horizon+1 timesteps
self.cached_latent_frames = torch.cat([self.cached_latent_frames, this_latent.unsqueeze(0)]).to(torch.float32)
# new video tokens. s_t-prompt_horizon to s_t+1, s_t+1 to s_t+execution_horizon are masked tokens
mask_tokens = torch.zeros(self.execution_horizon - 1, *self.cached_latent_frames.shape[1:],
dtype=self.cached_latent_frames.dtype,
device=self.device)
input_latent_states = torch.cat([self.cached_latent_frames, mask_tokens]).unsqueeze(0).to(torch.float32) # add batch dimension
# new action tokens. a_t-prompt_horizon to a_t, a_t to a_t+execution_horizon are masked tokens
self.cached_actions = torch.cat([
self.cached_actions, torch.zeros(self.execution_horizon, *self.cached_actions.shape[1:],
dtype=self.cached_actions.dtype,
device=self.device)]).to(torch.float32)
cached_actions = einops.rearrange(self.cached_actions, "h b c -> b h c")
action_mask = torch.zeros(cached_actions.shape[0], cached_actions.shape[1], 1, 1,
dtype=self.cached_actions.dtype, device=self.device)
action_mask[:, self.prompt_horizon:] = 1
# dtype conversion and mask token
if self.backbone_type == "stmaskgit":
input_latent_states = input_latent_states.long()
input_latent_states[:, self.prompt_horizon + 1:] = self.backbone.mask_token_id
# we should experiment with the other way to do this as well
# cached_actions[:, self.prompt_horizon:] = self.backbone.action_mask_tokens
elif self.backbone_type == "stmar":
input_latent_states[:, self.prompt_horizon + 1:] = self.backbone.mask_token
# cached_actions[:, self.prompt_horizon:] = self.backbone.action_mask_tokens
if self.open_loop_step != self.execution_horizon - 1:
self.open_loop_step += 1
else:
cached_actions = cached_actions[:, -input_latent_states.shape[1]:]
if self.execution_horizon == 1:
self.pred_action = self.backbone.maskgit_generate(
input_latent_states,
out_t=self.prompt_horizon,
maskgit_steps=self.inference_iterations,
temperature=self.sampling_temperature,
action_ids=cached_actions, # if self.is_full_dynamics else None
domain=[self.domain],
action_mask=action_mask
)[-1].squeeze(0)
else:
self.pred_action = self.backbone.maskgit_generate_horizon(
input_latent_states,
out_t_min=self.prompt_horizon,
out_t_max=self.prompt_horizon + self.execution_horizon,
maskgit_steps=self.inference_iterations,
temperature=self.sampling_temperature,
action_ids=cached_actions, # if self.is_full_dynamics else None
domain=[self.domain],
action_mask=action_mask
)[-1].squeeze(0)
self.open_loop_step = 0
pred_action = self.pred_action[self.prompt_horizon+self.open_loop_step:self.prompt_horizon+self.open_loop_step+1]
self.cached_actions = torch.cat([self.cached_actions, pred_action.unsqueeze(0)]).to(torch.float32)
self.cached_actions = self.cached_actions[-self.prompt_horizon:]
self.cached_latent_frames = self.cached_latent_frames[-self.prompt_horizon:]
return pred_action.detach().cpu().numpy()
@torch.inference_mode()
def _encode_image(self, image: np.ndarray) -> torch.Tensor:
# (H, W, 3)
if self.quantize:
image = torch.from_numpy(
self._normalize_image(image).transpose(2, 0, 1)
).to(device=self.device, dtype=self.image_encoding_dtype
).unsqueeze(0)
H, W = image.shape[-2:]
H //= self.quant_slice_size
W //= self.quant_slice_size
_, _, indices, _ = self.image_encoder.encode(image, flip=True)
indices = einops.rearrange(indices, "(h w) -> h w", h=H, w=W)
indices = indices.to(torch.int32)
return indices
elif self.use_raw_image:
image = torch.from_numpy(image).permute(2, 0, 1)
norm_image = torch.nn.functional.interpolate(image[None] / 255.0, (32, 32)) - 0.5
norm_image = einops.rearrange(norm_image, "b c h w -> b h w c")
norm_image = norm_image.squeeze(0).to(torch.float32).to(self.device)
return norm_image
else:
image = torch.from_numpy(
self._normalize_image(image).transpose(2, 0, 1)
).to(device=self.device, dtype=self.image_encoding_dtype
).unsqueeze(0)
H, W = image.shape[-2:]
if self.image_encoder_type == "magvit":
latent = self.image_encoder.encode_without_quantize(image)
elif self.image_encoder_type == "temporalvae":
latent_dist = self.image_encoder.encode(image).latent_dist
latent = latent_dist.mean
latent *= SVD_SCALE
latent = einops.rearrange(latent, "b c h w -> b h w c")
else:
pass
latent = latent.squeeze(0).to(torch.float32)
return latent
def _normalize_image(self, image: np.ndarray) -> np.ndarray:
# (H, W, 3) normalized to [-1, 1]
# if `resize`, resize the shorter side to `resized_res`
# and then do a center crop
image = np.asarray(image, dtype=np.float32)
image /= 255.
H, W = image.shape[:2]
# resize if asked
if self.resize_image:
resized_res = self.resize_image_resolution
if H < W:
Hnew, Wnew = resized_res, int(resized_res * W / H)
else:
Hnew, Wnew = int(resized_res * H / W), resized_res
image = cv2.resize(image, (Wnew, Hnew))
# center crop
H, W = image.shape[:2]
Hstart = (H - resized_res) // 2
Wstart = (W - resized_res) // 2
image = image[Hstart:Hstart + resized_res, Wstart:Wstart + resized_res]
# resize if resolution is too large
elif H > self.max_image_resolution or W > self.max_image_resolution:
if H < W:
Hnew, Wnew = int(self.max_image_resolution * H / W), self.max_image_resolution
else:
Hnew, Wnew = self.max_image_resolution, int(self.max_image_resolution * W / H)
image = cv2.resize(image, (Wnew, Hnew))
image = (image * 2 - 1.)
return image
def reset(self) -> np.ndarray:
# if ground truth physics simulator is provided,
# return the the side-by-side concatenated image
assert self.init_prompt is not None, "Initial state is not set."
prompt_frames, prompt_actions = self.init_prompt
current_image = prompt_frames[-1]
prompt_actions = torch.from_numpy(prompt_actions
).to(device=self.device, dtype=torch.float32)
self.cached_actions = prompt_actions
# convert to latent
self.cached_latent_frames = torch.stack([
self._encode_image(frame) for frame in prompt_frames
], axis=0)
if self.resize_image:
current_image = cv2.resize(current_image,
(self.resize_image_resolution, self.resize_image_resolution))
return current_image
def close(self):
pass
@property
def dt(self):
return 1.0 / self.genie_frequency
class DiffusionPolicy(LearnedPolicy):
def __init__(self, checkpoint: str):
super().__init__()
self.policy = diffusion_policy_factory(checkpoint)
def __getattr__(self, name):
try:
return self.__dict__[name]
except KeyError:
return getattr(self.policy, name)
def generate_action(self, obs):
return self.policy.predict_action(obs)
def reset(self):
pass
def close(self):
pass