Spaces:
Running
on
Zero
Running
on
Zero
File size: 3,244 Bytes
246c106 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 |
import numpy as np
import imageio
import click
from sim.main import InteractiveDigitalWorld
from sim.simulator import GenieSimulator
from sim.policy import TeleopPlanarQuadDirectionalPolicy
import cv2
"""
for maskgit:
python -m sim.example.genie_langtable_teleop --image_encoder_type magvit --image_encoder_ckpt data/magvit2.ckpt \
--quantize True --prompt_horizon 8 --backbone_type stmaskgit --backbone_ckpt data/maskgit_ckpt/langtable
"""
@click.command()
@click.option('--image_encoder_type', type=str, default='temporalvae')
@click.option('--image_encoder_ckpt', type=str, default='stabilityai/stable-video-diffusion-img2vid')
@click.option('--quantize', type=bool, default=False)
@click.option('--backbone_type', type=str, default='stmar')
@click.option('--backbone_ckpt', type=str, default='data/mar_ckpt/langtable')
@click.option('--prompt_horizon', type=int, default=11)
@click.option('--action_stride', type=int, default=1)
@click.option('--video_save_path', type=str, default='test.mp4')
@click.option('--scene_id', type=int, default=6)
@click.option('--live', type=bool, default=True)
def main(
image_encoder_type,
image_encoder_ckpt,
quantize,
backbone_type,
backbone_ckpt,
prompt_horizon,
action_stride,
video_save_path,
scene_id,
live
):
def draw_action_arrow_to_image(image: np.ndarray, action: np.ndarray) -> np.ndarray:
action = action[0] # remove `stride` dimension
assert action[0] * action[1] == 0
arrow_image = imageio.imread('sim/assets/arrow.jpg')
if action[0] > 0: # `s`
arrow_image = np.flipud(arrow_image)
elif action[1] < 0: # `a`
arrow_image = np.rot90(arrow_image)
elif action[1] > 0: # `d`
arrow_image = np.rot90(arrow_image, -1)
else:
pass # `w`
image[0:arrow_image.shape[0], 0:arrow_image.shape[1]] = arrow_image
return image
genie_simulator = GenieSimulator(
image_encoder_type=image_encoder_type,
image_encoder_ckpt=image_encoder_ckpt,
quantize=quantize,
backbone_type=backbone_type,
backbone_ckpt=backbone_ckpt,
prompt_horizon=prompt_horizon,
action_stride=action_stride,
domain='language_table',
post_processor=draw_action_arrow_to_image
)
# use whatever current state is as the initial state
current_image = imageio.imread(f'sim/assets/langtable_prompt/frame_{scene_id:02d}.png')
image_prompt = np.tile(
current_image, (genie_simulator.prompt_horizon, 1, 1, 1)
).astype(np.uint8)
action_prompt = np.zeros(
(genie_simulator.prompt_horizon, genie_simulator.action_stride, 2)
).astype(np.float32)
genie_simulator.set_initial_state((image_prompt, action_prompt))
teleop_policy = TeleopPlanarQuadDirectionalPolicy(increment=0.05)
playground = InteractiveDigitalWorld(
simulator=genie_simulator,
policy=teleop_policy,
offscreen=not live,
window_size=(512, 512)
)
for _ in range(20):
playground.step()
playground.save_video(save_path=video_save_path, as_gif=False)
playground.close()
if __name__ == '__main__':
main() |