File size: 2,825 Bytes
79943a9
8e63d2a
 
 
79943a9
 
 
 
 
8e63d2a
 
79943a9
 
 
 
 
 
 
 
 
 
8e63d2a
79943a9
 
 
 
 
 
 
 
 
8e63d2a
 
 
79943a9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8e63d2a
 
79943a9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
from collections import deque
from typing import Any, Dict, List, Optional

import numpy as np
from torch.utils.tensorboard.writer import SummaryWriter

from rl_algo_impls.shared.stats import Episode, EpisodesStats
from rl_algo_impls.wrappers.vectorable_wrapper import (
    VecEnvObs,
    VecEnvStepReturn,
    VecotarableWrapper,
)


class EpisodeStatsWriter(VecotarableWrapper):
    def __init__(
        self,
        env,
        tb_writer: SummaryWriter,
        training: bool = True,
        rolling_length=100,
        additional_keys_to_log: Optional[List[str]] = None,
    ):
        super().__init__(env)
        self.training = training
        self.tb_writer = tb_writer
        self.rolling_length = rolling_length
        self.episodes = deque(maxlen=rolling_length)
        self.total_steps = 0
        self.episode_cnt = 0
        self.last_episode_cnt_print = 0
        self.additional_keys_to_log = (
            additional_keys_to_log if additional_keys_to_log is not None else []
        )

    def step(self, actions: np.ndarray) -> VecEnvStepReturn:
        obs, rews, dones, infos = self.env.step(actions)
        self._record_stats(infos)
        return obs, rews, dones, infos

    # Support for stable_baselines3.common.vec_env.VecEnvWrapper
    def step_wait(self) -> VecEnvStepReturn:
        obs, rews, dones, infos = self.env.step_wait()
        self._record_stats(infos)
        return obs, rews, dones, infos

    def _record_stats(self, infos: List[Dict[str, Any]]) -> None:
        self.total_steps += getattr(self.env, "num_envs", 1)
        step_episodes = []
        for info in infos:
            ep_info = info.get("episode")
            if ep_info:
                additional_info = {k: info[k] for k in self.additional_keys_to_log}
                episode = Episode(ep_info["r"], ep_info["l"], info=additional_info)
                step_episodes.append(episode)
                self.episodes.append(episode)
        if step_episodes:
            tag = "train" if self.training else "eval"
            step_stats = EpisodesStats(step_episodes, simple=True)
            step_stats.write_to_tensorboard(self.tb_writer, tag, self.total_steps)
            rolling_stats = EpisodesStats(self.episodes)
            rolling_stats.write_to_tensorboard(
                self.tb_writer, f"{tag}_rolling", self.total_steps
            )
            self.episode_cnt += len(step_episodes)
            if self.episode_cnt >= self.last_episode_cnt_print + self.rolling_length:
                print(
                    f"Episode: {self.episode_cnt} | "
                    f"Steps: {self.total_steps} | "
                    f"{rolling_stats}"
                )
                self.last_episode_cnt_print += self.rolling_length

    def reset(self) -> VecEnvObs:
        return self.env.reset()