File size: 4,710 Bytes
9b19c29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
import logging
from abc import ABC, abstractmethod
from collections.abc import Callable
from enum import Enum
from typing import TYPE_CHECKING

import torch

from tianshou.highlevel.world import World

if TYPE_CHECKING:
    from tianshou.highlevel.module.core import TDevice

log = logging.getLogger(__name__)


class PersistEvent(Enum):
    """Enumeration of persistence events that Persistence objects can react to."""

    PERSIST_POLICY = "persist_policy"
    """Policy neural network is persisted (new best found)"""


class RestoreEvent(Enum):
    """Enumeration of restoration events that Persistence objects can react to."""

    RESTORE_POLICY = "restore_policy"
    """Policy neural network parameters are restored"""


class Persistence(ABC):
    @abstractmethod
    def persist(self, event: PersistEvent, world: World) -> None:
        pass

    @abstractmethod
    def restore(self, event: RestoreEvent, world: World) -> None:
        pass


class PersistenceGroup(Persistence):
    """Groups persistence handler such that they can be applied collectively."""

    def __init__(self, *p: Persistence, enabled: bool = True):
        self.items = p
        self.enabled = enabled

    def persist(self, event: PersistEvent, world: World) -> None:
        if not self.enabled:
            return
        for item in self.items:
            item.persist(event, world)

    def restore(self, event: RestoreEvent, world: World) -> None:
        for item in self.items:
            item.restore(event, world)


class PolicyPersistence:
    class Mode(Enum):
        """Mode of persistence."""

        POLICY_STATE_DICT = "policy_state_dict"
        """Persist only the policy's state dictionary. Note that for a policy to be restored from
        such a dictionary, it is necessary to first create a structurally equivalent object which can
        accept the respective state."""
        POLICY = "policy"
        """Persist the entire policy. This is larger but has the advantage of the policy being loadable
        without requiring an environment to be instantiated.
        It has the potential disadvantage that upon breaking code changes in the policy implementation
        (e.g. renamed/moved class), it will no longer be loadable.
        Note that a precondition is that the policy be picklable in its entirety.
        """

        def get_filename(self) -> str:
            return self.value + ".pt"

    def __init__(
        self,
        additional_persistence: Persistence | None = None,
        enabled: bool = True,
        mode: Mode = Mode.POLICY,
    ):
        """Handles persistence of the policy.

        :param additional_persistence: a persistence instance which is to be invoked whenever
            this object is used to persist/restore data
        :param enabled: whether persistence is enabled (restoration is always enabled)
        :param mode: the persistence mode
        """
        self.additional_persistence = additional_persistence
        self.enabled = enabled
        self.mode = mode

    def persist(self, policy: torch.nn.Module, world: World) -> None:
        if not self.enabled:
            return
        path = world.persist_path(self.mode.get_filename())
        match self.mode:
            case self.Mode.POLICY_STATE_DICT:
                log.info(f"Saving policy state dictionary in {path}")
                torch.save(policy.state_dict(), path)
            case self.Mode.POLICY:
                log.info(f"Saving policy object in {path}")
                torch.save(policy, path)
            case _:
                raise NotImplementedError
        if self.additional_persistence is not None:
            self.additional_persistence.persist(PersistEvent.PERSIST_POLICY, world)

    def restore(self, policy: torch.nn.Module, world: World, device: "TDevice") -> None:
        path = world.restore_path(self.mode.get_filename())
        log.info(f"Restoring policy from {path}")
        match self.mode:
            case self.Mode.POLICY_STATE_DICT:
                state_dict = torch.load(path, map_location=device)
            case self.Mode.POLICY:
                loaded_policy: torch.nn.Module = torch.load(path, map_location=device)
                state_dict = loaded_policy.state_dict()
            case _:
                raise NotImplementedError
        policy.load_state_dict(state_dict)
        if self.additional_persistence is not None:
            self.additional_persistence.restore(RestoreEvent.RESTORE_POLICY, world)

    def get_save_best_fn(self, world: World) -> Callable[[torch.nn.Module], None]:
        def save_best_fn(pol: torch.nn.Module) -> None:
            self.persist(pol, world)

        return save_best_fn