File size: 2,519 Bytes
1b36437
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import torch
import torch.nn as nn 
import torch.nn.functional as F 

from dataclasses import dataclass

from transformers import PretrainedConfig, PreTrainedModel
from transformers.utils import ModelOutput

from .configuration_mapper import MapperConfig

class FeedForward(nn.Module):
    def __init__(self, d_in, d_out):
        super().__init__()
        self.fc1 = nn.Linear(d_in, d_out*2)
        self.fc2 = nn.Linear(d_out, d_out)
        
    def forward(self, x):
        x = self.fc1(x)
        x1, x2 = x.chunk(2, dim=-1)
        x = self.fc2(F.silu(x1) * x2)
        return x

class FeedForwardLayer(nn.Module):
    def __init__(self, d_in, d_out, dropout=0.1, layer_norm_eps=None):
        super().__init__()
        self.ff = FeedForward(d_in, d_out)
        self.skip = nn.Linear(d_in, d_out) if d_in != d_out else nn.Identity()
        self.dropout = nn.Dropout(dropout)
        self.LayerNorm = nn.LayerNorm(d_out, eps=layer_norm_eps) if layer_norm_eps else None
        
    def forward(self, x):
        x = self.dropout(x)
        x = self.ff(x) + self.skip(x)
        if self.LayerNorm:
            x = self.LayerNorm(x)
        return x
    
class Mapper(nn.Module):
    def __init__(self, d_in, d_hidden, d_out, n_out, n_layers, dropout=0.1, layer_norm_eps=None):
        super().__init__()
        self.n_out = n_out
        layers = [FeedForwardLayer(d_in, d_hidden, 0.0, layer_norm_eps)]
        layers += [FeedForwardLayer(d_hidden, d_hidden, dropout, layer_norm_eps)
                   for i in range(n_layers)]
        self.layers = nn.Sequential(*layers)
    
        self.output_layer = FeedForwardLayer(d_hidden, d_out*n_out, 0.0, None)
        
    def forward(self, x):
        x = self.layers(x)
        x = self.output_layer(x)
        x = torch.stack(torch.chunk(x, self.n_out, -1), 1)
        return x

@dataclass
class MapperModelOutput(ModelOutput):
    mapper_out: torch.FloatTensor = None

class MapperModel(PreTrainedModel):
    config_class = MapperConfig
    def __init__(self, config):
        super().__init__(config)
        
        self.mapper = Mapper(config.d_in, config.d_hidden, config.d_out, config.n_out,
                            config.n_layers, config.dropout, config.layer_norm_eps)
        
    def forward(self, embedding, return_dict=True):
                        
        mapper_out = self.mapper(embedding)
        
        
        if not return_dict:
            return (mapper_out, )
        
        return MapperModelOutput(mapper_out=mapper_out)