File size: 2,525 Bytes
b3c2eb7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import torch
from torch import nn


def make_encoder(input_dim, enc_dec_dims):
    encoder_layers = []
    decoder_layers = []
    output_dim = input_dim
    enc_shape = enc_dec_dims[-1]
    for enc_dim in enc_dec_dims[:-1]:
        encoder_layers.extend([nn.Linear(input_dim, enc_dim), nn.SELU()])
        input_dim = enc_dim

    encoder_layers.append(nn.Linear(input_dim, enc_shape))

    enc_dec_dims = list(reversed(enc_dec_dims))
    for dec_dim in enc_dec_dims[1:]:
        decoder_layers.extend([nn.Linear(enc_shape, dec_dim), nn.SELU()])
        enc_shape = dec_dim

    decoder_layers.append(nn.Linear(enc_shape, output_dim))

    return nn.Sequential(*encoder_layers), nn.Sequential(*decoder_layers)


class FsrFgModel(nn.Module):
    def __init__(self, fg_input_dim, mfg_input_dim, num_input_dim, enc_dec_dims, output_dims,
                 num_tasks, dropout, method):
        super(FsrFgModel, self).__init__()

        self.method = method
        if self.method == 'FG':
            input_dim = fg_input_dim
        elif self.method == 'MFG':
            input_dim = mfg_input_dim
        elif self.method == 'FGR':
            input_dim = fg_input_dim + mfg_input_dim
        else:
            input_dim = fg_input_dim + mfg_input_dim
        if self.method != 'FGR_desc':
            fcn_input_dim = enc_dec_dims[-1]
        else:
            fcn_input_dim = num_input_dim + enc_dec_dims[-1]
        self.encoder, self.decoder = make_encoder(input_dim, enc_dec_dims)
        self.dropout = nn.Dropout(dropout)
        self.predict_out_dim = num_tasks
        self.batch_norm = nn.BatchNorm1d(fcn_input_dim)

        layers = []
        for output_dim in output_dims:
            layers.extend([nn.Linear(fcn_input_dim, output_dim), nn.SELU(), nn.BatchNorm1d(output_dim)])
            fcn_input_dim = output_dim

        layers.extend([self.dropout, nn.Linear(fcn_input_dim, num_tasks)])

        self.predictor = nn.Sequential(*layers)

    def forward(self, fg=None, mfg=None, num_features=None):

        if self.method == 'FG':
            z_d = self.encoder(fg)
        elif self.method == 'MFG':
            z_d = self.encoder(mfg)
        elif self.method == 'FGR':
            z_d = self.encoder(torch.cat([fg, mfg], dim=1))
        else:
            z_d = self.encoder(torch.cat([fg, mfg], dim=1))

        v_d_hat = self.decoder(z_d)

        if self.method == 'FGR_desc':
            z_d = torch.cat([z_d, num_features], dim=1)

        output = self.predictor(z_d)
        return output, v_d_hat