File size: 4,306 Bytes
246c106
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import numpy as np
import cv2
import os
import torch
import tqdm
import time
import imageio
from sim.simulator import GenieSimulator, RobomimicSimulator
from diffusion_policy.util.pytorch_util import dict_apply
from sim.policy import DiffusionPolicy

DP_RES = 84
MAX_STEPS = 100
NUM_EVAL_TRIALS = 50


if __name__ == '__main__':
    robomimic_simulator = RobomimicSimulator(env_name='lift')
    genie_simulator = GenieSimulator(

        image_encoder_type='temporalvae',
        image_encoder_ckpt='stabilityai/stable-video-diffusion-img2vid',
        quantize=False,
        backbone_type="stmar",
        backbone_ckpt="data/mar_ckpt/robomimic_mixed",
        prompt_horizon=11,
        
        action_stride=1,
        domain='robomimic',
        physics_simulator=robomimic_simulator,
        physics_simulator_teacher_force=None,
    )
    assert genie_simulator.action_stride == 1, "currently only support action stride of 1"

    # load the policy
    success_rates = [1.00, 0.70, 0.52, 0.38]
    eval_time_taken = [0.0] * len(success_rates)
    for index, sr in enumerate(success_rates):
        diffusion_policy = DiffusionPolicy(f'data/dp_ckpt/dp_lift_sr{sr:.2f}.ckpt')
        n_obs_steps = diffusion_policy.n_obs_steps
        
        for trial in range(NUM_EVAL_TRIALS):

            # reset
            genie_image = genie_simulator.reset()

            # obs dict construction
            latest_obs_dict = {
                'agentview_image': cv2.resize(
                    genie_image[:, :genie_image.shape[1]//2],
                    (DP_RES, DP_RES)
                ).transpose(2, 0, 1),
            }
            obs_dict_buf = dict_apply(latest_obs_dict, lambda x : x[np.newaxis].repeat(n_obs_steps, axis=0))

            done = False
            pbar = tqdm.tqdm(total=MAX_STEPS)
            simulation_frames = [ genie_image ]
            start_time = time.time()
            while not done and pbar.n < MAX_STEPS:

                # get the latest observation
                latest_obs_dict = {
                    'agentview_image': cv2.resize(
                        genie_image, 
                        (DP_RES, DP_RES)
                    ).transpose(2, 0, 1),
                }

                # roll the obs dict buffer
                obs_dict_buf = dict_apply(
                    obs_dict_buf,
                    lambda x : np.roll(x, shift=-1, axis=0)
                )

                # update the obs dict buffer with the latest observation
                for k, v in latest_obs_dict.items():
                    obs_dict_buf[k][-1] = v

                # rollout
                traj = diffusion_policy.generate_action(dict_apply(
                    obs_dict_buf,
                    lambda x : torch.from_numpy(x).to(
                        device=diffusion_policy.device, dtype=diffusion_policy.dtype
                        ).unsqueeze(0)
                ))['action'].squeeze(0).detach().cpu().numpy()

                # step the simulator
                for action in traj:
                    result = genie_simulator.step(action[np.newaxis])
                    done = result['done']
                    genie_image = result['pred_next_frame']
                    phys_image = result['gt_next_frame']
                    simulation_frames.append(np.concatenate([genie_image, phys_image], axis=1))
                    pbar.update(1)
            
            pbar.close()
            
            end_time = time.time()
            eval_time_taken[index] += end_time - start_time

            # save the simulation frames
            os.makedirs(f'data/policy_eval_videos/policy_{sr:.2f}', exist_ok=True)
            print(f"Saving {len(simulation_frames)} frames to data/policy_eval_videos/policy_{sr:.2f}/{trial:02d}.mp4")
            imageio.mimsave(f'data/policy_eval_videos/policy_{sr:.2f}/{trial:02d}.mp4', simulation_frames, fps=10)

        print(f"This checkpoint took {eval_time_taken[index]} seconds to evaluate")

    print("======= Evaluation Time Taken =======")
    for sr, t in zip(success_rates, eval_time_taken):
        print(f"SR={sr:.2f}: {t:.2f} seconds")
    print(f"Average time taken per eval: {np.mean(eval_time_taken) / NUM_EVAL_TRIALS:.2f} seconds")

    print("======= Simulation Done =======")
    genie_simulator.close()