hma / sim /save_policy_rollout_gt_sim.py
LeroyWaa's picture
draft
246c106
raw
history blame
2.86 kB
"""
Generate data based on the learned policy and physics simulator - mujoco
"""
from sim.simulator import RobomimicSimulator
from sim.policy import DiffusionPolicy
from diffusion_policy.util.pytorch_util import dict_apply
import h5py
import tqdm
import numpy as np
import torch
import cv2
import imageio
MAX_STEPS = 100
RES = 84 # for DP input
if __name__ == '__main__':
env = RobomimicSimulator(env_name='lift')
policy = DiffusionPolicy('data/dp_ckpt/dp_lift_sr0.70.ckpt')
n_obs_steps = policy.n_obs_steps
demos = dict()
for trial in range(200):
# reset all
image = env.reset()
this_demo = {
"images": [],
"actions": []
}
latest_obs_dict = {"agentview_image": cv2.resize(image, (RES, RES)).transpose(2, 0, 1)}
obs_dict_buf = dict_apply(latest_obs_dict, lambda x: x[np.newaxis].repeat(n_obs_steps, axis=0))
done = False
pbar = tqdm.tqdm(total=MAX_STEPS)
while not done and pbar.n < MAX_STEPS:
# get latest obs
latest_obs_dict = {"agentview_image": cv2.resize(image, (RES, RES)).transpose(2, 0, 1)}
obs_dict_buf = dict_apply(obs_dict_buf, lambda x: np.roll(x, -1, axis=0))
for k, v in latest_obs_dict.items():
obs_dict_buf[k][-1] = v
# rollout
traj = policy.generate_action(dict_apply(
obs_dict_buf,
lambda x: torch.from_numpy(x).to(
device=policy.device, dtype=policy.dtype
).unsqueeze(0)
))['action'].squeeze(0).detach().cpu().numpy()
# step the simulator
for action in traj:
this_demo["images"].append(image)
this_demo["actions"].append(action)
result = env.step(action)
done = done or result['done']
image = result['pred_next_frame']
pbar.update(1)
this_demo = dict_apply(this_demo, lambda x: np.array(x))
demos[f"demo_{trial}"] = this_demo
demos = {"data": demos}
with h5py.File('data/my_robomimic_dataset.hdf5', 'w') as f:
# save demos
"""
demos = {
"data": {
"demo_0": {
"images": np.array([...]),
"actions": np.array([...])
},
"demo_1": {
"images": np.array([...]),
"actions": np.array([...])
},
...
}
}
"""
data = f.create_group("data")
for demo_name, demo_data in demos["data"].items():
demo = data.create_group(demo_name)
for key, value in demo_data.items():
demo.create_dataset(key, data=value)