Spaces:
Sleeping
Sleeping
import torch.nn as nn | |
from einops import rearrange | |
from src.models.vit.utils import init_weights | |
class DecoderLinear(nn.Module): | |
def __init__( | |
self, | |
n_cls, | |
d_encoder, | |
scale_factor, | |
dropout_rate=0.3, | |
): | |
super().__init__() | |
self.scale_factor = scale_factor | |
self.head = nn.Linear(d_encoder, n_cls) | |
self.upsampling = nn.Upsample(scale_factor=scale_factor**2, mode="linear") | |
self.norm = nn.LayerNorm((n_cls, 24 * scale_factor, 24 * scale_factor)) | |
self.dropout = nn.Dropout(dropout_rate) | |
self.gelu = nn.GELU() | |
self.apply(init_weights) | |
def forward(self, x, img_size): | |
H, _ = img_size | |
x = self.head(x) ####### (2, 577, 64) | |
x = x.transpose(2, 1) ## (2, 64, 576) | |
x = self.upsampling(x) # (2, 64, 576*scale_factor*scale_factor) | |
x = x.transpose(2, 1) ## (2, 576*scale_factor*scale_factor, 64) | |
x = rearrange(x, "b (h w) c -> b c h w", h=H // (16 // self.scale_factor)) # (2, 64, 24*scale_factor, 24*scale_factor) | |
x = self.norm(x) | |
x = self.dropout(x) | |
x = self.gelu(x) | |
return x # (2, 64, a, a) | |