File size: 3,244 Bytes
246c106
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import numpy as np
import imageio
import click
from sim.main import InteractiveDigitalWorld
from sim.simulator import GenieSimulator
from sim.policy import TeleopPlanarQuadDirectionalPolicy
import cv2


"""
for maskgit:
python -m sim.example.genie_langtable_teleop --image_encoder_type magvit --image_encoder_ckpt data/magvit2.ckpt \
    --quantize True --prompt_horizon 8 --backbone_type stmaskgit --backbone_ckpt data/maskgit_ckpt/langtable
"""
@click.command()
@click.option('--image_encoder_type', type=str, default='temporalvae')
@click.option('--image_encoder_ckpt', type=str, default='stabilityai/stable-video-diffusion-img2vid')
@click.option('--quantize', type=bool, default=False)
@click.option('--backbone_type', type=str, default='stmar')
@click.option('--backbone_ckpt', type=str, default='data/mar_ckpt/langtable')
@click.option('--prompt_horizon', type=int, default=11)
@click.option('--action_stride', type=int, default=1)
@click.option('--video_save_path', type=str, default='test.mp4')
@click.option('--scene_id', type=int, default=6)
@click.option('--live', type=bool, default=True)
def main(
    image_encoder_type,
    image_encoder_ckpt,
    quantize,
    backbone_type,
    backbone_ckpt,
    prompt_horizon,
    action_stride,
    video_save_path,
    scene_id,
    live
):
    def draw_action_arrow_to_image(image: np.ndarray, action: np.ndarray) -> np.ndarray:
        action = action[0]  # remove `stride` dimension
        assert action[0] * action[1] == 0
        arrow_image = imageio.imread('sim/assets/arrow.jpg')
        if action[0] > 0:   # `s`
            arrow_image = np.flipud(arrow_image)
        elif action[1] < 0: # `a`
            arrow_image = np.rot90(arrow_image)
        elif action[1] > 0: # `d`
            arrow_image = np.rot90(arrow_image, -1)
        else:
            pass            # `w`
        image[0:arrow_image.shape[0], 0:arrow_image.shape[1]] = arrow_image
        return image
    genie_simulator = GenieSimulator(
        image_encoder_type=image_encoder_type,
        image_encoder_ckpt=image_encoder_ckpt,
        quantize=quantize,
        backbone_type=backbone_type,
        backbone_ckpt=backbone_ckpt,
        prompt_horizon=prompt_horizon,
        
        action_stride=action_stride,
        domain='language_table',
        post_processor=draw_action_arrow_to_image
    )
    # use whatever current state is as the initial state
    current_image = imageio.imread(f'sim/assets/langtable_prompt/frame_{scene_id:02d}.png')
    image_prompt = np.tile(
        current_image, (genie_simulator.prompt_horizon, 1, 1, 1)
        ).astype(np.uint8)
    action_prompt = np.zeros(
        (genie_simulator.prompt_horizon, genie_simulator.action_stride, 2)
        ).astype(np.float32)
    genie_simulator.set_initial_state((image_prompt, action_prompt))
    teleop_policy = TeleopPlanarQuadDirectionalPolicy(increment=0.05)
    playground = InteractiveDigitalWorld(
        simulator=genie_simulator,
        policy=teleop_policy,
        offscreen=not live,
        window_size=(512, 512)
    )

    for _ in range(20):
        playground.step()

    playground.save_video(save_path=video_save_path, as_gif=False)
    playground.close()


if __name__ == '__main__':
    main()