File size: 4,572 Bytes
9dc837c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
import os

from datetime import datetime
from dataclasses import dataclass
from typing import Any, Dict, NamedTuple, Optional, TypedDict, Union


@dataclass
class RunArgs:
    algo: str
    env: str
    seed: Optional[int] = None
    use_deterministic_algorithms: bool = True


class EnvHyperparams(NamedTuple):
    is_procgen: bool = False
    n_envs: int = 1
    frame_stack: int = 1
    make_kwargs: Optional[Dict[str, Any]] = None
    no_reward_timeout_steps: Optional[int] = None
    no_reward_fire_steps: Optional[int] = None
    vec_env_class: str = "dummy"
    normalize: bool = False
    normalize_kwargs: Optional[Dict[str, Any]] = None
    rolling_length: int = 100
    train_record_video: bool = False
    video_step_interval: Union[int, float] = 1_000_000
    initial_steps_to_truncate: Optional[int] = None


class Hyperparams(TypedDict, total=False):
    device: str
    n_timesteps: Union[int, float]
    env_hyperparams: Dict[str, Any]
    policy_hyperparams: Dict[str, Any]
    algo_hyperparams: Dict[str, Any]
    eval_params: Dict[str, Any]


@dataclass
class Config:
    args: RunArgs
    hyperparams: Hyperparams
    root_dir: str
    run_id: str = datetime.now().isoformat()

    def seed(self, training: bool = True) -> Optional[int]:
        seed = self.args.seed
        if training or seed is None:
            return seed
        return seed + self.env_hyperparams.get("n_envs", 1)

    @property
    def device(self) -> str:
        return self.hyperparams.get("device", "auto")

    @property
    def n_timesteps(self) -> int:
        return int(self.hyperparams.get("n_timesteps", 100_000))

    @property
    def env_hyperparams(self) -> Dict[str, Any]:
        return self.hyperparams.get("env_hyperparams", {})

    @property
    def policy_hyperparams(self) -> Dict[str, Any]:
        return self.hyperparams.get("policy_hyperparams", {})

    @property
    def algo_hyperparams(self) -> Dict[str, Any]:
        return self.hyperparams.get("algo_hyperparams", {})

    @property
    def eval_params(self) -> Dict[str, Any]:
        return self.hyperparams.get("eval_params", {})

    @property
    def algo(self) -> str:
        return self.args.algo

    @property
    def env_id(self) -> str:
        return self.hyperparams.get("env_id") or self.args.env

    def model_name(self, include_seed: bool = True) -> str:
        # Use arg env name instead of environment name
        parts = [self.algo, self.args.env]
        if include_seed and self.args.seed is not None:
            parts.append(f"S{self.args.seed}")

        # Assume that the custom arg name already has the necessary information
        if not self.hyperparams.get("env_id"):
            make_kwargs = self.env_hyperparams.get("make_kwargs", {})
            if make_kwargs:
                for k, v in make_kwargs.items():
                    if type(v) == bool and v:
                        parts.append(k)
                    elif type(v) == int and v:
                        parts.append(f"{k}{v}")
                    else:
                        parts.append(str(v))

        return "-".join(parts)

    @property
    def run_name(self) -> str:
        parts = [self.model_name(), self.run_id]
        return "-".join(parts)

    @property
    def saved_models_dir(self) -> str:
        return os.path.join(self.root_dir, "saved_models")

    @property
    def downloaded_models_dir(self) -> str:
        return os.path.join(self.root_dir, "downloaded_models")

    def model_dir_name(
        self,
        best: bool = False,
        extension: str = "",
    ) -> str:
        return self.model_name() + ("-best" if best else "") + extension

    def model_dir_path(self, best: bool = False, downloaded: bool = False) -> str:
        return os.path.join(
            self.saved_models_dir if not downloaded else self.downloaded_models_dir,
            self.model_dir_name(best=best),
        )

    @property
    def runs_dir(self) -> str:
        return os.path.join(self.root_dir, "runs")

    @property
    def tensorboard_summary_path(self) -> str:
        return os.path.join(self.runs_dir, self.run_name)

    @property
    def logs_path(self) -> str:
        return os.path.join(self.runs_dir, f"log.yml")

    @property
    def videos_dir(self) -> str:
        return os.path.join(self.root_dir, "videos")

    @property
    def video_prefix(self) -> str:
        return os.path.join(self.videos_dir, self.model_name())

    @property
    def best_videos_dir(self) -> str:
        return os.path.join(self.videos_dir, f"{self.model_name()}-best")