File size: 3,928 Bytes
246c106
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
# --------------------------------------------------------
# Licensed under The MIT License [see LICENSE for details]
# --------------------------------------------------------

"""
TODO: explain
"""
import h5py
import numpy as np
import cv2
import time
from collections import OrderedDict
import robomimic.utils.file_utils as FileUtils

from sim.robomimic.robomimic_runner import (
    create_env, OBS_KEYS, RESOLUTION
)
from sim.robomimic.robomimic_wrapper import RobomimicLowdimWrapper

from typing import Optional, Iterable

DATASET_DIR = 'data/robomimic/datasets'
SUPPORTED_ENVS = ['lift', 'square', 'can']
NUM_EPISODES_PER_TASK = 200


def render_step(env, state):
    env.env.env.sim.set_state_from_flattened(state)
    env.env.env.sim.forward()
    img = env.render()
    img = cv2.resize(img, RESOLUTION)
    return img


def robomimic_dataset_size() -> int:
    return len(SUPPORTED_ENVS) * NUM_EPISODES_PER_TASK


def robomimic_dataset_generator(example_inds: Optional[Iterable[int]] = None):
    if example_inds is None:
        example_inds = range(robomimic_dataset_size())
    
    curr_env_name = None
    for idx in example_inds:
        # get env_name corresponding to idx
        env_name = SUPPORTED_ENVS[idx // NUM_EPISODES_PER_TASK]
        if curr_env_name is None or curr_env_name != env_name:
            # need to load new env
            dataset = f"{DATASET_DIR}/{env_name}/ph/image.hdf5"
            env_meta = FileUtils.get_env_metadata_from_dataset(dataset)
            env_meta["use_image_obs"] = True
            env = create_env(env_meta=env_meta, obs_keys=OBS_KEYS)
            env = RobomimicLowdimWrapper(env=env)
            env.reset()     # NOTE: this is necessary to remove green laser bug
            curr_env_name = env_name

        with h5py.File(dataset) as file:
            demos = file["data"]
            local_episode_idx = idx % NUM_EPISODES_PER_TASK
            if f"demo_{local_episode_idx}" not in demos:
                continue

            demo = demos[f"demo_{local_episode_idx}"]
            obs = demo["obs"]
            states = demo["states"]
            action = demo["actions"][:].astype(np.float32)
            step_obs = np.concatenate([obs[key] for key in OBS_KEYS], axis=-1).astype(np.float32)
            steps = []
            for a, o, s in zip(action, step_obs, states):
                # break into step dict
                image = render_step(env, s)
                step = {
                    "observation": {"state": o, "image": image},
                    "action": a,
                    "language_instruction": f"{env_name}",
                }
                steps.append(OrderedDict(step))
            data_dict = {"steps": steps}
            yield data_dict
            
            # # import imageio
            # for _ in range(3):
            #     steps = []
            #     perturbed_action = action + np.random.normal(0, 0.2, action.shape)
            #     current_state = states[0]
            #     _ = render_step(env, current_state)
            #     for someindex in range(len(action)):
            #         image = env.render()
            #         step = {
            #             "observation": {"image": image},
            #             "action": action[someindex],
            #             "language_instruction": f"{env_name}",
            #         }
            #         steps.append(OrderedDict(step))

            #         # simulate action
            #         env.step(perturbed_action[someindex])

            #     # # save video
            #     # frames = [step["observation"]["image"] for step in steps]
            #     # imageio.mimsave(f"test.mp4", frames, fps=10)
            #     # while not (user_input := input("Continue? (y/n)")) in ["y", "n"]:
            #     #     print("Invalid input")

            #     data_dict = {"steps": steps}
            #     yield data_dict

    env.close()