hma / sim /example /genie_langtable_teleop.py
LeroyWaa's picture
draft
246c106
raw
history blame
3.24 kB
import numpy as np
import imageio
import click
from sim.main import InteractiveDigitalWorld
from sim.simulator import GenieSimulator
from sim.policy import TeleopPlanarQuadDirectionalPolicy
import cv2
"""
for maskgit:
python -m sim.example.genie_langtable_teleop --image_encoder_type magvit --image_encoder_ckpt data/magvit2.ckpt \
--quantize True --prompt_horizon 8 --backbone_type stmaskgit --backbone_ckpt data/maskgit_ckpt/langtable
"""
@click.command()
@click.option('--image_encoder_type', type=str, default='temporalvae')
@click.option('--image_encoder_ckpt', type=str, default='stabilityai/stable-video-diffusion-img2vid')
@click.option('--quantize', type=bool, default=False)
@click.option('--backbone_type', type=str, default='stmar')
@click.option('--backbone_ckpt', type=str, default='data/mar_ckpt/langtable')
@click.option('--prompt_horizon', type=int, default=11)
@click.option('--action_stride', type=int, default=1)
@click.option('--video_save_path', type=str, default='test.mp4')
@click.option('--scene_id', type=int, default=6)
@click.option('--live', type=bool, default=True)
def main(
image_encoder_type,
image_encoder_ckpt,
quantize,
backbone_type,
backbone_ckpt,
prompt_horizon,
action_stride,
video_save_path,
scene_id,
live
):
def draw_action_arrow_to_image(image: np.ndarray, action: np.ndarray) -> np.ndarray:
action = action[0] # remove `stride` dimension
assert action[0] * action[1] == 0
arrow_image = imageio.imread('sim/assets/arrow.jpg')
if action[0] > 0: # `s`
arrow_image = np.flipud(arrow_image)
elif action[1] < 0: # `a`
arrow_image = np.rot90(arrow_image)
elif action[1] > 0: # `d`
arrow_image = np.rot90(arrow_image, -1)
else:
pass # `w`
image[0:arrow_image.shape[0], 0:arrow_image.shape[1]] = arrow_image
return image
genie_simulator = GenieSimulator(
image_encoder_type=image_encoder_type,
image_encoder_ckpt=image_encoder_ckpt,
quantize=quantize,
backbone_type=backbone_type,
backbone_ckpt=backbone_ckpt,
prompt_horizon=prompt_horizon,
action_stride=action_stride,
domain='language_table',
post_processor=draw_action_arrow_to_image
)
# use whatever current state is as the initial state
current_image = imageio.imread(f'sim/assets/langtable_prompt/frame_{scene_id:02d}.png')
image_prompt = np.tile(
current_image, (genie_simulator.prompt_horizon, 1, 1, 1)
).astype(np.uint8)
action_prompt = np.zeros(
(genie_simulator.prompt_horizon, genie_simulator.action_stride, 2)
).astype(np.float32)
genie_simulator.set_initial_state((image_prompt, action_prompt))
teleop_policy = TeleopPlanarQuadDirectionalPolicy(increment=0.05)
playground = InteractiveDigitalWorld(
simulator=genie_simulator,
policy=teleop_policy,
offscreen=not live,
window_size=(512, 512)
)
for _ in range(20):
playground.step()
playground.save_video(save_path=video_save_path, as_gif=False)
playground.close()
if __name__ == '__main__':
main()