Spaces:
Running
on
Zero
Running
on
Zero
import numpy as np | |
import cv2 | |
import os | |
import torch | |
import tqdm | |
import time | |
import imageio | |
from sim.simulator import GenieSimulator, RobomimicSimulator | |
from diffusion_policy.util.pytorch_util import dict_apply | |
from sim.policy import DiffusionPolicy | |
DP_RES = 84 | |
MAX_STEPS = 100 | |
NUM_EVAL_TRIALS = 50 | |
if __name__ == '__main__': | |
robomimic_simulator = RobomimicSimulator(env_name='lift') | |
genie_simulator = GenieSimulator( | |
image_encoder_type='temporalvae', | |
image_encoder_ckpt='stabilityai/stable-video-diffusion-img2vid', | |
quantize=False, | |
backbone_type="stmar", | |
backbone_ckpt="data/mar_ckpt/robomimic_mixed", | |
prompt_horizon=11, | |
action_stride=1, | |
domain='robomimic', | |
physics_simulator=robomimic_simulator, | |
physics_simulator_teacher_force=None, | |
) | |
assert genie_simulator.action_stride == 1, "currently only support action stride of 1" | |
# load the policy | |
success_rates = [1.00, 0.70, 0.52, 0.38] | |
eval_time_taken = [0.0] * len(success_rates) | |
for index, sr in enumerate(success_rates): | |
diffusion_policy = DiffusionPolicy(f'data/dp_ckpt/dp_lift_sr{sr:.2f}.ckpt') | |
n_obs_steps = diffusion_policy.n_obs_steps | |
for trial in range(NUM_EVAL_TRIALS): | |
# reset | |
genie_image = genie_simulator.reset() | |
# obs dict construction | |
latest_obs_dict = { | |
'agentview_image': cv2.resize( | |
genie_image[:, :genie_image.shape[1]//2], | |
(DP_RES, DP_RES) | |
).transpose(2, 0, 1), | |
} | |
obs_dict_buf = dict_apply(latest_obs_dict, lambda x : x[np.newaxis].repeat(n_obs_steps, axis=0)) | |
done = False | |
pbar = tqdm.tqdm(total=MAX_STEPS) | |
simulation_frames = [ genie_image ] | |
start_time = time.time() | |
while not done and pbar.n < MAX_STEPS: | |
# get the latest observation | |
latest_obs_dict = { | |
'agentview_image': cv2.resize( | |
genie_image, | |
(DP_RES, DP_RES) | |
).transpose(2, 0, 1), | |
} | |
# roll the obs dict buffer | |
obs_dict_buf = dict_apply( | |
obs_dict_buf, | |
lambda x : np.roll(x, shift=-1, axis=0) | |
) | |
# update the obs dict buffer with the latest observation | |
for k, v in latest_obs_dict.items(): | |
obs_dict_buf[k][-1] = v | |
# rollout | |
traj = diffusion_policy.generate_action(dict_apply( | |
obs_dict_buf, | |
lambda x : torch.from_numpy(x).to( | |
device=diffusion_policy.device, dtype=diffusion_policy.dtype | |
).unsqueeze(0) | |
))['action'].squeeze(0).detach().cpu().numpy() | |
# step the simulator | |
for action in traj: | |
result = genie_simulator.step(action[np.newaxis]) | |
done = result['done'] | |
genie_image = result['pred_next_frame'] | |
phys_image = result['gt_next_frame'] | |
simulation_frames.append(np.concatenate([genie_image, phys_image], axis=1)) | |
pbar.update(1) | |
pbar.close() | |
end_time = time.time() | |
eval_time_taken[index] += end_time - start_time | |
# save the simulation frames | |
os.makedirs(f'data/policy_eval_videos/policy_{sr:.2f}', exist_ok=True) | |
print(f"Saving {len(simulation_frames)} frames to data/policy_eval_videos/policy_{sr:.2f}/{trial:02d}.mp4") | |
imageio.mimsave(f'data/policy_eval_videos/policy_{sr:.2f}/{trial:02d}.mp4', simulation_frames, fps=10) | |
print(f"This checkpoint took {eval_time_taken[index]} seconds to evaluate") | |
print("======= Evaluation Time Taken =======") | |
for sr, t in zip(success_rates, eval_time_taken): | |
print(f"SR={sr:.2f}: {t:.2f} seconds") | |
print(f"Average time taken per eval: {np.mean(eval_time_taken) / NUM_EVAL_TRIALS:.2f} seconds") | |
print("======= Simulation Done =======") | |
genie_simulator.close() |