|
"""Data collection script.""" |
|
|
|
import os |
|
import hydra |
|
import numpy as np |
|
import random |
|
|
|
from cliport import tasks |
|
from cliport.dataset import RavensDataset |
|
from cliport.environments.environment import Environment |
|
import IPython |
|
import random |
|
|
|
@hydra.main(config_path='./cfg', config_name='data') |
|
def main(cfg): |
|
|
|
env = Environment( |
|
cfg['assets_root'], |
|
disp=cfg['disp'], |
|
shared_memory=cfg['shared_memory'], |
|
hz=480, |
|
record_cfg=cfg['record'] |
|
) |
|
cfg['task'] = cfg['task'].replace("_", "-") |
|
task = tasks.names[cfg['task']]() |
|
task.mode = cfg['mode'] |
|
record = cfg['record']['save_video'] |
|
save_data = cfg['save_data'] |
|
|
|
|
|
agent = task.oracle(env) |
|
data_path = os.path.join(cfg['data_dir'], "{}-{}".format(cfg['task'], task.mode)) |
|
dataset = RavensDataset(data_path, cfg, n_demos=0, augment=False) |
|
print(f"Saving to: {data_path}") |
|
print(f"Mode: {task.mode}") |
|
|
|
|
|
seed = dataset.max_seed |
|
max_eps = 3 * cfg['n'] |
|
|
|
if seed < 0: |
|
if task.mode == 'train': |
|
seed = -2 |
|
elif task.mode == 'val': |
|
seed = -1 |
|
elif task.mode == 'test': |
|
seed = -1 + 10000 |
|
else: |
|
raise Exception("Invalid mode. Valid options: train, val, test") |
|
|
|
if 'regenerate_data' in cfg: |
|
dataset.n_episodes = 0 |
|
|
|
curr_run_eps = 0 |
|
|
|
while dataset.n_episodes < cfg['n'] and curr_run_eps < max_eps: |
|
|
|
episode, total_reward = [], 0 |
|
seed += 2 |
|
|
|
|
|
np.random.seed(seed) |
|
random.seed(seed) |
|
print('Oracle demo: {}/{} | Seed: {}'.format(dataset.n_episodes + 1, cfg['n'], seed)) |
|
try: |
|
curr_run_eps += 1 |
|
env.set_task(task) |
|
obs = env.reset() |
|
info = env.info |
|
reward = 0 |
|
|
|
|
|
if task.mode == 'val' and seed > (-1 + 10000): |
|
raise Exception("!!! Seeds for val set will overlap with the test set !!!") |
|
|
|
|
|
if record: |
|
env.start_rec(f'{dataset.n_episodes+1:06d}') |
|
|
|
|
|
|
|
for _ in range(task.max_steps): |
|
act = agent.act(obs, info) |
|
episode.append((obs, act, reward, info)) |
|
lang_goal = info['lang_goal'] |
|
obs, reward, done, info = env.step(act) |
|
total_reward += reward |
|
print(f'Total Reward: {total_reward:.3f} | Done: {done} | Goal: {lang_goal}') |
|
if done: |
|
break |
|
if record: |
|
env.end_rec() |
|
|
|
except Exception as e: |
|
from pygments import highlight |
|
from pygments.lexers import PythonLexer |
|
from pygments.formatters import TerminalFormatter |
|
import traceback |
|
|
|
to_print = highlight(f"{str(traceback.format_exc())}", PythonLexer(), TerminalFormatter()) |
|
print(to_print) |
|
if record: |
|
env.end_rec() |
|
continue |
|
|
|
episode.append((obs, None, reward, info)) |
|
|
|
|
|
if save_data and total_reward > 0.99: |
|
dataset.add(seed, episode) |
|
if hasattr(env, 'blender_recorder'): |
|
print("blender pickle saved to ", '{}/blender_demo_{}.pkl'.format(data_path, dataset.n_episodes)) |
|
env.blender_recorder.save('{}/blender_demo_{}.pkl'.format(data_path, dataset.n_episodes)) |
|
|
|
|
|
if __name__ == '__main__': |
|
main() |
|
|