hma / sim /evaluate_policy_learned_sim.py
LeroyWaa's picture
draft
246c106
raw
history blame
4.31 kB
import numpy as np
import cv2
import os
import torch
import tqdm
import time
import imageio
from sim.simulator import GenieSimulator, RobomimicSimulator
from diffusion_policy.util.pytorch_util import dict_apply
from sim.policy import DiffusionPolicy
DP_RES = 84
MAX_STEPS = 100
NUM_EVAL_TRIALS = 50
if __name__ == '__main__':
robomimic_simulator = RobomimicSimulator(env_name='lift')
genie_simulator = GenieSimulator(
image_encoder_type='temporalvae',
image_encoder_ckpt='stabilityai/stable-video-diffusion-img2vid',
quantize=False,
backbone_type="stmar",
backbone_ckpt="data/mar_ckpt/robomimic_mixed",
prompt_horizon=11,
action_stride=1,
domain='robomimic',
physics_simulator=robomimic_simulator,
physics_simulator_teacher_force=None,
)
assert genie_simulator.action_stride == 1, "currently only support action stride of 1"
# load the policy
success_rates = [1.00, 0.70, 0.52, 0.38]
eval_time_taken = [0.0] * len(success_rates)
for index, sr in enumerate(success_rates):
diffusion_policy = DiffusionPolicy(f'data/dp_ckpt/dp_lift_sr{sr:.2f}.ckpt')
n_obs_steps = diffusion_policy.n_obs_steps
for trial in range(NUM_EVAL_TRIALS):
# reset
genie_image = genie_simulator.reset()
# obs dict construction
latest_obs_dict = {
'agentview_image': cv2.resize(
genie_image[:, :genie_image.shape[1]//2],
(DP_RES, DP_RES)
).transpose(2, 0, 1),
}
obs_dict_buf = dict_apply(latest_obs_dict, lambda x : x[np.newaxis].repeat(n_obs_steps, axis=0))
done = False
pbar = tqdm.tqdm(total=MAX_STEPS)
simulation_frames = [ genie_image ]
start_time = time.time()
while not done and pbar.n < MAX_STEPS:
# get the latest observation
latest_obs_dict = {
'agentview_image': cv2.resize(
genie_image,
(DP_RES, DP_RES)
).transpose(2, 0, 1),
}
# roll the obs dict buffer
obs_dict_buf = dict_apply(
obs_dict_buf,
lambda x : np.roll(x, shift=-1, axis=0)
)
# update the obs dict buffer with the latest observation
for k, v in latest_obs_dict.items():
obs_dict_buf[k][-1] = v
# rollout
traj = diffusion_policy.generate_action(dict_apply(
obs_dict_buf,
lambda x : torch.from_numpy(x).to(
device=diffusion_policy.device, dtype=diffusion_policy.dtype
).unsqueeze(0)
))['action'].squeeze(0).detach().cpu().numpy()
# step the simulator
for action in traj:
result = genie_simulator.step(action[np.newaxis])
done = result['done']
genie_image = result['pred_next_frame']
phys_image = result['gt_next_frame']
simulation_frames.append(np.concatenate([genie_image, phys_image], axis=1))
pbar.update(1)
pbar.close()
end_time = time.time()
eval_time_taken[index] += end_time - start_time
# save the simulation frames
os.makedirs(f'data/policy_eval_videos/policy_{sr:.2f}', exist_ok=True)
print(f"Saving {len(simulation_frames)} frames to data/policy_eval_videos/policy_{sr:.2f}/{trial:02d}.mp4")
imageio.mimsave(f'data/policy_eval_videos/policy_{sr:.2f}/{trial:02d}.mp4', simulation_frames, fps=10)
print(f"This checkpoint took {eval_time_taken[index]} seconds to evaluate")
print("======= Evaluation Time Taken =======")
for sr, t in zip(success_rates, eval_time_taken):
print(f"SR={sr:.2f}: {t:.2f} seconds")
print(f"Average time taken per eval: {np.mean(eval_time_taken) / NUM_EVAL_TRIALS:.2f} seconds")
print("======= Simulation Done =======")
genie_simulator.close()