|
"""Ravens main training script.""" |
|
|
|
import os |
|
import pickle |
|
import json |
|
|
|
import numpy as np |
|
import hydra |
|
from cliport import agents |
|
from cliport import dataset |
|
from cliport import tasks |
|
from cliport.utils import utils |
|
from cliport.environments.environment import Environment |
|
from torch.utils.data import DataLoader |
|
|
|
|
|
@hydra.main(config_path='./cfg', config_name='eval', version_base="1.2") |
|
def main(vcfg): |
|
|
|
tcfg = utils.load_hydra_config(vcfg['train_config']) |
|
|
|
|
|
env = Environment( |
|
vcfg['assets_root'], |
|
disp=vcfg['disp'], |
|
shared_memory=vcfg['shared_memory'], |
|
hz=480, |
|
record_cfg=vcfg['record'] |
|
) |
|
|
|
|
|
mode = vcfg['mode'] |
|
eval_task = vcfg['eval_task'] |
|
print("eval_task!!!", eval_task) |
|
|
|
if mode not in {'train', 'val', 'test'}: |
|
raise Exception("Invalid mode. Valid options: train, val, test") |
|
|
|
|
|
dataset_type = vcfg['type'] |
|
if 'multi' in dataset_type: |
|
ds = dataset.RavensMultiTaskDataset(vcfg['data_dir'], |
|
tcfg, |
|
group=eval_task, |
|
mode=mode, |
|
n_demos=vcfg['n_demos'], |
|
augment=False) |
|
else: |
|
ds = dataset.RavensDataset(os.path.join(vcfg['data_dir'], f"{eval_task}-{mode}"), |
|
tcfg, |
|
n_demos=vcfg['n_demos'], |
|
augment=False) |
|
|
|
all_results = {} |
|
name = '{}-{}-n{}'.format(eval_task, vcfg['agent'], vcfg['n_demos']) |
|
|
|
|
|
json_name = f"multi-results-{mode}.json" if 'multi' in vcfg['model_path'] else f"results-{mode}.json" |
|
save_path = vcfg['save_path'] |
|
print(f"Save path for results: {save_path}") |
|
if not os.path.exists(save_path): |
|
os.makedirs(save_path) |
|
save_json = os.path.join(save_path, f'{name}-{json_name}') |
|
|
|
|
|
existing_results = {} |
|
if os.path.exists(save_json): |
|
with open(save_json, 'r') as f: |
|
existing_results = json.load(f) |
|
|
|
|
|
ckpts_to_eval = list_ckpts_to_eval(vcfg, existing_results) |
|
data_loader = DataLoader(ds, shuffle=False, |
|
pin_memory=False, |
|
num_workers=1 ) |
|
|
|
|
|
print(f"Evaluating: {str(ckpts_to_eval)}") |
|
for ckpt in ckpts_to_eval: |
|
model_file = os.path.join(vcfg['model_path'], ckpt) |
|
|
|
if not os.path.exists(model_file) or not os.path.isfile(model_file): |
|
print(f"Checkpoint not found: {model_file}") |
|
continue |
|
elif not vcfg['update_results'] and ckpt in existing_results: |
|
print(f"Skipping because of existing results for {model_file}.") |
|
continue |
|
|
|
results = [] |
|
mean_reward = 0.0 |
|
|
|
|
|
for train_run in range(vcfg['n_repeats']): |
|
|
|
|
|
utils.set_seed(train_run, torch=True) |
|
agent = agents.names[vcfg['agent']](name, tcfg, data_loader, data_loader) |
|
|
|
|
|
agent.load(model_file) |
|
print(f"Loaded: {model_file}") |
|
|
|
record = vcfg['record']['save_video'] |
|
n_demos = vcfg['n_demos'] |
|
|
|
|
|
for i in range(0, n_demos): |
|
print(f'Test: {i + 1}/{n_demos}') |
|
try: |
|
episode, seed = ds.load(i) |
|
except: |
|
print(f"skip bad example {i}") |
|
continue |
|
goal = episode[-1] |
|
total_reward = 0 |
|
np.random.seed(seed) |
|
|
|
|
|
if 'multi' in dataset_type: |
|
task_name = ds.get_curr_task() |
|
task = tasks.names[task_name]() |
|
print(f'Evaluating on {task_name}') |
|
else: |
|
task_name = vcfg['eval_task'] |
|
task = tasks.names[task_name]() |
|
|
|
task.mode = mode |
|
env.seed(seed) |
|
env.set_task(task) |
|
obs = env.reset() |
|
info = env.info |
|
reward = 0 |
|
|
|
|
|
if record: |
|
video_name = f'{task_name}-{i+1:06d}' |
|
if 'multi' in vcfg['model_task']: |
|
video_name = f"{vcfg['model_task']}-{video_name}" |
|
env.start_rec(video_name) |
|
|
|
for _ in range(task.max_steps): |
|
act = agent.act(obs, info, goal) |
|
lang_goal = info['lang_goal'] |
|
|
|
|
|
obs, reward, done, info = env.step(act) |
|
total_reward += reward |
|
|
|
if done: |
|
break |
|
|
|
results.append((total_reward, info)) |
|
mean_reward = np.mean([r for r, i in results]) |
|
print(f'Mean: {mean_reward} | Task: {task_name} | Ckpt: {ckpt}') |
|
|
|
|
|
if record: |
|
env.end_rec() |
|
|
|
all_results[ckpt] = { |
|
'episodes': results, |
|
'mean_reward': mean_reward, |
|
} |
|
|
|
|
|
if vcfg['save_results']: |
|
print("save results to:", save_json) |
|
|
|
if os.path.exists(save_json): |
|
with open(save_json, 'r') as f: |
|
existing_results = json.load(f) |
|
existing_results.update(all_results) |
|
all_results = existing_results |
|
|
|
with open(save_json, 'w') as f: |
|
json.dump(all_results, f, indent=4) |
|
|
|
|
|
def list_ckpts_to_eval(vcfg, existing_results): |
|
ckpts_to_eval = [] |
|
|
|
|
|
if vcfg['checkpoint_type'] == 'last': |
|
last_ckpt = 'last.ckpt' |
|
ckpts_to_eval.append(last_ckpt) |
|
|
|
|
|
elif vcfg['checkpoint_type'] == 'val_missing': |
|
checkpoints = sorted([c for c in os.listdir(vcfg['model_path']) if "steps=" in c]) |
|
ckpts_to_eval = [c for c in checkpoints if c not in existing_results] |
|
|
|
|
|
elif vcfg['checkpoint_type'] == 'test_best': |
|
result_jsons = [c for c in os.listdir(vcfg['results_path']) if "results-val" in c] |
|
if 'multi' in vcfg['model_task']: |
|
result_jsons = [r for r in result_jsons if "multi" in r] |
|
else: |
|
result_jsons = [r for r in result_jsons if "multi" not in r] |
|
|
|
if len(result_jsons) > 0: |
|
result_json = result_jsons[0] |
|
with open(os.path.join(vcfg['results_path'], result_json), 'r') as f: |
|
eval_res = json.load(f) |
|
best_checkpoint = 'last.ckpt' |
|
best_success = -1.0 |
|
for ckpt, res in eval_res.items(): |
|
if res['mean_reward'] > best_success: |
|
best_checkpoint = ckpt |
|
best_success = res['mean_reward'] |
|
print(best_checkpoint) |
|
ckpt = best_checkpoint |
|
ckpts_to_eval.append(ckpt) |
|
else: |
|
print("No best val ckpt found. Using last.ckpt") |
|
ckpt = 'last.ckpt' |
|
ckpts_to_eval.append(ckpt) |
|
|
|
|
|
else: |
|
print(f"Looking for: {vcfg['checkpoint_type']}") |
|
checkpoints = [c for c in os.listdir(vcfg['model_path']) if vcfg['checkpoint_type'] in c] |
|
checkpoint = checkpoints[0] if len(checkpoints) > 0 else "" |
|
ckpt = checkpoint |
|
ckpts_to_eval.append(ckpt) |
|
|
|
print("ckpts_to_eval:", ckpts_to_eval) |
|
return ckpts_to_eval |
|
|
|
|
|
if __name__ == '__main__': |
|
main() |
|
|