File size: 21,610 Bytes
be5548b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
import numpy
import torch
import torch.nn.functional as F
from torch_ac.intrinsic_reward_models import compute_forward_dynamics_loss, compute_inverse_dynamics_loss
from sklearn.metrics import f1_score

from torch_ac.algos.base import BaseAlgo

def compute_balance_mask(target, n_classes):
    if target.float().var() == 0:
        # all the same class, don't train at all
        return torch.zeros_like(target).detach()

    # compute the balance mask
    per_class_n = torch.bincount(target, minlength=n_classes)

    # number of times the least common class (that appeared) appeared
    n_for_each_class = per_class_n[torch.nonzero(per_class_n)].min()

    # undersample other classes
    per_class_n = n_for_each_class  # sample each class that many times

    balanced_indexes_ = []

    for c in range(n_classes):
        c_indexes = torch.where(target == c)[0]
        if len(c_indexes) == 0:
            continue

        # c_sampled_indexes = c_indexes[torch.randint(len(c_indexes), (per_class_n,))]
        c_sampled_indexes = c_indexes[torch.randperm(len(c_indexes))[:per_class_n]]
        balanced_indexes_.append(c_sampled_indexes)

    balanced_indexes = torch.concat(balanced_indexes_)
    balance_mask = torch.zeros_like(target)
    balance_mask[balanced_indexes] = 1.0

    return balance_mask.detach()


