File size: 2,912 Bytes
79943a9
 
 
 
 
 
8e63d2a
79943a9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8e63d2a
79943a9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import os
import shutil

from dataclasses import dataclass
from typing import NamedTuple, Optional

from rl_algo_impls.shared.vec_env import make_eval_env
from rl_algo_impls.runner.config import Config, EnvHyperparams, Hyperparams, RunArgs
from rl_algo_impls.runner.running_utils import (
    load_hyperparams,
    set_seeds,
    get_device,
    make_policy,
)
from rl_algo_impls.shared.callbacks.eval_callback import evaluate
from rl_algo_impls.shared.policy.policy import Policy
from rl_algo_impls.shared.stats import EpisodesStats


@dataclass
class EvalArgs(RunArgs):
    render: bool = True
    best: bool = True
    n_envs: Optional[int] = 1
    n_episodes: int = 3
    deterministic_eval: Optional[bool] = None
    no_print_returns: bool = False
    wandb_run_path: Optional[str] = None


class Evaluation(NamedTuple):
    policy: Policy
    stats: EpisodesStats
    config: Config


def evaluate_model(args: EvalArgs, root_dir: str) -> Evaluation:
    if args.wandb_run_path:
        import wandb

        api = wandb.Api()
        run = api.run(args.wandb_run_path)
        params = run.config

        args.algo = params["algo"]
        args.env = params["env"]
        args.seed = params.get("seed", None)
        args.use_deterministic_algorithms = params.get(
            "use_deterministic_algorithms", True
        )

        config = Config(args, Hyperparams.from_dict_with_extra_fields(params), root_dir)
        model_path = config.model_dir_path(best=args.best, downloaded=True)

        model_archive_name = config.model_dir_name(best=args.best, extension=".zip")
        run.file(model_archive_name).download()
        if os.path.isdir(model_path):
            shutil.rmtree(model_path)
        shutil.unpack_archive(model_archive_name, model_path)
        os.remove(model_archive_name)
    else:
        hyperparams = load_hyperparams(args.algo, args.env)

        config = Config(args, hyperparams, root_dir)
        model_path = config.model_dir_path(best=args.best)

    print(args)

    set_seeds(args.seed, args.use_deterministic_algorithms)

    env = make_eval_env(
        config,
        EnvHyperparams(**config.env_hyperparams),
        override_n_envs=args.n_envs,
        render=args.render,
        normalize_load_path=model_path,
    )
    device = get_device(config, env)
    policy = make_policy(
        args.algo,
        env,
        device,
        load_path=model_path,
        **config.policy_hyperparams,
    ).eval()

    deterministic = (
        args.deterministic_eval
        if args.deterministic_eval is not None
        else config.eval_params.get("deterministic", True)
    )
    return Evaluation(
        policy,
        evaluate(
            env,
            policy,
            args.n_episodes,
            render=args.render,
            deterministic=deterministic,
            print_returns=not args.no_print_returns,
        ),
        config,
    )