File size: 4,915 Bytes
d746b98
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import math

import numpy as np
import torch
from torch import nn, Tensor
from torch.distributions import Categorical


class PositionalEncoding(nn.Module):

    def __init__(self, d_model: int, max_len: int = 100):
        super().__init__()
        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = torch.zeros(max_len, d_model)
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, positions: Tensor) -> Tensor:
        return self.pe[positions]


class Actor(nn.Module):

    def __init__(self, pos_encoder):
        super(Actor, self).__init__()
        self.activation = nn.Tanh()
        self.project = nn.Linear(4, 8)
        nn.init.xavier_uniform_(self.project.weight, gain=1.0)
        nn.init.constant_(self.project.bias, 0)
        self.pos_encoder = pos_encoder

        self.embedding_fixed = nn.Embedding(2, 1)
        self.embedding_legal_op = nn.Embedding(2, 1)

        self.tokens_start_end = nn.Embedding(3, 4)

        # self.conv_transform = nn.Conv1d(5, 1, 1)
        # nn.init.kaiming_normal_(self.conv_transform.weight, mode="fan_out", nonlinearity="relu")
        # nn.init.constant_(self.conv_transform.bias, 0)

        self.enc1 = nn.TransformerEncoderLayer(8, 1, dim_feedforward=8 * 4, dropout=0.0, batch_first=True,
                                               norm_first=True)
        self.enc2 = nn.TransformerEncoderLayer(8, 1, dim_feedforward=8 * 4, dropout=0.0, batch_first=True,
                                               norm_first=True)

        self.final_tmp = nn.Sequential(
            layer_init_tanh(nn.Linear(8, 32)),
            nn.Tanh(),
            layer_init_tanh(nn.Linear(32, 1), std=0.01)
        )
        self.no_op = nn.Sequential(
            layer_init_tanh(nn.Linear(8, 32)),
            nn.Tanh(),
            layer_init_tanh(nn.Linear(32, 1), std=0.01)
        )

    def forward(self, obs, attention_interval_mask, job_resource, mask, indexes_inter, tokens_start_end):
        embedded_obs = torch.cat((self.embedding_fixed(obs[:, :, :, 0].long()), obs[:, :, :, 1:3],
                                  self.embedding_legal_op(obs[:, :, :, 3].long())), dim=3)
        non_zero_tokens = tokens_start_end != 0
        t = tokens_start_end[non_zero_tokens].long()
        embedded_obs[non_zero_tokens] = self.tokens_start_end(t)
        pos_encoder = self.pos_encoder(indexes_inter.long())
        pos_encoder[non_zero_tokens] = 0
        obs = self.project(embedded_obs) + pos_encoder

        transformed_obs = obs.view(-1, obs.shape[2], obs.shape[3])
        attention_interval_mask = attention_interval_mask.view(-1, attention_interval_mask.shape[-1])
        transformed_obs = self.enc1(transformed_obs, src_key_padding_mask=attention_interval_mask == 1)
        transformed_obs = transformed_obs.view(obs.shape)
        obs = transformed_obs.mean(dim=2)

        job_resource = job_resource[:, :-1, :-1] == 0

        obs_action = self.enc2(obs, src_mask=job_resource) + obs

        logits = torch.cat((self.final_tmp(obs_action).squeeze(2), self.no_op(obs_action).mean(dim=1)), dim=1)
        return logits.masked_fill(mask == 0, -3.4028234663852886e+38)


class Agent(nn.Module):
    def __init__(self):
        super(Agent, self).__init__()
        self.pos_encoder = PositionalEncoding(8)
        self.actor = Actor(self.pos_encoder)

    def forward(self, data, attention_interval_mask, job_resource_masks, mask, indexes_inter, tokens_start_end,
                action=None):
        logits = self.actor(data, attention_interval_mask, job_resource_masks, mask, indexes_inter, tokens_start_end)
        probs = Categorical(logits=logits)
        if action is None:
            probabilities = probs.probs
            actions = torch.multinomial(probabilities, probabilities.shape[1])
            return actions, torch.log(probabilities), probs.entropy()
        else:
            return logits, probs.log_prob(action), probs.entropy()

    def get_action_only(self, data, attention_interval_mask, job_resource_masks, mask, indexes_inter, tokens_start_end):
        logits = self.actor(data, attention_interval_mask, job_resource_masks, mask, indexes_inter, tokens_start_end)
        probs = Categorical(logits=logits)
        return probs.sample()

    def get_logits_only(self,data, attention_interval_mask, job_resource_masks, mask, indexes_inter, tokens_start_end):
        logits = self.actor(data, attention_interval_mask, job_resource_masks, mask, indexes_inter, tokens_start_end)
        return logits


def layer_init_tanh(layer, std=np.sqrt(2), bias_const=0.0):
    torch.nn.init.orthogonal_(layer.weight, std)
    if layer.bias is not None:
        torch.nn.init.constant_(layer.bias, bias_const)
    return layer