class PPOAlgo(BaseAlgo):
    """The Proximal Policy Optimization algorithm
    ([Schulman et al., 2015](https://arxiv.org/abs/1707.06347))."""

    def __init__(self, envs, acmodel, device=None, num_frames_per_proc=None, discount=0.99, lr=0.001, gae_lambda=0.95,
                 entropy_coef=0.01, value_loss_coef=0.5, max_grad_norm=0.5, recurrence=4,
                 adam_eps=1e-5, clip_eps=0.2, epochs=4, batch_size=256, preprocess_obss=None,
                 reshape_reward=None, exploration_bonus=False, exploration_bonus_params=None,
                 expert_exploration_bonus=False, episodic_exploration_bonus=True, exploration_bonus_type="lang",
                 exploration_bonus_tanh=None, clipped_rewards=False, intrinsic_reward_epochs=0,
                 # default is set to fit RND
                 intrinsic_reward_coef=0.1,
                 intrinsic_reward_learning_rate=0.0001,
                 intrinsic_reward_momentum=0,
                 intrinsic_reward_epsilon=0.01,
                 intrinsic_reward_alpha=0.99,
                 intrinsic_reward_max_grad_norm=40,
                 intrinsic_reward_loss_coef=0.1,
                 intrinsic_reward_forward_loss_coef=10,
                 intrinsic_reward_inverse_loss_coef=0.1,
                 reset_rnd_ride_at_phase=False,
                 balance_moa_training=False,
                 moa_memory_dim=128,
                 schedule_lr=False,
                 lr_schedule_end_frames=0,
                 end_lr=0.0,
    ):
        num_frames_per_proc = num_frames_per_proc or 128

        # save config
        self.config = locals()

        super().__init__(
            envs=envs,
            acmodel=acmodel,
            device=device,
            num_frames_per_proc=num_frames_per_proc,
            discount=discount,
            lr=lr,
            gae_lambda=gae_lambda,
            entropy_coef=entropy_coef,
            value_loss_coef=value_loss_coef,
            max_grad_norm=max_grad_norm,
            recurrence=recurrence,
            preprocess_obss=preprocess_obss,
            reshape_reward=reshape_reward,
            exploration_bonus=exploration_bonus,
            expert_exploration_bonus=expert_exploration_bonus,
            episodic_exploration_bonus=episodic_exploration_bonus,
            exploration_bonus_params=exploration_bonus_params,
            exploration_bonus_tanh=exploration_bonus_tanh,
            exploration_bonus_type=exploration_bonus_type,
            clipped_rewards=clipped_rewards,
            intrinsic_reward_loss_coef=intrinsic_reward_loss_coef,
            intrinsic_reward_coef=intrinsic_reward_coef,
            intrinsic_reward_learning_rate=intrinsic_reward_learning_rate,
            intrinsic_reward_momentum=intrinsic_reward_momentum,
            intrinsic_reward_epsilon=intrinsic_reward_epsilon,
            intrinsic_reward_alpha=intrinsic_reward_alpha,
            intrinsic_reward_max_grad_norm=intrinsic_reward_max_grad_norm,
            intrinsic_reward_forward_loss_coef=intrinsic_reward_forward_loss_coef,
            intrinsic_reward_inverse_loss_coef=intrinsic_reward_inverse_loss_coef,
            balance_moa_training=balance_moa_training,
            moa_memory_dim=moa_memory_dim,
            reset_rnd_ride_at_phase=reset_rnd_ride_at_phase,
        )

        self.clip_eps = clip_eps
        self.epochs = epochs
        self.intrinsic_reward_epochs = intrinsic_reward_epochs
        self.batch_size = batch_size

        assert self.batch_size % self.recurrence == 0

        if self.exploration_bonus and "soc_inf" in self.exploration_bonus_type:
            adam_params = list(dict.fromkeys(list(self.acmodel.parameters()) + list(self.moa_net.parameters())))
            self.optimizer = torch.optim.Adam(adam_params, lr, eps=adam_eps)

        else:
            self.optimizer = torch.optim.Adam(self.acmodel.parameters(), lr, eps=adam_eps)

        self.schedule_lr = schedule_lr

        self.lr_schedule_end_frames = lr_schedule_end_frames

        assert end_lr <= lr
        def lr_lambda(step):
            if self.lr_schedule_end_frames == 0:
                # no schedule
                return 1

            end_factor = end_lr/lr
            final_diminished_factor = 1-end_factor
            n_frames = self.step_to_n_frames(step)
            return 1 - (min(n_frames, self.lr_schedule_end_frames) / self.lr_schedule_end_frames) * final_diminished_factor

        self.lr_scheduler = torch.optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda)

        self.batch_num = 0

    def load_status_dict(self, status):
        super().load_status_dict(status)

        if "optimizer_state" in status:
            self.optimizer.load_state_dict(status["optimizer_state"])

        if "lr_scheduler_state" in status:
            self.lr_scheduler.load_state_dict(status["lr_scheduler_state"])

    def get_status_dict(self):

        status_dict = super().get_status_dict()

        status_dict["optimizer_state"] = self.optimizer.state_dict()

        status_dict["lr_scheduler_state"] = self.lr_scheduler.state_dict()

        return status_dict

    def update_parameters(self, exps):
        # Collect experiences

        self.acmodel.train()

        self.update_epoch += 1

        intr_rew_perf = torch.tensor(0.0)
        intr_rew_perf_ = 0.0

        social_influence = False

        if self.exploration_bonus:
            if "rnd" in self.exploration_bonus_type:
                imgs = exps.obs.image.reshape(
                    self.num_procs, self.num_frames_per_proc, *exps.obs.image.shape[1:]
                ).transpose(0, 1)
                mask = exps.mask.reshape(
                    self.num_procs, self.num_frames_per_proc, 1,
                ).transpose(0, 1)

                self.random_target_network.train()
                self.predictor_network.train()

                random_embedding = self.random_target_network(imgs).reshape(self.num_frames_per_proc, self.num_procs, 128)
                predicted_embedding = self.predictor_network(imgs).reshape(self.num_frames_per_proc, self.num_procs, 128)
                intr_rew_loss = self.intrinsic_reward_loss_coef * compute_forward_dynamics_loss(mask*predicted_embedding, mask*random_embedding.detach())

                # update the intr rew models
                self.intrinsic_reward_optimizer.zero_grad()
                intr_rew_loss.backward()
                torch.nn.utils.clip_grad_norm_(self.predictor_network.parameters(), self.intrinsic_reward_max_grad_norm)
                self.intrinsic_reward_optimizer.step()

                intr_rew_perf = intr_rew_loss

            elif "ride" in self.exploration_bonus_type:
                imgs = exps.obs.image.reshape(
                    self.num_procs, self.num_frames_per_proc, *exps.obs.image.shape[1:]
                ).transpose(0, 1)

                mask = exps.mask.reshape(
                    self.num_procs, self.num_frames_per_proc
                ).transpose(0, 1).to(torch.int64)

                # we only take the first (primitive) action
                action = exps.action[:, 0].reshape(
                    self.num_procs, self.num_frames_per_proc
                ).transpose(0, 1).to(torch.int64)

                _mask = mask[:-1]
                _obs = imgs[:-1]
                _actions = action[:-1]
                _next_obs = imgs[1:]

                self.state_embedding_model.train()
                self.forward_dynamics_model.train()
                self.inverse_dynamics_model.train()

                state_emb = self.state_embedding_model(_obs.to(device=self.device))
                next_state_emb = self.state_embedding_model(_next_obs.to(device=self.device))

                pred_next_state_emb = self.forward_dynamics_model(state_emb, _actions.to(device=self.device))

                pred_actions = self.inverse_dynamics_model(state_emb, next_state_emb)

                forward_dynamics_loss = self.intrinsic_reward_forward_loss_coef * \
                                        compute_forward_dynamics_loss(_mask[:,:,None]*pred_next_state_emb, _mask[:,:,None]*next_state_emb)

                inverse_dynamics_loss = self.intrinsic_reward_inverse_loss_coef * \
                                        compute_inverse_dynamics_loss(_mask[:,:,None]*pred_actions, _mask*_actions)

                # update the intr rew models
                self.state_embedding_optimizer.zero_grad()
                self.forward_dynamics_optimizer.zero_grad()
                self.inverse_dynamics_optimizer.zero_grad()

                intr_rew_loss = forward_dynamics_loss + inverse_dynamics_loss
                intr_rew_loss.backward()

                torch.nn.utils.clip_grad_norm_(self.state_embedding_model.parameters(), self.intrinsic_reward_max_grad_norm)
                torch.nn.utils.clip_grad_norm_(self.forward_dynamics_model.parameters(), self.intrinsic_reward_max_grad_norm)
                torch.nn.utils.clip_grad_norm_(self.inverse_dynamics_model.parameters(), self.intrinsic_reward_max_grad_norm)

                self.state_embedding_optimizer.step()
                self.forward_dynamics_optimizer.step()
                self.inverse_dynamics_optimizer.step()

                intr_rew_perf = intr_rew_loss

            elif "soc_inf" in self.exploration_bonus_type:

                # trained together with the policy
                social_influence = True
                self.moa_net.train()
                if self.intrinsic_reward_epochs > 0:
                    raise DeprecationWarning(f"Moa must be trained with the agent. intrinsic_reward_epochs must be 0 but is {self.intrinsic_reward_epochs}")

        for _ in range(self.epochs):
            # Initialize log values

            log_entropies = []
            log_values = []
            log_policy_losses = []
            log_value_losses = []
            log_grad_norms = []
            log_lrs = []

            for inds in self._get_batches_starting_indexes():
                # Initialize batch values

                batch_entropy = 0
                batch_value = 0
                batch_policy_loss = 0
                batch_value_loss = 0
                batch_loss = 0

                # intr reward metrics
                batch_intr_rew_loss = 0
                batch_intr_rew_acc = 0
                batch_intr_rew_f1 = 0

                # Initialize memory

                if self.acmodel.recurrent:
                    memory = exps.memory[inds]

                if social_influence:
                    # Initialize moa memory
                    moa_memory = exps.moa_memory[inds]
                    prev_npc_prim_action = None

                for i in range(self.recurrence):
                    # Create a sub-batch of experience
                    sb = exps[inds + i]

                    # Compute loss
                    if self.acmodel.recurrent:
                        dist, value, memory, policy_embeddings = self.acmodel(sb.obs, memory * sb.mask, return_embeddings=True)
                    else:
                        dist, value, policy_embeddings = self.acmodel(sb.obs, return_embeddings=True)

                    losses = []

                    for head_i, d in enumerate(dist):
                        action_masks = self.acmodel.calculate_action_gradient_masks(sb.action).type(sb.log_prob.type())

                        entropy = (d.entropy() * action_masks[:, head_i]).mean()
                        ratio = torch.exp(d.log_prob(sb.action[:, head_i]) - sb.log_prob[:, head_i])
                        surr1 = ratio * sb.advantage
                        surr2 = torch.clamp(ratio, 1.0 - self.clip_eps, 1.0 + self.clip_eps) * sb.advantage
                        policy_loss = (
                            -torch.min(surr1, surr2) * action_masks[:, head_i]
                        ).mean()

                        value_clipped = sb.value + torch.clamp(value - sb.value, -self.clip_eps, self.clip_eps)
                        surr1 = (value - sb.returnn).pow(2)
                        surr2 = (value_clipped - sb.returnn).pow(2)
                        value_loss = (
                            torch.max(surr1, surr2) * action_masks[:, head_i]
                        ).mean()

                        head_loss = policy_loss - self.entropy_coef * entropy + self.value_loss_coef * value_loss
                        losses.append(head_loss)

                    if social_influence:
                        # moa loss
                        imgs = sb.obs.image
                        mask = sb.mask.to(torch.int64)
                        # we only take the first (primitive) action
                        agent_action = sb.action.to(torch.int64)
                        infos = numpy.array(sb.infos)
                        npc_prim_action = torch.tensor(
                            numpy.array([self.fn_name_to_npc_prim_act[info["NPC_prim_action"]] for info in infos]))
                        npc_utt_action = torch.tensor(
                            numpy.array([self.npc_utterance_to_id[info["NPC_utterance"]] for info in infos]))

                        assert infos.shape == imgs.shape[:1] == agent_action.shape[:1]  # [bs]

                        if i == 0:
                            prev_npc_prim_action = npc_prim_action
                            prev_npc_utt_action = npc_utt_action

                        else:
                            # compute loss and train moa net
                            if self.utterance_moa_net:
                                # transform to long logits
                                target = npc_prim_action.detach().to(self.device) * self.num_npc_utterance_actions + npc_utt_action.detach().to(self.device)
                            else:
                                target = npc_prim_action.detach().to(self.device)

                            if self.balance_moa_training:
                                balance_mask = compute_balance_mask(target, n_classes=self.num_npc_all_actions)
                            else:
                                balance_mask = torch.ones_like(target)

                            moa_predictions_logs, moa_memory = self.moa_net(
                                embeddings=policy_embeddings,
                                npc_previous_prim_actions=prev_npc_prim_action.detach().to(self.device),
                                npc_previous_utterance_actions=prev_npc_utt_action.detach().to(self.device) if self.utterance_moa_net else None,
                                agent_actions=agent_action.detach().to(self.device),
                                memory=moa_memory * mask,
                            )

                            # moa_predictions_logs = moa_predictions_logs.reshape([*prev_shape, -1])  # is this needed

                            # loss
                            ce_loss = torch.nn.CrossEntropyLoss(reduction='none')
                            intr_rew_loss = (
                                balance_mask * mask * ce_loss(
                                input=moa_predictions_logs,
                                target=target,
                            )).mean() * self.intrinsic_reward_loss_coef

                            preds = moa_predictions_logs.detach().argmax(dim=-1)
                            intr_rew_f1 = f1_score(
                                y_pred=preds.detach().cpu().numpy(),
                                y_true=target.detach().cpu().numpy(),
                                average="macro"
                            )

                            intr_rew_acc = (
                                    torch.argmax(moa_predictions_logs.to(self.device), dim=-1) == target
                            ).to(float).mean()

                            batch_intr_rew_loss += intr_rew_loss
                            batch_intr_rew_acc += intr_rew_acc.detach().cpu().numpy().mean()
                            batch_intr_rew_f1 += intr_rew_f1

                            losses.append(intr_rew_loss)  # trained with the policy optimizer

                    loss = torch.stack(losses).mean()

                    # Update batch values
                    batch_entropy += entropy.item()
                    batch_value += value.mean().item()
                    batch_policy_loss += policy_loss.item()
                    batch_value_loss += value_loss.item()
                    batch_loss += loss

                    # Update memories for next epoch
                    # assert self.acmodel.recurrent == (self.recurrence > 1)
                    if self.acmodel.recurrent and i < self.recurrence - 1:
                        exps.memory[inds + i + 1] = memory.detach()

                    if social_influence and i < self.recurrence - 1:
                        exps.moa_memory[inds + i + 1] = moa_memory.detach()


                # Update batch values
                batch_entropy /= self.recurrence
                batch_value /= self.recurrence
                batch_policy_loss /= self.recurrence
                batch_value_loss /= self.recurrence
                batch_loss /= self.recurrence

                # Update actor-critic
                self.optimizer.zero_grad()
                batch_loss.backward()
                grad_norm = sum(p.grad.data.norm(2).item() ** 2 for p in self.acmodel.parameters()) ** 0.5
                torch.nn.utils.clip_grad_norm_(self.acmodel.parameters(), self.max_grad_norm)
                self.optimizer.step()

                self.lr_scheduler.step()

                if social_influence:
                    # recurrence-1 because we skipped the first step
                    batch_intr_rew_loss /= self.recurrence - 1
                    batch_intr_rew_acc /= self.recurrence - 1
                    batch_intr_rew_f1 /= self.recurrence - 1

                intr_rew_perf_ = batch_intr_rew_f1
                intr_rew_perf = batch_intr_rew_acc

                # Update log values

                log_entropies.append(batch_entropy)
                log_values.append(batch_value)
                log_policy_losses.append(batch_policy_loss)
                log_value_losses.append(batch_value_loss)
                log_grad_norms.append(grad_norm)
                log_lrs.append(self.optimizer.param_groups[0]['lr'])

        # Log some values

        logs = {
            "entropy": numpy.mean(log_entropies),
            "value": numpy.mean(log_values),
            "policy_loss": numpy.mean(log_policy_losses),
            "value_loss": numpy.mean(log_value_losses),
            "grad_norm": numpy.mean(log_grad_norms),
            "intr_reward_perf": intr_rew_perf,
            "intr_reward_perf_": intr_rew_perf_,
            "lr": numpy.mean(log_lrs),
        }

        return logs

    def _get_batches_starting_indexes(self):
        """Gives, for each batch, the indexes of the observations given to
        the model and the experiences used to compute the loss at first.

        First, the indexes are the integers from 0 to `self.num_frames` with a step of
        `self.recurrence`, shifted by `self.recurrence//2` one time in two for having
        more diverse batches. Then, the indexes are splited into the different batches.

        Returns
        -------
        batches_starting_indexes : list of list of int
            the indexes of the experiences to be used at first for each batch
        """

        indexes = numpy.arange(0, self.num_frames, self.recurrence)
        indexes = numpy.random.permutation(indexes)

        # Shift starting indexes by self.recurrence//2 half the time
        if self.batch_num % 2 == 1:
            indexes = indexes[(indexes + self.recurrence) % self.num_frames_per_proc != 0]
            indexes += self.recurrence // 2
        self.batch_num += 1

        num_indexes = self.batch_size // self.recurrence
        batches_starting_indexes = [indexes[i:i+num_indexes] for i in range(0, len(indexes), num_indexes)]

        return batches_starting_indexes

    def get_config_dict(self):

        del self.config['envs']
        del self.config['acmodel']
        del self.config['__class__']
        del self.config['self']
        del self.config['preprocess_obss']
        del self.config['device']
        return self.config