File size: 18,698 Bytes
e11e4fe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
from typing import List, Optional, Tuple, Dict
from mlagents.torch_utils import torch, nn
from mlagents.trainers.torch_entities.layers import LinearEncoder, Initialization
import numpy as np

from mlagents.trainers.torch_entities.encoders import (
    SimpleVisualEncoder,
    ResNetVisualEncoder,
    NatureVisualEncoder,
    SmallVisualEncoder,
    FullyConnectedVisualEncoder,
    VectorInput,
)
from mlagents.trainers.settings import EncoderType, ScheduleType
from mlagents.trainers.torch_entities.attention import (
    EntityEmbedding,
    ResidualSelfAttention,
)
from mlagents.trainers.exception import UnityTrainerException
from mlagents_envs.base_env import ObservationSpec, DimensionProperty


class ModelUtils:
    # Minimum supported side for each encoder type. If refactoring an encoder, please
    # adjust these also.
    MIN_RESOLUTION_FOR_ENCODER = {
        EncoderType.FULLY_CONNECTED: 1,
        EncoderType.MATCH3: 5,
        EncoderType.SIMPLE: 20,
        EncoderType.NATURE_CNN: 36,
        EncoderType.RESNET: 15,
    }

    VALID_VISUAL_PROP = frozenset(
        [
            (
                DimensionProperty.TRANSLATIONAL_EQUIVARIANCE,
                DimensionProperty.TRANSLATIONAL_EQUIVARIANCE,
                DimensionProperty.NONE,
            ),
            (DimensionProperty.UNSPECIFIED,) * 3,
        ]
    )

    VALID_VECTOR_PROP = frozenset(
        [(DimensionProperty.NONE,), (DimensionProperty.UNSPECIFIED,)]
    )

    VALID_VAR_LEN_PROP = frozenset(
        [(DimensionProperty.VARIABLE_SIZE, DimensionProperty.NONE)]
    )

    @staticmethod
    def update_learning_rate(optim: torch.optim.Optimizer, lr: float) -> None:
        """
        Apply a learning rate to a torch optimizer.
        :param optim: Optimizer
        :param lr: Learning rate
        """
        for param_group in optim.param_groups:
            param_group["lr"] = lr

    class DecayedValue:
        def __init__(
            self,
            schedule: ScheduleType,
            initial_value: float,
            min_value: float,
            max_step: int,
        ):
            """
            Object that represnets value of a parameter that should be decayed, assuming it is a function of
            global_step.
            :param schedule: Type of learning rate schedule.
            :param initial_value: Initial value before decay.
            :param min_value: Decay value to this value by max_step.
            :param max_step: The final step count where the return value should equal min_value.
            :param global_step: The current step count.
            :return: The value.
            """
            self.schedule = schedule
            self.initial_value = initial_value
            self.min_value = min_value
            self.max_step = max_step

        def get_value(self, global_step: int) -> float:
            """
            Get the value at a given global step.
            :param global_step: Step count.
            :returns: Decayed value at this global step.
            """
            if self.schedule == ScheduleType.CONSTANT:
                return self.initial_value
            elif self.schedule == ScheduleType.LINEAR:
                return ModelUtils.polynomial_decay(
                    self.initial_value, self.min_value, self.max_step, global_step
                )
            else:
                raise UnityTrainerException(f"The schedule {self.schedule} is invalid.")

    @staticmethod
    def polynomial_decay(
        initial_value: float,
        min_value: float,
        max_step: int,
        global_step: int,
        power: float = 1.0,
    ) -> float:
        """
        Get a decayed value based on a polynomial schedule, with respect to the current global step.
        :param initial_value: Initial value before decay.
        :param min_value: Decay value to this value by max_step.
        :param max_step: The final step count where the return value should equal min_value.
        :param global_step: The current step count.
        :param power: Power of polynomial decay. 1.0 (default) is a linear decay.
        :return: The current decayed value.
        """
        global_step = min(global_step, max_step)
        decayed_value = (initial_value - min_value) * (
            1 - float(global_step) / max_step
        ) ** (power) + min_value
        return decayed_value

    @staticmethod
    def get_encoder_for_type(encoder_type: EncoderType) -> nn.Module:
        ENCODER_FUNCTION_BY_TYPE = {
            EncoderType.SIMPLE: SimpleVisualEncoder,
            EncoderType.NATURE_CNN: NatureVisualEncoder,
            EncoderType.RESNET: ResNetVisualEncoder,
            EncoderType.MATCH3: SmallVisualEncoder,
            EncoderType.FULLY_CONNECTED: FullyConnectedVisualEncoder,
        }
        return ENCODER_FUNCTION_BY_TYPE.get(encoder_type)

    @staticmethod
    def _check_resolution_for_encoder(
        height: int, width: int, vis_encoder_type: EncoderType
    ) -> None:
        min_res = ModelUtils.MIN_RESOLUTION_FOR_ENCODER[vis_encoder_type]
        if height < min_res or width < min_res:
            raise UnityTrainerException(
                f"Visual observation resolution ({width}x{height}) is too small for"
                f"the provided EncoderType ({vis_encoder_type.value}). The min dimension is {min_res}"
            )

    @staticmethod
    def get_encoder_for_obs(
        obs_spec: ObservationSpec,
        normalize: bool,
        h_size: int,
        attention_embedding_size: int,
        vis_encode_type: EncoderType,
    ) -> Tuple[nn.Module, int]:
        """
        Returns the encoder and the size of the appropriate encoder.
        :param shape: Tuples that represent the observation dimension.
        :param normalize: Normalize all vector inputs.
        :param h_size: Number of hidden units per layer excluding attention layers.
        :param attention_embedding_size: Number of hidden units per attention layer.
        :param vis_encode_type: Type of visual encoder to use.
        """
        shape = obs_spec.shape
        dim_prop = obs_spec.dimension_property

        # VISUAL
        if dim_prop in ModelUtils.VALID_VISUAL_PROP:
            visual_encoder_class = ModelUtils.get_encoder_for_type(vis_encode_type)
            ModelUtils._check_resolution_for_encoder(
                shape[0], shape[1], vis_encode_type
            )
            return (visual_encoder_class(shape[0], shape[1], shape[2], h_size), h_size)
        # VECTOR
        if dim_prop in ModelUtils.VALID_VECTOR_PROP:
            return (VectorInput(shape[0], normalize), shape[0])
        # VARIABLE LENGTH
        if dim_prop in ModelUtils.VALID_VAR_LEN_PROP:
            return (
                EntityEmbedding(
                    entity_size=shape[1],
                    entity_num_max_elements=shape[0],
                    embedding_size=attention_embedding_size,
                ),
                0,
            )
        # OTHER
        raise UnityTrainerException(f"Unsupported Sensor with specs {obs_spec}")

    @staticmethod
    def create_input_processors(
        observation_specs: List[ObservationSpec],
        h_size: int,
        vis_encode_type: EncoderType,
        attention_embedding_size: int,
        normalize: bool = False,
    ) -> Tuple[nn.ModuleList, List[int]]:
        """
        Creates visual and vector encoders, along with their normalizers.
        :param observation_specs: List of ObservationSpec that represent the observation dimensions.
        :param action_size: Number of additional un-normalized inputs to each vector encoder. Used for
            conditioning network on other values (e.g. actions for a Q function)
        :param h_size: Number of hidden units per layer excluding attention layers.
        :param attention_embedding_size: Number of hidden units per attention layer.
        :param vis_encode_type: Type of visual encoder to use.
        :param unnormalized_inputs: Vector inputs that should not be normalized, and added to the vector
            obs.
        :param normalize: Normalize all vector inputs.
        :return: Tuple of :
         - ModuleList of the encoders
         - A list of embedding sizes (0 if the input requires to be processed with a variable length
         observation encoder)
        """
        encoders: List[nn.Module] = []
        embedding_sizes: List[int] = []
        for obs_spec in observation_specs:
            encoder, embedding_size = ModelUtils.get_encoder_for_obs(
                obs_spec, normalize, h_size, attention_embedding_size, vis_encode_type
            )
            encoders.append(encoder)
            embedding_sizes.append(embedding_size)

        x_self_size = sum(embedding_sizes)  # The size of the "self" embedding
        if x_self_size > 0:
            for enc in encoders:
                if isinstance(enc, EntityEmbedding):
                    enc.add_self_embedding(attention_embedding_size)
        return (nn.ModuleList(encoders), embedding_sizes)

    @staticmethod
    def list_to_tensor(
        ndarray_list: List[np.ndarray], dtype: Optional[torch.dtype] = torch.float32
    ) -> torch.Tensor:
        """
        Converts a list of numpy arrays into a tensor. MUCH faster than
        calling as_tensor on the list directly.
        """
        return torch.as_tensor(np.asanyarray(ndarray_list), dtype=dtype)

    @staticmethod
    def list_to_tensor_list(
        ndarray_list: List[np.ndarray], dtype: Optional[torch.dtype] = torch.float32
    ) -> torch.Tensor:
        """
        Converts a list of numpy arrays into a list of tensors. MUCH faster than
        calling as_tensor on the list directly.
        """
        return [
            torch.as_tensor(np.asanyarray(_arr), dtype=dtype) for _arr in ndarray_list
        ]

    @staticmethod
    def to_numpy(tensor: torch.Tensor) -> np.ndarray:
        """
        Converts a Torch Tensor to a numpy array. If the Tensor is on the GPU, it will
        be brought to the CPU.
        """
        return tensor.detach().cpu().numpy()

    @staticmethod
    def break_into_branches(
        concatenated_logits: torch.Tensor, action_size: List[int]
    ) -> List[torch.Tensor]:
        """
        Takes a concatenated set of logits that represent multiple discrete action branches
        and breaks it up into one Tensor per branch.
        :param concatenated_logits: Tensor that represents the concatenated action branches
        :param action_size: List of ints containing the number of possible actions for each branch.
        :return: A List of Tensors containing one tensor per branch.
        """
        action_idx = [0] + list(np.cumsum(action_size))
        branched_logits = [
            concatenated_logits[:, action_idx[i] : action_idx[i + 1]]
            for i in range(len(action_size))
        ]
        return branched_logits

    @staticmethod
    def actions_to_onehot(
        discrete_actions: torch.Tensor, action_size: List[int]
    ) -> List[torch.Tensor]:
        """
        Takes a tensor of discrete actions and turns it into a List of onehot encoding for each
        action.
        :param discrete_actions: Actions in integer form.
        :param action_size: List of branch sizes. Should be of same size as discrete_actions'
        last dimension.
        :return: List of one-hot tensors, one representing each branch.
        """
        onehot_branches = [
            torch.nn.functional.one_hot(_act.T, action_size[i]).float()
            for i, _act in enumerate(discrete_actions.long().T)
        ]
        return onehot_branches

    @staticmethod
    def dynamic_partition(
        data: torch.Tensor, partitions: torch.Tensor, num_partitions: int
    ) -> List[torch.Tensor]:
        """
        Torch implementation of dynamic_partition :
        https://www.tensorflow.org/api_docs/python/tf/dynamic_partition
        Splits the data Tensor input into num_partitions Tensors according to the indices in
        partitions.
        :param data: The Tensor data that will be split into partitions.
        :param partitions: An indices tensor that determines in which partition each element
        of data will be in.
        :param num_partitions: The number of partitions to output. Corresponds to the
        maximum possible index in the partitions argument.
        :return: A list of Tensor partitions (Their indices correspond to their partition index).
        """
        res: List[torch.Tensor] = []
        for i in range(num_partitions):
            res += [data[(partitions == i).nonzero().squeeze(1)]]
        return res

    @staticmethod
    def masked_mean(tensor: torch.Tensor, masks: torch.Tensor) -> torch.Tensor:
        """
        Returns the mean of the tensor but ignoring the values specified by masks.
        Used for masking out loss functions.
        :param tensor: Tensor which needs mean computation.
        :param masks: Boolean tensor of masks with same dimension as tensor.
        """
        if tensor.ndim == 0:
            return (tensor * masks).sum() / torch.clamp(
                (torch.ones_like(tensor) * masks).float().sum(), min=1.0
            )
        else:
            return (
                tensor.permute(*torch.arange(tensor.ndim - 1, -1, -1)) * masks
            ).sum() / torch.clamp(
                (
                    torch.ones_like(
                        tensor.permute(*torch.arange(tensor.ndim - 1, -1, -1))
                    )
                    * masks
                )
                .float()
                .sum(),
                min=1.0,
            )

    @staticmethod
    def soft_update(source: nn.Module, target: nn.Module, tau: float) -> None:
        """
        Performs an in-place polyak update of the target module based on the source,
        by a ratio of tau. Note that source and target modules must have the same
        parameters, where:
            target = tau * source + (1-tau) * target
        :param source: Source module whose parameters will be used.
        :param target: Target module whose parameters will be updated.
        :param tau: Percentage of source parameters to use in average. Setting tau to
            1 will copy the source parameters to the target.
        """
        with torch.no_grad():
            for source_param, target_param in zip(
                source.parameters(), target.parameters()
            ):
                target_param.data.mul_(1.0 - tau)
                torch.add(
                    target_param.data,
                    source_param.data,
                    alpha=tau,
                    out=target_param.data,
                )

    @staticmethod
    def create_residual_self_attention(
        input_processors: nn.ModuleList, embedding_sizes: List[int], hidden_size: int
    ) -> Tuple[Optional[ResidualSelfAttention], Optional[LinearEncoder]]:
        """
        Creates an RSA if there are variable length observations found in the input processors.
        :param input_processors: A ModuleList of input processors as returned by the function
            create_input_processors().
        :param embedding sizes: A List of embedding sizes as returned by create_input_processors().
        :param hidden_size: The hidden size to use for the RSA.
        :returns: A Tuple of the RSA itself, a self encoder, and the embedding size after the RSA.
            Returns None for the RSA and encoder if no var len inputs are detected.
        """
        rsa, x_self_encoder = None, None
        entity_num_max: int = 0
        var_processors = [p for p in input_processors if isinstance(p, EntityEmbedding)]
        for processor in var_processors:
            entity_max: int = processor.entity_num_max_elements
            # Only adds entity max if it was known at construction
            if entity_max > 0:
                entity_num_max += entity_max
        if len(var_processors) > 0:
            if sum(embedding_sizes):
                x_self_encoder = LinearEncoder(
                    sum(embedding_sizes),
                    1,
                    hidden_size,
                    kernel_init=Initialization.Normal,
                    kernel_gain=(0.125 / hidden_size) ** 0.5,
                )
            rsa = ResidualSelfAttention(hidden_size, entity_num_max)
        return rsa, x_self_encoder

    @staticmethod
    def trust_region_value_loss(
        values: Dict[str, torch.Tensor],
        old_values: Dict[str, torch.Tensor],
        returns: Dict[str, torch.Tensor],
        epsilon: float,
        loss_masks: torch.Tensor,
    ) -> torch.Tensor:
        """
        Evaluates value loss, clipping to stay within a trust region of old value estimates.
        Used for PPO and POCA.
        :param values: Value output of the current network.
        :param old_values: Value stored with experiences in buffer.
        :param returns: Computed returns.
        :param epsilon: Clipping value for value estimate.
        :param loss_mask: Mask for losses. Used with LSTM to ignore 0'ed out experiences.
        """
        value_losses = []
        for name, head in values.items():
            old_val_tensor = old_values[name]
            returns_tensor = returns[name]
            clipped_value_estimate = old_val_tensor + torch.clamp(
                head - old_val_tensor, -1 * epsilon, epsilon
            )
            v_opt_a = (returns_tensor - head) ** 2
            v_opt_b = (returns_tensor - clipped_value_estimate) ** 2
            value_loss = ModelUtils.masked_mean(torch.max(v_opt_a, v_opt_b), loss_masks)
            value_losses.append(value_loss)
        value_loss = torch.mean(torch.stack(value_losses))
        return value_loss

    @staticmethod
    def trust_region_policy_loss(
        advantages: torch.Tensor,
        log_probs: torch.Tensor,
        old_log_probs: torch.Tensor,
        loss_masks: torch.Tensor,
        epsilon: float,
    ) -> torch.Tensor:
        """
        Evaluate policy loss clipped to stay within a trust region. Used for PPO and POCA.
        :param advantages: Computed advantages.
        :param log_probs: Current policy probabilities
        :param old_log_probs: Past policy probabilities
        :param loss_masks: Mask for losses. Used with LSTM to ignore 0'ed out experiences.
        """
        advantage = advantages.unsqueeze(-1)
        r_theta = torch.exp(log_probs - old_log_probs)
        p_opt_a = r_theta * advantage
        p_opt_b = torch.clamp(r_theta, 1.0 - epsilon, 1.0 + epsilon) * advantage
        policy_loss = -1 * ModelUtils.masked_mean(
            torch.min(p_opt_a, p_opt_b), loss_masks
        )
        return policy_loss