""" Generate data based on the learned policy and physics simulator - mujoco """ from sim.simulator import RobomimicSimulator from sim.policy import DiffusionPolicy from diffusion_policy.util.pytorch_util import dict_apply import h5py import tqdm import numpy as np import torch import cv2 import imageio MAX_STEPS = 100 RES = 84 # for DP input if __name__ == '__main__': env = RobomimicSimulator(env_name='lift') policy = DiffusionPolicy('data/dp_ckpt/dp_lift_sr0.70.ckpt') n_obs_steps = policy.n_obs_steps demos = dict() for trial in range(200): # reset all image = env.reset() this_demo = { "images": [], "actions": [] } latest_obs_dict = {"agentview_image": cv2.resize(image, (RES, RES)).transpose(2, 0, 1)} obs_dict_buf = dict_apply(latest_obs_dict, lambda x: x[np.newaxis].repeat(n_obs_steps, axis=0)) done = False pbar = tqdm.tqdm(total=MAX_STEPS) while not done and pbar.n < MAX_STEPS: # get latest obs latest_obs_dict = {"agentview_image": cv2.resize(image, (RES, RES)).transpose(2, 0, 1)} obs_dict_buf = dict_apply(obs_dict_buf, lambda x: np.roll(x, -1, axis=0)) for k, v in latest_obs_dict.items(): obs_dict_buf[k][-1] = v # rollout traj = policy.generate_action(dict_apply( obs_dict_buf, lambda x: torch.from_numpy(x).to( device=policy.device, dtype=policy.dtype ).unsqueeze(0) ))['action'].squeeze(0).detach().cpu().numpy() # step the simulator for action in traj: this_demo["images"].append(image) this_demo["actions"].append(action) result = env.step(action) done = done or result['done'] image = result['pred_next_frame'] pbar.update(1) this_demo = dict_apply(this_demo, lambda x: np.array(x)) demos[f"demo_{trial}"] = this_demo demos = {"data": demos} with h5py.File('data/my_robomimic_dataset.hdf5', 'w') as f: # save demos """ demos = { "data": { "demo_0": { "images": np.array([...]), "actions": np.array([...]) }, "demo_1": { "images": np.array([...]), "actions": np.array([...]) }, ... } } """ data = f.create_group("data") for demo_name, demo_data in demos["data"].items(): demo = data.create_group(demo_name) for key, value in demo_data.items(): demo.create_dataset(key, data=value)