File size: 2,180 Bytes
246c106
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import numpy as np
import h5py
from sim.main import InteractiveDigitalWorld
from sim.simulator import GenieSimulator, ReplaySimulator
from sim.policy import ReplayPolicy

if __name__ == '__main__':
    demo_idx = 120
    prompt_horizon = 11
    action_stride = 1

    with h5py.File('data/robomimic_datasets/robomimic_raw/datasets/lift/ph/image.hdf5', 'r') as f:
        demo = f['data'][f'demo_{demo_idx}']
        actions = demo['actions'][:].astype(np.float32)
        frames = demo['obs']['agentview_image'][:].astype(np.uint8)     # NOTE: possible re-render

    replay_simulator = ReplaySimulator(frames=frames, prompt_horizon=prompt_horizon)
    replay_policy = ReplayPolicy(actions=actions, prompt_horizon=prompt_horizon, action_stride=action_stride)
    assert len(replay_policy) == len(replay_simulator)

    genie_simulator = GenieSimulator(
        # image_encoder_type="magvit",
        # image_encoder_ckpt="data/magvit2.ckpt",
        # quantize=True,
        # backbone_type="stmaskgit",
        # backbone_ckpt='data/serious_robomimic_d256/step_86500',
        # # backbone_ckpt="data/genie_lang/step_5",

        image_encoder_type='temporalvae',
        image_encoder_ckpt='stabilityai/stable-video-diffusion-img2vid',
        quantize=False,
        backbone_type="stmar",
        backbone_ckpt="data/mar_ckpt/robomimic",
        
        prompt_horizon=prompt_horizon,
        action_stride=action_stride,
        domain='robomimic',
        physics_simulator=replay_simulator,
        compute_psnr=True,
        compute_delta_psnr=True,
        allow_external_prompt=True,
    )
    # use whatever current state is as the initial state
    image_prompt = replay_simulator.prompt()
    action_prompt = replay_policy.prompt()
    genie_simulator.set_initial_state((image_prompt, action_prompt))
    playground = InteractiveDigitalWorld(
        simulator=genie_simulator,
        policy=replay_policy,
        offscreen=True,
        window_size=(512 * 2, 512)  # [genie image | GT image] side-by-side
    )

    for _ in range(len(replay_policy)):
        playground.step()

    playground.save_video(save_path='test.mp4', as_gif=False)
    playground.close()