import numpy as np import cv2 import torch import einops from genie.st_mask_git import STMaskGIT from genie.st_mar import STMAR from datasets.utils import get_image_encoder from diffusion_policy import diffusion_policy_factory from data import DATA_FREQ_TABLE from train_diffusion import SVD_SCALE from typing import Optional, Tuple, List, Dict, Any import os class Policy: def generate_action(self, obs): raise NotImplementedError def reset(self): pass class RandomPolicy(Policy): def __init__(self): super().__init__() class TeleopPolicy(Policy): def __init__(self): super().__init__() class LearnedPolicy(Policy): def __init__(self): super().__init__() class ReplayPolicy(Policy): def __init__(self, actions: np.ndarray, # (T * S, A) action_stride: int = 1, prompt_horizon: int = 0, ): super().__init__() T = len(actions) // action_stride self.actions = actions[:T * action_stride ].reshape(T, action_stride, actions.shape[-1]) self.action_idx = prompt_horizon self.prompt_horizon = prompt_horizon self.action_stride = action_stride assert self.action_idx < len(self.actions) def __len__(self): return len(self.actions) - self.prompt_horizon def generate_action(self, obs): assert self.action_idx < len(self.actions) action = self.actions[self.action_idx] self.action_idx = self.action_idx + 1 return action def reset(self): # return current action = last action of prompt self.action_idx = self.prompt_horizon return self.prompt()[-1] def prompt(self): return self.actions[:self.prompt_horizon] class RandomJointPositionPolicy(RandomPolicy): def __init__(self, action_bounds: Tuple[np.ndarray, np.ndarray]): self.lb = action_bounds[0] self.ub = action_bounds[1] self.action_dim = action_bounds[0].shape[0] def generate_action(self, obs): return np.random.uniform(self.lb, self.ub) class TeleopJointPositionPolicy(TeleopPolicy): """ Example usage: teleop = TeleopJointPositionPolicy( initial_position=[0, 0, 0, 0, 0, 0, 0], increment=0.1, keyboard_bindings=['q', 'w', 'e', 'r', 't', 'y', 'u'], return_delta=False ) while True: print(teleop.generate_action(None)) """ def __init__(self, initial_position: List[float], # initial position for each joint increment: float, # increment for each joint keyboard_bindings: List[str], # list of keyboard bindings for each joint # shift + key for negative direction return_delta: bool = False, # if True, return delta instead of absolute position ): super().__init__() self.increment = increment self.pos_keys = keyboard_bindings self.neg_keys = [self._shift_key(key) for key in keyboard_bindings] self.action_dim = len(keyboard_bindings) self.return_delta = return_delta self.current_position = np.array(initial_position) self.shift_pressed = False self.delta_position = np.zeros(self.action_dim) def generate_action(self, obs): while (user_input := input('Waiting for input: ')) not in self.pos_keys + self.neg_keys: print(f'Invalid input {user_input}') is_pos = user_input in self.pos_keys joint_idx = self.pos_keys.index(user_input) if is_pos else self.neg_keys.index(user_input) self.delta_position[joint_idx] = self.increment * (1 if is_pos else -1) curr_pos = self.current_position delta_pos = self.delta_position # update current position and reset delta self.current_position += self.delta_position self.delta_position = np.zeros(self.action_dim) if self.return_delta: return delta_pos else: return curr_pos def _shift_key(self, key): if key.isalpha(): return key.upper() return { '1': '!', '2': '@', '3': '#', '4': '$', '5': '%', '6': '^', '7': '&', '8': '*', '9': '(', '0': ')', '-': '_', '=': '+', '[': '{', ']': '}', '\\': '|', ';': ':', "'": '"', ',': '<', '.': '>', '/': '?', '`': '~' }.get(key, key) class RandomPlanarQuadDirectionalPolicy(RandomPolicy): def __init__(self, increment: float = 0.5): self.increment = increment def generate_action(self, obs): actions = [ np.array([0, self.increment]), np.array([0, -self.increment]), np.array([self.increment, 0]), np.array([-self.increment, 0]) ] return actions[np.random.choice(4)] class TeleopPlanarQuadDirectionalPolicy(TeleopPolicy): # control with: w, a, s, d def __init__(self, increment: float = 0.5, # increment for each direction ): super().__init__() self.increment = increment def generate_action(self, obs): while (user_input := input('Waiting for input: ')) not in ['w', 'a', 's', 'd']: print(f'Invalid input {user_input}') # follow IRASIM's convention if user_input == 'd': return np.array([0, self.increment]) elif user_input == 'a': return np.array([0, -self.increment]) elif user_input == 's': return np.array([self.increment, 0]) elif user_input == 'w': return np.array([-self.increment, 0]) class GeniePolicy(LearnedPolicy): average_delta_psnr_over = 5 def __init__(self, # image preprocessing max_image_resolution: int = 1024, resize_image: bool = True, resize_image_resolution: int = 256, # tokenizer setting image_encoder_type: str = "magvit", image_encoder_ckpt: str = "data/magvit2.ckpt", quantize: bool = True, quantization_slice_size: int = 16, # dynamics backbone setting backbone_type: str = "stmaskgit", backbone_ckpt: str = "data/genie_model/final_checkpt", prompt_horizon: int = 4, prediction_horizon: int = 4, execution_horizon: int = 2, # half of the prediction context inference_iterations: Optional[int] = None, sampling_temperature: float = 0.0, action_stride: Optional[int] = None, domain: str = "robomimic", genie_frequency: int = 2, diffusion_steps=10, # misc is_full_dynamics: bool = False, device: str = 'cuda', use_raw_image=False, ): super().__init__() assert quantize == (image_encoder_type == "magvit"), \ "Currently quantization if and only if magvit is the image encoder." assert image_encoder_type in ["magvit", "temporalvae"], \ "Image encoder type must be either 'magvit' or 'temporalvae'." assert not quantize or image_encoder_type == "magvit", \ "If quantize is enabled, image encoder type must be 'magvit'." assert backbone_type in ["stmaskgit", "stmar"], \ "Backbone type must be either 'stmaskgit' or 'stmar'." if action_stride is None: action_stride = DATA_FREQ_TABLE[domain] // genie_frequency if inference_iterations is None: if backbone_type == "stmaskgit": inference_iterations = 2 elif backbone_type == "stmar": inference_iterations = 2 # misc self.use_raw_image = use_raw_image self.device = torch.device(device) self.is_full_dynamics = is_full_dynamics self.prediction_horizon = prediction_horizon self.execution_horizon = execution_horizon self.open_loop_step = self.execution_horizon - 1 # image preprocessing self.max_image_resolution = max_image_resolution self.resize_image = resize_image self.resize_image_resolution = resize_image_resolution # load image encoder self.image_encoding_dtype = torch.bfloat16 self.quantize = quantize self.quant_slice_size = quantization_slice_size self.image_encoder_type = image_encoder_type self.image_encoder = get_image_encoder( image_encoder_type, image_encoder_ckpt ).to(device=self.device, dtype=self.image_encoding_dtype).eval() # load STMaskGIT model (STMAR is inherited from STMaskGIT) self.prompt_horizon = prompt_horizon self.domain = domain self.genie_frequency = genie_frequency self.inference_iterations = inference_iterations self.sampling_temperature = sampling_temperature self.action_stride = action_stride self.backbone_type = backbone_type if not os.path.exists(backbone_ckpt + "/config.json"): # search and find the latest modified checkpoint folder dirs = [os.path.join(backbone_ckpt, f.name) for f in os.scandir(backbone_ckpt) if f.is_dir()] dirs.sort(key=os.path.getctime) backbone_ckpt = dirs[-1] print("backbone_ckpt:", backbone_ckpt) if backbone_type == "stmaskgit": self.backbone = STMaskGIT.from_pretrained(backbone_ckpt) else: self.backbone = STMAR.from_pretrained(backbone_ckpt) self.backbone.action_diff_losses[domain].gen_diffusion.num_timesteps = diffusion_steps self.backbone.diffloss.gen_diffusion.num_timesteps = diffusion_steps self.backbone = self.backbone.to(device=self.device).eval() # history buffer, i.e., the input to the model self.cached_actions = None # (prompt_horizon, action_stride, A) self.cached_latent_frames = None # (prompt_horizon, ...) self.init_prompt = None # (prompt_frames, prompt_actions) # report model size print( "================ Model Size Report ================\n" f" encoder size: {sum(p.numel() for p in self.image_encoder.parameters()) / 1e6:.3f}M \n" f" backbone size: {sum(p.numel() for p in self.backbone.parameters()) / 1e6:.3f}M\n" "===================================================" ) def set_initial_state(self, state: Tuple[np.ndarray, np.ndarray]): self.init_prompt = state @torch.inference_mode() def generate_action(self, obs: Dict[str, Any]) -> np.ndarray: # obs: {'image': np.ndarray (H, W, 3), ...} # return: np.ndarray (stride, A) assert self.cached_latent_frames is not None, "Model is not prompted yet." this_image = obs['image'] # encode this_latent = self._encode_image(this_image) # update cache for the current image. prompt_horizon+1 timesteps self.cached_latent_frames = torch.cat([self.cached_latent_frames, this_latent.unsqueeze(0)]).to(torch.float32) # new video tokens. s_t-prompt_horizon to s_t+1, s_t+1 to s_t+execution_horizon are masked tokens mask_tokens = torch.zeros(self.execution_horizon - 1, *self.cached_latent_frames.shape[1:], dtype=self.cached_latent_frames.dtype, device=self.device) input_latent_states = torch.cat([self.cached_latent_frames, mask_tokens]).unsqueeze(0).to(torch.float32) # add batch dimension # new action tokens. a_t-prompt_horizon to a_t, a_t to a_t+execution_horizon are masked tokens self.cached_actions = torch.cat([ self.cached_actions, torch.zeros(self.execution_horizon, *self.cached_actions.shape[1:], dtype=self.cached_actions.dtype, device=self.device)]).to(torch.float32) cached_actions = einops.rearrange(self.cached_actions, "h b c -> b h c") action_mask = torch.zeros(cached_actions.shape[0], cached_actions.shape[1], 1, 1, dtype=self.cached_actions.dtype, device=self.device) action_mask[:, self.prompt_horizon:] = 1 # dtype conversion and mask token if self.backbone_type == "stmaskgit": input_latent_states = input_latent_states.long() input_latent_states[:, self.prompt_horizon + 1:] = self.backbone.mask_token_id # we should experiment with the other way to do this as well # cached_actions[:, self.prompt_horizon:] = self.backbone.action_mask_tokens elif self.backbone_type == "stmar": input_latent_states[:, self.prompt_horizon + 1:] = self.backbone.mask_token # cached_actions[:, self.prompt_horizon:] = self.backbone.action_mask_tokens if self.open_loop_step != self.execution_horizon - 1: self.open_loop_step += 1 else: cached_actions = cached_actions[:, -input_latent_states.shape[1]:] if self.execution_horizon == 1: self.pred_action = self.backbone.maskgit_generate( input_latent_states, out_t=self.prompt_horizon, maskgit_steps=self.inference_iterations, temperature=self.sampling_temperature, action_ids=cached_actions, # if self.is_full_dynamics else None domain=[self.domain], action_mask=action_mask )[-1].squeeze(0) else: self.pred_action = self.backbone.maskgit_generate_horizon( input_latent_states, out_t_min=self.prompt_horizon, out_t_max=self.prompt_horizon + self.execution_horizon, maskgit_steps=self.inference_iterations, temperature=self.sampling_temperature, action_ids=cached_actions, # if self.is_full_dynamics else None domain=[self.domain], action_mask=action_mask )[-1].squeeze(0) self.open_loop_step = 0 pred_action = self.pred_action[self.prompt_horizon+self.open_loop_step:self.prompt_horizon+self.open_loop_step+1] self.cached_actions = torch.cat([self.cached_actions, pred_action.unsqueeze(0)]).to(torch.float32) self.cached_actions = self.cached_actions[-self.prompt_horizon:] self.cached_latent_frames = self.cached_latent_frames[-self.prompt_horizon:] return pred_action.detach().cpu().numpy() @torch.inference_mode() def _encode_image(self, image: np.ndarray) -> torch.Tensor: # (H, W, 3) if self.quantize: image = torch.from_numpy( self._normalize_image(image).transpose(2, 0, 1) ).to(device=self.device, dtype=self.image_encoding_dtype ).unsqueeze(0) H, W = image.shape[-2:] H //= self.quant_slice_size W //= self.quant_slice_size _, _, indices, _ = self.image_encoder.encode(image, flip=True) indices = einops.rearrange(indices, "(h w) -> h w", h=H, w=W) indices = indices.to(torch.int32) return indices elif self.use_raw_image: image = torch.from_numpy(image).permute(2, 0, 1) norm_image = torch.nn.functional.interpolate(image[None] / 255.0, (32, 32)) - 0.5 norm_image = einops.rearrange(norm_image, "b c h w -> b h w c") norm_image = norm_image.squeeze(0).to(torch.float32).to(self.device) return norm_image else: image = torch.from_numpy( self._normalize_image(image).transpose(2, 0, 1) ).to(device=self.device, dtype=self.image_encoding_dtype ).unsqueeze(0) H, W = image.shape[-2:] if self.image_encoder_type == "magvit": latent = self.image_encoder.encode_without_quantize(image) elif self.image_encoder_type == "temporalvae": latent_dist = self.image_encoder.encode(image).latent_dist latent = latent_dist.mean latent *= SVD_SCALE latent = einops.rearrange(latent, "b c h w -> b h w c") else: pass latent = latent.squeeze(0).to(torch.float32) return latent def _normalize_image(self, image: np.ndarray) -> np.ndarray: # (H, W, 3) normalized to [-1, 1] # if `resize`, resize the shorter side to `resized_res` # and then do a center crop image = np.asarray(image, dtype=np.float32) image /= 255. H, W = image.shape[:2] # resize if asked if self.resize_image: resized_res = self.resize_image_resolution if H < W: Hnew, Wnew = resized_res, int(resized_res * W / H) else: Hnew, Wnew = int(resized_res * H / W), resized_res image = cv2.resize(image, (Wnew, Hnew)) # center crop H, W = image.shape[:2] Hstart = (H - resized_res) // 2 Wstart = (W - resized_res) // 2 image = image[Hstart:Hstart + resized_res, Wstart:Wstart + resized_res] # resize if resolution is too large elif H > self.max_image_resolution or W > self.max_image_resolution: if H < W: Hnew, Wnew = int(self.max_image_resolution * H / W), self.max_image_resolution else: Hnew, Wnew = self.max_image_resolution, int(self.max_image_resolution * W / H) image = cv2.resize(image, (Wnew, Hnew)) image = (image * 2 - 1.) return image def reset(self) -> np.ndarray: # if ground truth physics simulator is provided, # return the the side-by-side concatenated image assert self.init_prompt is not None, "Initial state is not set." prompt_frames, prompt_actions = self.init_prompt current_image = prompt_frames[-1] prompt_actions = torch.from_numpy(prompt_actions ).to(device=self.device, dtype=torch.float32) self.cached_actions = prompt_actions # convert to latent self.cached_latent_frames = torch.stack([ self._encode_image(frame) for frame in prompt_frames ], axis=0) if self.resize_image: current_image = cv2.resize(current_image, (self.resize_image_resolution, self.resize_image_resolution)) return current_image def close(self): pass @property def dt(self): return 1.0 / self.genie_frequency class DiffusionPolicy(LearnedPolicy): def __init__(self, checkpoint: str): super().__init__() self.policy = diffusion_policy_factory(checkpoint) def __getattr__(self, name): try: return self.__dict__[name] except KeyError: return getattr(self.policy, name) def generate_action(self, obs): return self.policy.predict_action(obs) def reset(self): pass def close(self): pass