File size: 4,703 Bytes
05c9ac2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
from typing import List, Optional, NamedTuple
from mlagents.torch_utils import torch
import numpy as np

from mlagents.trainers.torch_entities.utils import ModelUtils
from mlagents.trainers.buffer import AgentBuffer, BufferKey
from mlagents_envs.base_env import _ActionTupleBase


class LogProbsTuple(_ActionTupleBase):
    """
    An object whose fields correspond to the log probs of actions of different types.
    Continuous and discrete are numpy arrays
    Dimensions are of (n_agents, continuous_size) and (n_agents, discrete_size),
    respectively. Note, this also holds when continuous or discrete size is
    zero.
    """

    @property
    def discrete_dtype(self) -> np.dtype:
        """
        The dtype of a discrete log probability.
        """
        return np.float32

    @staticmethod
    def empty_log_probs() -> "LogProbsTuple":
        """
        Generates a dummy LogProbsTuple
        """
        return LogProbsTuple()


class ActionLogProbs(NamedTuple):
    """
    A NamedTuple containing the tensor for continuous log probs and list of tensors for
    discrete log probs of individual actions as well as all the log probs for an entire branch.
    Utility functions provide numpy <=> tensor conversions to be used by the optimizers.
    :param continuous_tensor: Torch tensor corresponding to log probs of continuous actions
    :param discrete_list: List of Torch tensors each corresponding to log probs of the discrete actions that were
    sampled.
    :param all_discrete_list: List of Torch tensors each corresponding to all log probs of
    a discrete action branch, even the discrete actions that were not sampled. all_discrete_list is a list of Tensors,
    each Tensor corresponds to one discrete branch log probabilities.
    """

    continuous_tensor: torch.Tensor
    discrete_list: Optional[List[torch.Tensor]]
    all_discrete_list: Optional[List[torch.Tensor]]

    @property
    def discrete_tensor(self):
        """
        Returns the discrete log probs list as a stacked tensor
        """
        return torch.stack(self.discrete_list, dim=-1)

    @property
    def all_discrete_tensor(self):
        """
        Returns the discrete log probs of each branch as a tensor
        """
        return torch.cat(self.all_discrete_list, dim=1)

    def to_log_probs_tuple(self) -> LogProbsTuple:
        """
        Returns a LogProbsTuple. Only adds if tensor is not None. Otherwise,
        LogProbsTuple uses a default.
        """
        log_probs_tuple = LogProbsTuple()
        if self.continuous_tensor is not None:
            continuous = ModelUtils.to_numpy(self.continuous_tensor)
            log_probs_tuple.add_continuous(continuous)
        if self.discrete_list is not None:
            discrete = ModelUtils.to_numpy(self.discrete_tensor)
            log_probs_tuple.add_discrete(discrete)
        return log_probs_tuple

    def _to_tensor_list(self) -> List[torch.Tensor]:
        """
        Returns the tensors in the ActionLogProbs as a flat List of torch Tensors. This
        is private and serves as a utility for self.flatten()
        """
        tensor_list: List[torch.Tensor] = []
        if self.continuous_tensor is not None:
            tensor_list.append(self.continuous_tensor)
        if self.discrete_list is not None:
            tensor_list.append(self.discrete_tensor)
        return tensor_list

    def flatten(self) -> torch.Tensor:
        """
        A utility method that returns all log probs in ActionLogProbs as a flattened tensor.
        This is useful for algorithms like PPO which can treat all log probs in the same way.
        """
        return torch.cat(self._to_tensor_list(), dim=1)

    @staticmethod
    def from_buffer(buff: AgentBuffer) -> "ActionLogProbs":
        """
        A static method that accesses continuous and discrete log probs fields in an AgentBuffer
        and constructs the corresponding ActionLogProbs from the retrieved np arrays.
        """
        continuous: torch.Tensor = None
        discrete: List[torch.Tensor] = None  # type: ignore

        if BufferKey.CONTINUOUS_LOG_PROBS in buff:
            continuous = ModelUtils.list_to_tensor(buff[BufferKey.CONTINUOUS_LOG_PROBS])
        if BufferKey.DISCRETE_LOG_PROBS in buff:
            discrete_tensor = ModelUtils.list_to_tensor(
                buff[BufferKey.DISCRETE_LOG_PROBS]
            )
            # This will keep discrete_list = None which enables flatten()
            if discrete_tensor.shape[1] > 0:
                discrete = [
                    discrete_tensor[..., i] for i in range(discrete_tensor.shape[-1])
                ]
        return ActionLogProbs(continuous, discrete, None)