File size: 4,362 Bytes
b3fb4dd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import torch
import torch.nn as nn
from Modules.conformer import ConformerEncoder, ConformerDecoder
from Modules.mhsa_pro import RotaryEmbedding, ContinuousRotaryEmbedding

class ConvBlock(nn.Module):
  def __init__(self, args, num_layer) -> None:
    super().__init__()
    if args.activation == 'silu':
        self.activation = nn.SiLU()
    else:
        self.activation = nn.ReLU()
    in_channels = args.encoder_dims[num_layer-1] if num_layer < len(args.encoder_dims) else args.encoder_dims[-1]
    out_channels = args.encoder_dims[num_layer] if num_layer < len(args.encoder_dims) else args.encoder_dims[-1]
    self.layers = nn.Sequential(
        nn.Conv1d(in_channels=in_channels,
                out_channels=out_channels,
                kernel_size=args.kernel_size,
                stride=1, padding='same', bias=False),
        nn.BatchNorm1d(num_features=out_channels),
        self.activation,
    )

  def forward(self, x: torch.Tensor) -> torch.Tensor:  
    return self.layers(x)

class CNNEncoder(nn.Module):
    def __init__(self, args) -> None:
        super().__init__()
        print("Using CNN encoder wit activation: ", args.activation, 'args avg_output: ', args.avg_output)
        if args.activation == 'silu':
            self.activation = nn.SiLU()
        else:
            self.activation = nn.ReLU()
        self.embedding = nn.Sequential(nn.Conv1d(in_channels = args.in_channels,
                kernel_size=3, out_channels = args.encoder_dims[0], stride=1, padding = 'same', bias = False),
                        nn.BatchNorm1d(args.encoder_dims[0]),
                        self.activation,
        )
        
        self.layers = nn.ModuleList([ConvBlock(args, i+1)
        for i in range(args.num_layers)])
        self.pool = nn.MaxPool1d(2)
        self.output_dim = args.encoder_dims[-1]
        self.min_seq_len = 2 
        self.avg_output = args.avg_output
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if len(x.shape)==2:
            x = x.unsqueeze(1)
        if len(x.shape)==3 and x.shape[-1]==1:
            x = x.permute(0,2,1)
        x = self.embedding(x)
        for m in self.layers:
            x = m(x)
            if x.shape[-1] > self.min_seq_len:
                x = self.pool(x)
        if self.avg_output:
            x = x.mean(dim=-1)
        return x


class MultiEncoder(nn.Module):
    def __init__(self, args, conformer_args):
        super().__init__()
        self.backbone = CNNEncoder(args)
        self.backbone.avg_output = False
        self.head_size = conformer_args.encoder_dim // conformer_args.num_heads
        self.rotary_ndims = int(self.head_size * 0.5)
        self.pe = RotaryEmbedding(self.rotary_ndims)
        self.encoder = ConformerEncoder(conformer_args)
        self.output_dim = conformer_args.encoder_dim
        self.avg_output = args.avg_output
        
    def forward(self, x):
        # Store backbone output in a separate tensor
        backbone_out = self.backbone(x)
        
        # Create x_enc from backbone_out
        if len(backbone_out.shape) == 2:
            x_enc = backbone_out.unsqueeze(1).clone()
        else:
            x_enc = backbone_out.permute(0,2,1).clone()
            
        RoPE = self.pe(x_enc, x_enc.shape[1])
        x_enc = self.encoder(x_enc, RoPE)
        
        if len(x_enc.shape) == 3:
            if self.avg_output:
                x_enc = x_enc.sum(dim=1)
            else:
                x_enc = x_enc.permute(0,2,1)
                
        # Return x_enc and the original backbone output
        return x_enc, backbone_out

class DualEncoder(nn.Module):
    def __init__(self, args_x, args_f, conformer_args) -> None:
        super().__init__()
        self.encoder_x = CNNEncoder(args_x)
        self.encoder_f = MultiEncoder(args_f, conformer_args)
        total_output_dim = args_x.encoder_dims[-1] + args_f.encoder_dims[-1]
        self.regressor = nn.Sequential(
            nn.Linear(total_output_dim, total_output_dim//2),
            nn.BatchNorm1d(total_output_dim//2),
            nn.SiLU(),
            nn.Linear(total_output_dim//2, 1)
        )
            
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x1 = self.encoder_x(x)
        x2, _ = self.encoder_f(x)
        logits = torch.cat([x1, x2], dim=-1)
        return self.regressor(logits).squeeze()