Spaces:
Running
on
Zero
Running
on
Zero
import numpy as np | |
import imageio | |
import click | |
from sim.main import InteractiveDigitalWorld | |
from sim.simulator import GenieSimulator | |
from sim.policy import TeleopPlanarQuadDirectionalPolicy | |
import cv2 | |
""" | |
for maskgit: | |
python -m sim.example.genie_langtable_teleop --image_encoder_type magvit --image_encoder_ckpt data/magvit2.ckpt \ | |
--quantize True --prompt_horizon 8 --backbone_type stmaskgit --backbone_ckpt data/maskgit_ckpt/langtable | |
""" | |
def main( | |
image_encoder_type, | |
image_encoder_ckpt, | |
quantize, | |
backbone_type, | |
backbone_ckpt, | |
prompt_horizon, | |
action_stride, | |
video_save_path, | |
scene_id, | |
live | |
): | |
def draw_action_arrow_to_image(image: np.ndarray, action: np.ndarray) -> np.ndarray: | |
action = action[0] # remove `stride` dimension | |
assert action[0] * action[1] == 0 | |
arrow_image = imageio.imread('sim/assets/arrow.jpg') | |
if action[0] > 0: # `s` | |
arrow_image = np.flipud(arrow_image) | |
elif action[1] < 0: # `a` | |
arrow_image = np.rot90(arrow_image) | |
elif action[1] > 0: # `d` | |
arrow_image = np.rot90(arrow_image, -1) | |
else: | |
pass # `w` | |
image[0:arrow_image.shape[0], 0:arrow_image.shape[1]] = arrow_image | |
return image | |
genie_simulator = GenieSimulator( | |
image_encoder_type=image_encoder_type, | |
image_encoder_ckpt=image_encoder_ckpt, | |
quantize=quantize, | |
backbone_type=backbone_type, | |
backbone_ckpt=backbone_ckpt, | |
prompt_horizon=prompt_horizon, | |
action_stride=action_stride, | |
domain='language_table', | |
post_processor=draw_action_arrow_to_image | |
) | |
# use whatever current state is as the initial state | |
current_image = imageio.imread(f'sim/assets/langtable_prompt/frame_{scene_id:02d}.png') | |
image_prompt = np.tile( | |
current_image, (genie_simulator.prompt_horizon, 1, 1, 1) | |
).astype(np.uint8) | |
action_prompt = np.zeros( | |
(genie_simulator.prompt_horizon, genie_simulator.action_stride, 2) | |
).astype(np.float32) | |
genie_simulator.set_initial_state((image_prompt, action_prompt)) | |
teleop_policy = TeleopPlanarQuadDirectionalPolicy(increment=0.05) | |
playground = InteractiveDigitalWorld( | |
simulator=genie_simulator, | |
policy=teleop_policy, | |
offscreen=not live, | |
window_size=(512, 512) | |
) | |
for _ in range(20): | |
playground.step() | |
playground.save_video(save_path=video_save_path, as_gif=False) | |
playground.close() | |
if __name__ == '__main__': | |
main() |