File size: 1,184 Bytes
62ef5f4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import torch.nn as nn
from einops import rearrange
from src.models.vit.utils import init_weights


class DecoderLinear(nn.Module):
    def __init__(
        self,
        n_cls,
        d_encoder,
        scale_factor,
        dropout_rate=0.3,
    ):
        super().__init__()
        self.scale_factor = scale_factor
        self.head = nn.Linear(d_encoder, n_cls)
        self.upsampling = nn.Upsample(scale_factor=scale_factor**2, mode="linear")
        self.norm = nn.LayerNorm((n_cls, 24 * scale_factor, 24 * scale_factor))
        self.dropout = nn.Dropout(dropout_rate)
        self.gelu = nn.GELU()
        self.apply(init_weights)

    def forward(self, x, img_size):
        H, _ = img_size
        x = self.head(x)  ####### (2, 577, 64)
        x = x.transpose(2, 1)  ## (2, 64, 576)
        x = self.upsampling(x)  # (2, 64, 576*scale_factor*scale_factor)
        x = x.transpose(2, 1)  ## (2, 576*scale_factor*scale_factor, 64)
        x = rearrange(x, "b (h w) c -> b c h w", h=H // (16 // self.scale_factor))  # (2, 64, 24*scale_factor, 24*scale_factor)
        x = self.norm(x)
        x = self.dropout(x)
        x = self.gelu(x)

        return x  # (2, 64, a, a)