Spaces:
Running
on
Zero
Running
on
Zero
""" | |
Generate data based on the learned policy and physics simulator - mujoco | |
""" | |
from sim.simulator import RobomimicSimulator | |
from sim.policy import DiffusionPolicy | |
from diffusion_policy.util.pytorch_util import dict_apply | |
import h5py | |
import tqdm | |
import numpy as np | |
import torch | |
import cv2 | |
import imageio | |
MAX_STEPS = 100 | |
RES = 84 # for DP input | |
if __name__ == '__main__': | |
env = RobomimicSimulator(env_name='lift') | |
policy = DiffusionPolicy('data/dp_ckpt/dp_lift_sr0.70.ckpt') | |
n_obs_steps = policy.n_obs_steps | |
demos = dict() | |
for trial in range(200): | |
# reset all | |
image = env.reset() | |
this_demo = { | |
"images": [], | |
"actions": [] | |
} | |
latest_obs_dict = {"agentview_image": cv2.resize(image, (RES, RES)).transpose(2, 0, 1)} | |
obs_dict_buf = dict_apply(latest_obs_dict, lambda x: x[np.newaxis].repeat(n_obs_steps, axis=0)) | |
done = False | |
pbar = tqdm.tqdm(total=MAX_STEPS) | |
while not done and pbar.n < MAX_STEPS: | |
# get latest obs | |
latest_obs_dict = {"agentview_image": cv2.resize(image, (RES, RES)).transpose(2, 0, 1)} | |
obs_dict_buf = dict_apply(obs_dict_buf, lambda x: np.roll(x, -1, axis=0)) | |
for k, v in latest_obs_dict.items(): | |
obs_dict_buf[k][-1] = v | |
# rollout | |
traj = policy.generate_action(dict_apply( | |
obs_dict_buf, | |
lambda x: torch.from_numpy(x).to( | |
device=policy.device, dtype=policy.dtype | |
).unsqueeze(0) | |
))['action'].squeeze(0).detach().cpu().numpy() | |
# step the simulator | |
for action in traj: | |
this_demo["images"].append(image) | |
this_demo["actions"].append(action) | |
result = env.step(action) | |
done = done or result['done'] | |
image = result['pred_next_frame'] | |
pbar.update(1) | |
this_demo = dict_apply(this_demo, lambda x: np.array(x)) | |
demos[f"demo_{trial}"] = this_demo | |
demos = {"data": demos} | |
with h5py.File('data/my_robomimic_dataset.hdf5', 'w') as f: | |
# save demos | |
""" | |
demos = { | |
"data": { | |
"demo_0": { | |
"images": np.array([...]), | |
"actions": np.array([...]) | |
}, | |
"demo_1": { | |
"images": np.array([...]), | |
"actions": np.array([...]) | |
}, | |
... | |
} | |
} | |
""" | |
data = f.create_group("data") | |
for demo_name, demo_data in demos["data"].items(): | |
demo = data.create_group(demo_name) | |
for key, value in demo_data.items(): | |
demo.create_dataset(key, data=value) | |