"""Ravens main training script.""" |
import os |
import pickle |
import json |
import numpy as np |
import hydra |
from cliport import agents |
from cliport import dataset |
from cliport import tasks |
from cliport.utils import utils |
from cliport.environments.environment import Environment |
from torch.utils.data import DataLoader |
@hydra.main(config_path='./cfg', config_name='eval', version_base="1.2") |
def main(vcfg): |
tcfg = utils.load_hydra_config(vcfg['train_config']) |
env = Environment( |
vcfg['assets_root'], |
disp=vcfg['disp'], |
shared_memory=vcfg['shared_memory'], |
hz=480, |
record_cfg=vcfg['record'] |
) |
mode = vcfg['mode'] |
eval_task = vcfg['eval_task'] |
print("eval_task!!!", eval_task) |
if mode not in {'train', 'val', 'test'}: |
raise Exception("Invalid mode. Valid options: train, val, test") |
dataset_type = vcfg['type'] |
if 'multi' in dataset_type: |
ds = dataset.RavensMultiTaskDataset(vcfg['data_dir'], |
tcfg, |
group=eval_task, |
mode=mode, |
n_demos=vcfg['n_demos'], |
augment=False) |
else: |
ds = dataset.RavensDataset(os.path.join(vcfg['data_dir'], f"{eval_task}-{mode}"), |
tcfg, |
n_demos=vcfg['n_demos'], |
augment=False) |
all_results = {} |
name = '{}-{}-n{}'.format(eval_task, vcfg['agent'], vcfg['n_demos']) |
json_name = f"multi-results-{mode}.json" if 'multi' in vcfg['model_path'] else f"results-{mode}.json" |
save_path = vcfg['save_path'] |
print(f"Save path for results: {save_path}") |
if not os.path.exists(save_path): |
os.makedirs(save_path) |
save_json = os.path.join(save_path, f'{name}-{json_name}') |
existing_results = {} |
if os.path.exists(save_json): |
with open(save_json, 'r') as f: |
existing_results = json.load(f) |
ckpts_to_eval = list_ckpts_to_eval(vcfg, existing_results) |
data_loader = DataLoader(ds, shuffle=False, |
pin_memory=False, |
num_workers=1 ) |
print(f"Evaluating: {str(ckpts_to_eval)}") |
for ckpt in ckpts_to_eval: |
model_file = os.path.join(vcfg['model_path'], ckpt) |
if not os.path.exists(model_file) or not os.path.isfile(model_file): |
print(f"Checkpoint not found: {model_file}") |
continue |
elif not vcfg['update_results'] and ckpt in existing_results: |
print(f"Skipping because of existing results for {model_file}.") |
continue |
results = [] |
mean_reward = 0.0 |
for train_run in range(vcfg['n_repeats']): |
utils.set_seed(train_run, torch=True) |
agent = agents.names[vcfg['agent']](name, tcfg, data_loader, data_loader) |
agent.load(model_file) |
print(f"Loaded: {model_file}") |
record = vcfg['record']['save_video'] |
n_demos = vcfg['n_demos'] |
for i in range(0, n_demos): |
print(f'Test: {i + 1}/{n_demos}') |
try: |
episode, seed = ds.load(i) |
except: |
print(f"skip bad example {i}") |
continue |
goal = episode[-1] |
total_reward = 0 |
np.random.seed(seed) |
if 'multi' in dataset_type: |
task_name = ds.get_curr_task() |
task = tasks.names[task_name]() |
print(f'Evaluating on {task_name}') |
else: |
task_name = vcfg['eval_task'] |
task = tasks.names[task_name]() |
task.mode = mode |
env.seed(seed) |
env.set_task(task) |
obs = env.reset() |
info = env.info |
reward = 0 |
if record: |
video_name = f'{task_name}-{i+1:06d}' |
if 'multi' in vcfg['model_task']: |
video_name = f"{vcfg['model_task']}-{video_name}" |
env.start_rec(video_name) |
for _ in range(task.max_steps): |
act = agent.act(obs, info, goal) |
lang_goal = info['lang_goal'] |
obs, reward, done, info = env.step(act) |
total_reward += reward |
if done: |
break |
results.append((total_reward, info)) |
mean_reward = np.mean([r for r, i in results]) |
print(f'Mean: {mean_reward} | Task: {task_name} | Ckpt: {ckpt}') |
if record: |
env.end_rec() |
all_results[ckpt] = { |
'episodes': results, |
'mean_reward': mean_reward, |
} |
if vcfg['save_results']: |
print("save results to:", save_json) |
if os.path.exists(save_json): |
with open(save_json, 'r') as f: |
existing_results = json.load(f) |
existing_results.update(all_results) |
all_results = existing_results |
with open(save_json, 'w') as f: |
json.dump(all_results, f, indent=4) |
def list_ckpts_to_eval(vcfg, existing_results): |
ckpts_to_eval = [] |
if vcfg['checkpoint_type'] == 'last': |
last_ckpt = 'last.ckpt' |
ckpts_to_eval.append(last_ckpt) |
elif vcfg['checkpoint_type'] == 'val_missing': |
checkpoints = sorted([c for c in os.listdir(vcfg['model_path']) if "steps=" in c]) |
ckpts_to_eval = [c for c in checkpoints if c not in existing_results] |
elif vcfg['checkpoint_type'] == 'test_best': |
result_jsons = [c for c in os.listdir(vcfg['results_path']) if "results-val" in c] |
if 'multi' in vcfg['model_task']: |
result_jsons = [r for r in result_jsons if "multi" in r] |
else: |
result_jsons = [r for r in result_jsons if "multi" not in r] |
if len(result_jsons) > 0: |
result_json = result_jsons[0] |
with open(os.path.join(vcfg['results_path'], result_json), 'r') as f: |
eval_res = json.load(f) |
best_checkpoint = 'last.ckpt' |
best_success = -1.0 |
for ckpt, res in eval_res.items(): |
if res['mean_reward'] > best_success: |
best_checkpoint = ckpt |
best_success = res['mean_reward'] |
print(best_checkpoint) |
ckpt = best_checkpoint |
ckpts_to_eval.append(ckpt) |
else: |
print("No best val ckpt found. Using last.ckpt") |
ckpt = 'last.ckpt' |
ckpts_to_eval.append(ckpt) |
else: |
print(f"Looking for: {vcfg['checkpoint_type']}") |
checkpoints = [c for c in os.listdir(vcfg['model_path']) if vcfg['checkpoint_type'] in c] |
checkpoint = checkpoints[0] if len(checkpoints) > 0 else "" |
ckpt = checkpoint |
ckpts_to_eval.append(ckpt) |
print("ckpts_to_eval:", ckpts_to_eval) |
return ckpts_to_eval |
if __name__ == '__main__': |
main() |